diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-12-11 19:56:28 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-11 19:57:12 +0100 |
commit | a3299c09e0fd12cf73c4a0a9a2ae37b8f69f9b10 (patch) | |
tree | 8f28f2cb8034530f20fc56265c64af1164b35776 /src/Data/Array/Nested/Internal/Shaped.hs | |
parent | 9570a94d331facc8961be204d7a3010d33146f97 (diff) |
Simpler API to mcast
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shaped.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 24 |
1 files changed, 10 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index f8df5aa..0a230c4 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -118,7 +118,7 @@ instance Elt a => Elt (Shaped sh a) where @(NonEmpty (Mixed sh2 (Shaped sh a))) $ mliftL ssh2 f (coerce l) - mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) @@ -324,13 +324,13 @@ stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim +sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] stoListOuter (Shaped arr) = coerce (mtoListOuter arr) @@ -476,10 +476,8 @@ stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a mcastToShaped arr targetsh - | Refl <- lemAppNil @sh - , Refl <- lemAppNil @(MapJust sh') - , Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + | Refl <- lemRankMapJust targetsh + = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr @@ -487,9 +485,7 @@ stoMixed (Shaped arr) = arr -- | A more weakly-typed version of 'stoMixed' that does a runtime shape -- compatibility check. scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => IShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed shx sarr@(Shaped arr) - | Refl <- lemAppNil @sh' - , Refl <- lemAppNil @(MapJust sh) - , Refl <- lemRankMapJust (sshape sarr) - = mcast (ssxFromShape (mshape arr)) shx (Proxy @'[]) arr + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr |