From 2c3d1e4884eee109ca72286244eef4b357d586b8 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Sun, 21 Apr 2024 18:49:54 +0200 Subject: Flesh out shaped sized lists --- src/Data/Array/Mixed.hs | 6 +++-- src/Data/Array/Nested.hs | 2 +- src/Data/Array/Nested/Internal.hs | 56 +++++++++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 8 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index c19fbe5..d2765b6 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -44,6 +44,8 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl +-- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype +-- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)? type IxX :: Type -> [Maybe Nat] -> Type data IxX i sh where @@ -317,7 +319,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) unXArray (XArray a) = a -- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a transpose perm (XArray arr) | Dict <- lemKnownINatRankSSX (knownShapeX @sh) , Dict <- knownNatFromINat (Proxy @(Rank sh)) @@ -360,7 +362,7 @@ fromList ssh l = case ssh of m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (natVal m) ++ ")" + "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) toList :: Storable a => XArray (n : sh) a -> [XArray sh a] diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index f383b99..9222210 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -13,7 +13,7 @@ module Data.Array.Nested ( -- * Shaped arrays Shaped, - IxS(..), IIxS, + IxS, pattern (:.$), pattern ZIS, IIxS, KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, sunScalar, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index e42de12..9cabdc6 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -1043,6 +1043,14 @@ instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) whe deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) +type ListS :: Type -> [Nat] -> Type +data ListS i n where + ZS :: ListS i '[] + (::$) :: forall n sh i. i -> ListS i sh -> ListS i (n : sh) +deriving instance Show i => Show (ListS i n) +deriving instance Eq i => Eq (ListS i n) +infixr 3 ::$ + -- | An index into a shape-typed array. -- -- For convenience, this contains regular 'Int's instead of bounded integers @@ -1050,15 +1058,53 @@ deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. type IxS :: Type -> [Nat] -> Type -data IxS i sh where - ZIS :: IxS i '[] - (:.$) :: forall n sh i. i -> IxS i sh -> IxS i (n : sh) -deriving instance Show i => Show (IxS i n) -deriving instance Eq i => Eq (IxS i n) +newtype IxS i sh = IxS (ListS i sh) + deriving (Show, Eq) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS i sh +pattern ZIS = IxS ZS + +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxS i sh -> IxS i sh1 +pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i)) + where i :.$ (IxS shl) = IxS (i ::$ shl) +{-# COMPLETE ZIS, (:.$) #-} infixr 3 :.$ +data UnconsIxSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS i sh) i +unconsIxS :: IxS i sh1 -> Maybe (UnconsIxSRes i sh1) +unconsIxS (IxS shl) = case shl of + i ::$ shl' -> Just (UnconsIxSRes (IxS shl') i) + ZS -> Nothing + type IIxS = IxS Int +type StaticShapeS :: Type -> [Nat] -> Type +newtype StaticShapeS i sh = StaticShapeS (ListS i sh) + deriving (Show, Eq) + +pattern ZSS :: forall sh i. () => sh ~ '[] => StaticShapeS i sh +pattern ZSS = StaticShapeS ZS + +pattern (:$$) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> StaticShapeS i sh -> StaticShapeS i sh1 +pattern i :$$ shl <- (unconsStaticShapeS -> Just (UnconsStaticShapeSRes shl i)) + where i :$$ (StaticShapeS shl) = StaticShapeS (i ::$ shl) +{-# COMPLETE ZSS, (:$$) #-} +infixr 3 :$$ + +data UnconsStaticShapeSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsStaticShapeSRes (StaticShapeS i sh) i +unconsStaticShapeS :: StaticShapeS i sh1 -> Maybe (UnconsStaticShapeSRes i sh1) +unconsStaticShapeS (StaticShapeS shl) = case shl of + i ::$ shl' -> Just (UnconsStaticShapeSRes (StaticShapeS shl') i) + ZS -> Nothing + zeroIxS :: SShape sh -> IIxS sh zeroIxS ShNil = ZIS zeroIxS (ShCons _ sh) = 0 :.$ zeroIxS sh -- cgit v1.2.3-70-g09d2