aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-13 20:08:54 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-13 20:08:54 +0200
commit7c9865354442326d55094087ad6a74b6e96341fb (patch)
tree8f4350d91eb5075f7fe7ac1bedc37e8b040e4ab4 /src/Data/Array
parent4808710e311a69326c6fdef9fc1b9b2173fd009e (diff)
Replace SShape with ShS
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested.hs3
-rw-r--r--src/Data/Array/Nested/Internal.hs89
2 files changed, 37 insertions, 55 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 389a8f5..370e30a 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -17,8 +17,7 @@ module Data.Array.Nested (
Shaped,
ListS, pattern (::$), pattern ZS,
IxS(..), pattern (:.$), pattern ZIS, IIxS,
- ShS(..), pattern (:$$), pattern ZSS,
- KnownShape(..), SShape(..),
+ ShS(..), KnownShape(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, sunScalar,
sconstant, sfromList, sfromList1, stoList, stoList1,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 18f458e..49ed7cb 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -708,35 +708,35 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
-- | The shape of a shape-typed array given as a list of 'SNat' values.
-data SShape sh where
- ShNil :: SShape '[]
- ShCons :: SNat n -> SShape sh -> SShape (n : sh)
-deriving instance Show (SShape sh)
-infixr 5 `ShCons`
+data ShS sh where
+ ZSS :: ShS '[]
+ (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)
+deriving instance Show (ShS sh)
+infixr 3 :$$
-- | A statically-known shape of a shape-typed array.
-class KnownShape sh where knownShape :: SShape sh
-instance KnownShape '[] where knownShape = ShNil
-instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons natSing knownShape
+class KnownShape sh where knownShape :: ShS sh
+instance KnownShape '[] where knownShape = ZSS
+instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = natSing :$$ knownShape
-sshapeKnown :: SShape sh -> Dict KnownShape sh
-sshapeKnown ShNil = Dict
-sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict
+sshapeKnown :: ShS sh -> Dict KnownShape sh
+sshapeKnown ZSS = Dict
+sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict
lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
where
- go :: SShape sh' -> StaticShapeX (MapJust sh')
- go ShNil = ZSX
- go (ShCons n sh) = n :$@ go sh
+ go :: ShS sh' -> StaticShapeX (MapJust sh')
+ go ZSS = ZSX
+ go (n :$$ sh) = n :$@ go sh
lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
-> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
lemMapJustPlusApp _ _ = go (knownShape @sh1)
where
- go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
- go ShNil = Refl
- go (ShCons _ sh) | Refl <- go sh = Refl
+ go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
+ go ZSS = Refl
+ go (_ :$$ sh) | Refl <- go sh = Refl
instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
@@ -1131,41 +1131,24 @@ unconsIxS (IxS ZS) = Nothing
type IIxS sh = IxS sh Int
-type role ShS nominal representational
-type ShS :: [Nat] -> Type -> Type
-newtype ShS sh i = ShS (ListS sh i)
- deriving (Show, Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZSS :: forall sh i. () => sh ~ '[] => ShS sh i
-pattern ZSS = ShS ZS
-
-pattern (:$$)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => i -> ShS sh i -> ShS sh1 i
-pattern i :$$ shl <- (unconsShS -> Just (UnconsShSRes shl i))
- where i :$$ ShS shl = ShS (i ::$ shl)
-{-# COMPLETE ZSS, (:$$) #-}
-infixr 3 :$$
-
-data UnconsShSRes i sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh i) i
-unconsShS :: ShS sh1 i -> Maybe (UnconsShSRes i sh1)
-unconsShS (ShS (i ::$ shl')) = Just (UnconsShSRes (ShS shl') i)
-unconsShS (ShS ZS) = Nothing
+data UnconsShSRes sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n)
+unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1)
+unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i)
+unconsShS ZSS = Nothing
-zeroIxS :: SShape sh -> IIxS sh
-zeroIxS ShNil = ZIS
-zeroIxS (ShCons _ sh) = 0 :.$ zeroIxS sh
+zeroIxS :: ShS sh -> IIxS sh
+zeroIxS ZSS = ZIS
+zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh
-cvtSShapeIxS :: SShape sh -> IIxS sh
-cvtSShapeIxS ShNil = ZIS
-cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) :.$ cvtSShapeIxS sh
+-- TODO: this function should not exist perhaps
+cvtShSIxS :: ShS sh -> IIxS sh
+cvtShSIxS ZSS = ZIS
+cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh
-ixCvtXS :: SShape sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ShNil ZIX = ZIS
-ixCvtXS (ShCons _ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx
+ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
+ixCvtXS ZSS ZIX = ZIS
+ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx
ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
@@ -1178,7 +1161,7 @@ shapeSizeS (n :.$ sh) = n * shapeSizeS sh
-- | This does not touch the passed array, all information comes from 'KnownShape'.
sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh
-sshape _ = cvtSShapeIxS (knownShape @sh)
+sshape _ = cvtShSIxS (knownShape @sh)
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
@@ -1194,7 +1177,7 @@ sindexPartial (Shaped arr) idx =
sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a
sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))
+ = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
@@ -1230,7 +1213,7 @@ sscalar x = Shaped (mscalar x)
sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
sfromVector v
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v)
+ = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v)
sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
=> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
@@ -1253,7 +1236,7 @@ sunScalar arr = sindex arr ZIS
sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)
sconstantP x
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mconstantP (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x)
+ = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x)
sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
=> a -> Shaped sh a