diff options
Diffstat (limited to 'src/Data/Array/Mixed/Shape.hs')
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 41 |
1 files changed, 13 insertions, 28 deletions
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index e5f8b67..bed812d 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -18,8 +18,6 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Shape where import Control.DeepSeq (NFData(..)) @@ -35,16 +33,16 @@ import GHC.Exts (withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList -import GHC.TypeLits import Data.Array.Mixed.Types +import Data.SNat.Peano -- | The length of a type-level list. If the argument is a shape, then the -- result is the rank of that shape. type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 + Rank '[] = Z + Rank (_ : sh) = S (Rank sh) -- * Mixed lists @@ -91,8 +89,8 @@ listxLength :: ListX sh f -> Int listxLength = getSum . listxFold (\_ -> Sum 1) listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat +listxRank ZX = SZ +listxRank (_ ::% l) = SS (listxRank l) listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS listxShow f l = showString "[" . go "" l . showString "]" @@ -255,7 +253,7 @@ type family AddMaybe n m where smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatAdd n m) -- | This is a newtype over 'ListX'. @@ -288,7 +286,7 @@ instance Functor (ShX sh) where instance NFData i => NFData (ShX sh i) where rnf (ShX ZX) = () rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + rnf (ShX (SKnown n ::% l)) = rnf n `seq` rnf (ShX l) shxLength :: ShX sh i -> Int shxLength (ShX l) = listxLength l @@ -300,8 +298,8 @@ shxRank (ShX list) = listxRank list -- dimensions) are the same. shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m +shxEqual (SKnown n :$% sh) (SKnown m :$% sh') + | Just Refl <- testEquality n m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') @@ -422,19 +420,6 @@ instance TestEquality StaticShX where ssxLength :: StaticShX sh -> Int ssxLength (StaticShX l) = listxLength l --- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ --- class of the @some@ package. -ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -ssxGeq ZKX ZKX = Just Refl -ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') - | Just Refl <- sameNat n m - , Just Refl <- ssxGeq sh sh' - = Just Refl -ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') - | Just Refl <- ssxGeq sh sh' - = Just Refl -ssxGeq _ _ = Nothing - ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' @@ -481,7 +466,7 @@ ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 type KnownShX :: [Maybe Nat] -> Constraint class KnownShX sh where knownShX :: StaticShX sh instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown knownNat :!% knownShX instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r @@ -490,7 +475,7 @@ withKnownShX sh = withDict @(KnownShX sh) sh -- * Flattening -type Flatten sh = Flatten' 1 sh +type Flatten sh = Flatten' (S Z) sh type family Flatten' acc sh where Flatten' acc '[] = Just acc @@ -499,7 +484,7 @@ type family Flatten' acc sh where -- This function is currently unused ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) -ssxFlatten = go (SNat @1) +ssxFlatten = go (mkSNat @1) where go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) go acc ZKX = SKnown acc @@ -507,7 +492,7 @@ ssxFlatten = go (SNat @1) go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) -shxFlatten = go (SNat @1) +shxFlatten = go (mkSNat @1) where go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) go acc ZSX = SKnown acc |