diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-12-11 19:56:28 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-11 19:57:12 +0100 |
commit | a3299c09e0fd12cf73c4a0a9a2ae37b8f69f9b10 (patch) | |
tree | 8f28f2cb8034530f20fc56265c64af1164b35776 /src/Data/Array/Nested/Internal | |
parent | 9570a94d331facc8961be204d7a3010d33146f97 (diff) |
Simpler API to mcast
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 35 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 8 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 24 |
3 files changed, 35 insertions, 32 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 619d9bc..8813ab3 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -321,8 +321,8 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a @@ -417,10 +417,11 @@ instance Storable a => Elt (Primitive a) where = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = + let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + sh2 = shxCast' sh1 ssh2 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -493,8 +494,8 @@ instance (Elt a, Elt b) => Elt (a, b) where unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 - mcast ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) mconcat = @@ -600,13 +601,14 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) - mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcast ssh1 sh2 _ (M_Nest sh1T arr) + mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + sh2 = shxCast' sh1 ssh2 + in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) @@ -908,3 +910,10 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) -> Mixed sh a -> Mixed sh b -> Mixed sh c mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 1966270..ed89d82 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -120,7 +120,7 @@ instance Elt a => Elt (Ranked n a) where @(NonEmpty (Mixed sh2 (Ranked n a))) $ mliftL ssh2 f (coerce l) - mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) + mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) @@ -523,10 +523,8 @@ rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked arr - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) - , Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + | Refl <- lemRankReplicate (shxRank (mshape arr)) + = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) where convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) convSh ZSX = ZSX diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index f8df5aa..0a230c4 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -118,7 +118,7 @@ instance Elt a => Elt (Shaped sh a) where @(NonEmpty (Mixed sh2 (Shaped sh a))) $ mliftL ssh2 f (coerce l) - mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) @@ -324,13 +324,13 @@ stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim +sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] stoListOuter (Shaped arr) = coerce (mtoListOuter arr) @@ -476,10 +476,8 @@ stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a mcastToShaped arr targetsh - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(MapJust sh') - , Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + | Refl <- lemRankMapJust targetsh + = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr @@ -487,9 +485,7 @@ stoMixed (Shaped arr) = arr -- | A more weakly-typed version of 'stoMixed' that does a runtime shape -- compatibility check. scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => IShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed shx sarr@(Shaped arr) - | Refl <- lemAppNil @sh' - , Refl <- lemAppNil @(MapJust sh) - , Refl <- lemRankMapJust (sshape sarr) - = mcast (ssxFromShape (mshape arr)) shx (Proxy @'[]) arr + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr |