From 7c9865354442326d55094087ad6a74b6e96341fb Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 13 May 2024 20:08:54 +0200 Subject: Replace SShape with ShS --- src/Data/Array/Nested.hs | 3 +- src/Data/Array/Nested/Internal.hs | 89 ++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 55 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 389a8f5..370e30a 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -17,8 +17,7 @@ module Data.Array.Nested ( Shaped, ListS, pattern (::$), pattern ZS, IxS(..), pattern (:.$), pattern ZIS, IIxS, - ShS(..), pattern (:$$), pattern ZSS, - KnownShape(..), SShape(..), + ShS(..), KnownShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, sunScalar, sconstant, sfromList, sfromList1, stoList, stoList1, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 18f458e..49ed7cb 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -708,35 +708,35 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where -- | The shape of a shape-typed array given as a list of 'SNat' values. -data SShape sh where - ShNil :: SShape '[] - ShCons :: SNat n -> SShape sh -> SShape (n : sh) -deriving instance Show (SShape sh) -infixr 5 `ShCons` +data ShS sh where + ZSS :: ShS '[] + (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) +deriving instance Show (ShS sh) +infixr 3 :$$ -- | A statically-known shape of a shape-typed array. -class KnownShape sh where knownShape :: SShape sh -instance KnownShape '[] where knownShape = ShNil -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons natSing knownShape +class KnownShape sh where knownShape :: ShS sh +instance KnownShape '[] where knownShape = ZSS +instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = natSing :$$ knownShape -sshapeKnown :: SShape sh -> Dict KnownShape sh -sshapeKnown ShNil = Dict -sshapeKnown (ShCons GHC_SNat sh) | Dict <- sshapeKnown sh = Dict +sshapeKnown :: ShS sh -> Dict KnownShape sh +sshapeKnown ZSS = Dict +sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) where - go :: SShape sh' -> StaticShapeX (MapJust sh') - go ShNil = ZSX - go (ShCons n sh) = n :$@ go sh + go :: ShS sh' -> StaticShapeX (MapJust sh') + go ZSS = ZSX + go (n :$$ sh) = n :$@ go sh lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 lemMapJustPlusApp _ _ = go (knownShape @sh1) where - go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ShNil = Refl - go (ShCons _ sh) | Refl <- go sh = Refl + go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 + go ZSS = Refl + go (_ :$$ sh) | Refl <- go sh = Refl instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr @@ -1131,41 +1131,24 @@ unconsIxS (IxS ZS) = Nothing type IIxS sh = IxS sh Int -type role ShS nominal representational -type ShS :: [Nat] -> Type -> Type -newtype ShS sh i = ShS (ListS sh i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZSS :: forall sh i. () => sh ~ '[] => ShS sh i -pattern ZSS = ShS ZS - -pattern (:$$) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> ShS sh i -> ShS sh1 i -pattern i :$$ shl <- (unconsShS -> Just (UnconsShSRes shl i)) - where i :$$ ShS shl = ShS (i ::$ shl) -{-# COMPLETE ZSS, (:$$) #-} -infixr 3 :$$ - -data UnconsShSRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh i) i -unconsShS :: ShS sh1 i -> Maybe (UnconsShSRes i sh1) -unconsShS (ShS (i ::$ shl')) = Just (UnconsShSRes (ShS shl') i) -unconsShS (ShS ZS) = Nothing +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n) +unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1) +unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i) +unconsShS ZSS = Nothing -zeroIxS :: SShape sh -> IIxS sh -zeroIxS ShNil = ZIS -zeroIxS (ShCons _ sh) = 0 :.$ zeroIxS sh +zeroIxS :: ShS sh -> IIxS sh +zeroIxS ZSS = ZIS +zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh -cvtSShapeIxS :: SShape sh -> IIxS sh -cvtSShapeIxS ShNil = ZIS -cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) :.$ cvtSShapeIxS sh +-- TODO: this function should not exist perhaps +cvtShSIxS :: ShS sh -> IIxS sh +cvtShSIxS ZSS = ZIS +cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh -ixCvtXS :: SShape sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ShNil ZIX = ZIS -ixCvtXS (ShCons _ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx +ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh +ixCvtXS ZSS ZIX = ZIS +ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX ZIS = ZIX @@ -1178,7 +1161,7 @@ shapeSizeS (n :.$ sh) = n * shapeSizeS sh -- | This does not touch the passed array, all information comes from 'KnownShape'. sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh -sshape _ = cvtSShapeIxS (knownShape @sh) +sshape _ = cvtShSIxS (knownShape @sh) sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) @@ -1194,7 +1177,7 @@ sindexPartial (Shaped arr) idx = sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) + = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) @@ -1230,7 +1213,7 @@ sscalar x = Shaped (mscalar x) sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) sfromVector v | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) + = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v) sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) => NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1253,7 +1236,7 @@ sunScalar arr = sindex arr ZIS sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) sconstantP x | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstantP (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) + = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x) sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) => a -> Shaped sh a -- cgit v1.2.3-70-g09d2