aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-12-11 19:56:28 +0100
committerTom Smeding <tom@tomsmeding.com>2024-12-11 19:57:12 +0100
commita3299c09e0fd12cf73c4a0a9a2ae37b8f69f9b10 (patch)
tree8f28f2cb8034530f20fc56265c64af1164b35776 /src/Data/Array/Nested/Internal
parent9570a94d331facc8961be204d7a3010d33146f97 (diff)
Simpler API to mcast
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs35
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs8
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs24
3 files changed, 35 insertions, 32 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 619d9bc..8813ab3 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -321,8 +321,8 @@ class Elt a where
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
-> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a)
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
@@ -417,10 +417,11 @@ instance Storable a => Elt (Primitive a) where
= fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $
f ZKX (fmap (\(M_Primitive _ arr) -> arr) l)
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
- mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
- let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
+ let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ sh2 = shxCast' sh1 ssh2
in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
mtranspose perm (M_Primitive sh arr) =
@@ -493,8 +494,8 @@ instance (Elt a, Elt b) => Elt (a, b) where
unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2
- mcast ssh1 sh2 psh' (M_Tup2 a b) =
- M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+ mcastPartial ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b)
mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
mconcat =
@@ -600,13 +601,14 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
- mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
- mcast ssh1 sh2 _ (M_Nest sh1T arr)
+ mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+ = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ sh2 = shxCast' sh1 ssh2
+ in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
@@ -908,3 +910,10 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
+
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 1966270..ed89d82 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -120,7 +120,7 @@ instance Elt a => Elt (Ranked n a) where
@(NonEmpty (Mixed sh2 (Ranked n a))) $
mliftL ssh2 f (coerce l)
- mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
+ mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)
mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
@@ -523,10 +523,8 @@ rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
mtoRanked arr
- | Refl <- lemAppNil @sh
- , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))
- , Refl <- lemRankReplicate (shxRank (mshape arr))
- = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ | Refl <- lemRankReplicate (shxRank (mshape arr))
+ = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
where
convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
convSh ZSX = ZSX
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index f8df5aa..0a230c4 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -118,7 +118,7 @@ instance Elt a => Elt (Shaped sh a) where
@(NonEmpty (Mixed sh2 (Shaped sh a))) $
mliftL ssh2 f (coerce l)
- mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
+ mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)
mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
@@ -324,13 +324,13 @@ stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
+sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
+sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim
+sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
@@ -476,10 +476,8 @@ stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> Mixed sh a -> ShS sh' -> Shaped sh' a
mcastToShaped arr targetsh
- | Refl <- lemAppNil @sh
- , Refl <- lemAppNil @(MapJust sh')
- , Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
+ | Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
@@ -487,9 +485,7 @@ stoMixed (Shaped arr) = arr
-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
-- compatibility check.
scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => IShX sh' -> Shaped sh a -> Mixed sh' a
-scastToMixed shx sarr@(Shaped arr)
- | Refl <- lemAppNil @sh'
- , Refl <- lemAppNil @(MapJust sh)
- , Refl <- lemRankMapJust (sshape sarr)
- = mcast (ssxFromShape (mshape arr)) shx (Proxy @'[]) arr
+ => StaticShX sh' -> Shaped sh a -> Mixed sh' a
+scastToMixed sshx sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mcast sshx arr