aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Shape.hs')
-rw-r--r--src/Data/Array/Mixed/Shape.hs41
1 files changed, 13 insertions, 28 deletions
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index e5f8b67..bed812d 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -18,8 +18,6 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Shape where
import Control.DeepSeq (NFData(..))
@@ -35,16 +33,16 @@ import GHC.Exts (withDict)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
-import GHC.TypeLits
import Data.Array.Mixed.Types
+import Data.SNat.Peano
-- | The length of a type-level list. If the argument is a shape, then the
-- result is the rank of that shape.
type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = Rank sh + 1
+ Rank '[] = Z
+ Rank (_ : sh) = S (Rank sh)
-- * Mixed lists
@@ -91,8 +89,8 @@ listxLength :: ListX sh f -> Int
listxLength = getSum . listxFold (\_ -> Sum 1)
listxRank :: ListX sh f -> SNat (Rank sh)
-listxRank ZX = SNat
-listxRank (_ ::% l) | SNat <- listxRank l = SNat
+listxRank ZX = SZ
+listxRank (_ ::% l) = SS (listxRank l)
listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
@@ -255,7 +253,7 @@ type family AddMaybe n m where
smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
-smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatAdd n m)
-- | This is a newtype over 'ListX'.
@@ -288,7 +286,7 @@ instance Functor (ShX sh) where
instance NFData i => NFData (ShX sh i) where
rnf (ShX ZX) = ()
rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+ rnf (ShX (SKnown n ::% l)) = rnf n `seq` rnf (ShX l)
shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l
@@ -300,8 +298,8 @@ shxRank (ShX list) = listxRank list
-- dimensions) are the same.
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
-shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
- | Just Refl <- sameNat n m
+shxEqual (SKnown n :$% sh) (SKnown m :$% sh')
+ | Just Refl <- testEquality n m
, Just Refl <- shxEqual sh sh'
= Just Refl
shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
@@ -422,19 +420,6 @@ instance TestEquality StaticShX where
ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l
--- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
--- class of the @some@ package.
-ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
-ssxGeq ZKX ZKX = Just Refl
-ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
- | Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq _ _ = Nothing
-
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
@@ -481,7 +466,7 @@ ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
type KnownShX :: [Maybe Nat] -> Constraint
class KnownShX sh where knownShX :: StaticShX sh
instance KnownShX '[] where knownShX = ZKX
-instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown knownNat :!% knownShX
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
@@ -490,7 +475,7 @@ withKnownShX sh = withDict @(KnownShX sh) sh
-- * Flattening
-type Flatten sh = Flatten' 1 sh
+type Flatten sh = Flatten' (S Z) sh
type family Flatten' acc sh where
Flatten' acc '[] = Just acc
@@ -499,7 +484,7 @@ type family Flatten' acc sh where
-- This function is currently unused
ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
-ssxFlatten = go (SNat @1)
+ssxFlatten = go (mkSNat @1)
where
go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
go acc ZKX = SKnown acc
@@ -507,7 +492,7 @@ ssxFlatten = go (SNat @1)
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
-shxFlatten = go (SNat @1)
+shxFlatten = go (mkSNat @1)
where
go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
go acc ZSX = SKnown acc