diff options
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r-- | src/Fancy.hs | 37 |
1 files changed, 13 insertions, 24 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs index 41272f0..8019393 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} @@ -8,9 +7,6 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} --- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Fancy where import Control.Monad (forM_) @@ -21,16 +17,15 @@ import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) import qualified Array as X +import Nats type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a + Replicate Z a = '[] + Replicate (S n) a = a : Replicate n a type family MapJust l where MapJust '[] = '[] @@ -39,18 +34,12 @@ type family MapJust l where lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a lemCompareFalse1 = error "Incoherence" -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) where - go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) - go SNat = case cmpNat (Proxy @1) (Proxy @m) of - LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) - EQI -> () :$? SZX - GTI -> case cmpNat (Proxy @0) (Proxy @m) of - LTI -> lemCompareFalse1 (Proxy @m) - EQI -> SZX - GTI -> error "0 > natural" - go _ = error "COMPLETE" + go :: SNat m -> StaticShapeX (Replicate m Nothing) + go SZ = SZX + go (SS n) = () :$? go n type Mixed :: [Maybe Nat] -> Type -> Type @@ -177,7 +166,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -- TODO: this is quadratic in the nesting level mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) - | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') + | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = ixAppPrefix (knownShapeX @sh) (mshape arr) where ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 @@ -213,7 +202,7 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs @@ -251,13 +240,13 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where - mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr + mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr type IxR :: Nat -> Type data IxR n where - IZR :: IxR 0 - (:::) :: Int -> IxR n -> IxR (n + 1) + IZR :: IxR Z + (:::) :: Int -> IxR n -> IxR (S n) type IxS :: [Nat] -> Type data IxS sh where |