aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Fancy.hs34
1 files changed, 32 insertions, 2 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs
index 821073e..41272f0 100644
--- a/src/Fancy.hs
+++ b/src/Fancy.hs
@@ -1,14 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE FlexibleInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Fancy where
import Control.Monad (forM_)
@@ -16,9 +18,11 @@ import Control.Monad.ST
import Data.Kind
import Data.Proxy
import Data.Type.Equality
+import Data.Type.Ord
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
@@ -32,6 +36,22 @@ type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
+lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a
+lemCompareFalse1 = error "Incoherence"
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
+ where
+ go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing)
+ go SNat = case cmpNat (Proxy @1) (Proxy @m) of
+ LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1))
+ EQI -> () :$? SZX
+ GTI -> case cmpNat (Proxy @0) (Proxy @m) of
+ LTI -> lemCompareFalse1 (Proxy @m)
+ EQI -> SZX
+ GTI -> error "0 > natural"
+ go _ = error "COMPLETE"
+
type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
@@ -223,6 +243,16 @@ newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
+
+
+instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
+ mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr
+
type IxR :: Nat -> Type
data IxR n where