aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@gmail.com>2024-04-21 18:49:54 +0200
committerMikolaj Konarski <mikolaj.konarski@gmail.com>2024-04-21 18:54:03 +0200
commit2c3d1e4884eee109ca72286244eef4b357d586b8 (patch)
tree194427a565ecffb5101f0de2f4a9037e3097f747
parentb3c92786635568e652b98095c3d0db5b4ec312b2 (diff)
Flesh out shaped sized lists
-rw-r--r--src/Data/Array/Mixed.hs6
-rw-r--r--src/Data/Array/Nested.hs2
-rw-r--r--src/Data/Array/Nested/Internal.hs56
3 files changed, 56 insertions, 8 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index c19fbe5..d2765b6 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -44,6 +44,8 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
+-- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype
+-- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)?
type IxX :: Type -> [Maybe Nat] -> Type
data IxX i sh where
@@ -317,7 +319,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
unXArray (XArray a) = a
-- | The list argument gives indices into the original dimension list.
-transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
transpose perm (XArray arr)
| Dict <- lemKnownINatRankSSX (knownShapeX @sh)
, Dict <- knownNatFromINat (Proxy @(Rank sh))
@@ -360,7 +362,7 @@ fromList ssh l
= case ssh of
m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (natVal m) ++ ")"
+ "does not match the type (" ++ show (natVal m) ++ ")"
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
toList :: Storable a => XArray (n : sh) a -> [XArray sh a]
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index f383b99..9222210 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -13,7 +13,7 @@ module Data.Array.Nested (
-- * Shaped arrays
Shaped,
- IxS(..), IIxS,
+ IxS, pattern (:.$), pattern ZIS, IIxS,
KnownShape(..), SShape(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, sunScalar,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index e42de12..9cabdc6 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1043,6 +1043,14 @@ instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) whe
deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int)
deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double)
+type ListS :: Type -> [Nat] -> Type
+data ListS i n where
+ ZS :: ListS i '[]
+ (::$) :: forall n sh i. i -> ListS i sh -> ListS i (n : sh)
+deriving instance Show i => Show (ListS i n)
+deriving instance Eq i => Eq (ListS i n)
+infixr 3 ::$
+
-- | An index into a shape-typed array.
--
-- For convenience, this contains regular 'Int's instead of bounded integers
@@ -1050,15 +1058,53 @@ deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped
-- shape-typed array is known statically, you can also retrieve the array shape
-- from a 'KnownShape' dictionary.
type IxS :: Type -> [Nat] -> Type
-data IxS i sh where
- ZIS :: IxS i '[]
- (:.$) :: forall n sh i. i -> IxS i sh -> IxS i (n : sh)
-deriving instance Show i => Show (IxS i n)
-deriving instance Eq i => Eq (IxS i n)
+newtype IxS i sh = IxS (ListS i sh)
+ deriving (Show, Eq)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS i sh
+pattern ZIS = IxS ZS
+
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxS i sh -> IxS i sh1
+pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i))
+ where i :.$ (IxS shl) = IxS (i ::$ shl)
+{-# COMPLETE ZIS, (:.$) #-}
infixr 3 :.$
+data UnconsIxSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS i sh) i
+unconsIxS :: IxS i sh1 -> Maybe (UnconsIxSRes i sh1)
+unconsIxS (IxS shl) = case shl of
+ i ::$ shl' -> Just (UnconsIxSRes (IxS shl') i)
+ ZS -> Nothing
+
type IIxS = IxS Int
+type StaticShapeS :: Type -> [Nat] -> Type
+newtype StaticShapeS i sh = StaticShapeS (ListS i sh)
+ deriving (Show, Eq)
+
+pattern ZSS :: forall sh i. () => sh ~ '[] => StaticShapeS i sh
+pattern ZSS = StaticShapeS ZS
+
+pattern (:$$)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> StaticShapeS i sh -> StaticShapeS i sh1
+pattern i :$$ shl <- (unconsStaticShapeS -> Just (UnconsStaticShapeSRes shl i))
+ where i :$$ (StaticShapeS shl) = StaticShapeS (i ::$ shl)
+{-# COMPLETE ZSS, (:$$) #-}
+infixr 3 :$$
+
+data UnconsStaticShapeSRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsStaticShapeSRes (StaticShapeS i sh) i
+unconsStaticShapeS :: StaticShapeS i sh1 -> Maybe (UnconsStaticShapeSRes i sh1)
+unconsStaticShapeS (StaticShapeS shl) = case shl of
+ i ::$ shl' -> Just (UnconsStaticShapeSRes (StaticShapeS shl') i)
+ ZS -> Nothing
+
zeroIxS :: SShape sh -> IIxS sh
zeroIxS ShNil = ZIS
zeroIxS (ShCons _ sh) = 0 :.$ zeroIxS sh