diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 35 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 24 | 
3 files changed, 35 insertions, 32 deletions
| diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 619d9bc..8813ab3 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -321,8 +321,8 @@ class Elt a where           -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))           -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) -  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 -        => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a +  mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 +               => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a    mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)               => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a @@ -417,10 +417,11 @@ instance Storable a => Elt (Primitive a) where      = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $          f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) -  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 -        => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) -  mcast ssh1 sh2 _ (M_Primitive sh1' arr) = -    let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' +  mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 +               => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) +  mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = +    let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' +        sh2 = shxCast' sh1 ssh2      in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)    mtranspose perm (M_Primitive sh arr) = @@ -493,8 +494,8 @@ instance (Elt a, Elt b) => Elt (a, b) where          unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)      in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 -  mcast ssh1 sh2 psh' (M_Tup2 a b) = -    M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) +  mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = +    M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b)    mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)    mconcat = @@ -600,13 +601,14 @@ instance Elt a => Elt (Mixed sh' a) where          , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)          = f (ssxAppend ssh' sshT) -  mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 -        => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) -  mcast ssh1 sh2 _ (M_Nest sh1T arr) +  mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 +               => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) +  mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr)      | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')      , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') -    = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T -      in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) +    = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T +          sh2 = shxCast' sh1 ssh2 +      in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)    mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)               => Perm is -> Mixed sh (Mixed sh' a) @@ -908,3 +910,10 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)             -> Mixed sh a -> Mixed sh b -> Mixed sh c  mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =    fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) +      => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr +  | Refl <- lemAppNil @sh1 +  , Refl <- lemAppNil @sh2 +  = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 1966270..ed89d82 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -120,7 +120,7 @@ instance Elt a => Elt (Ranked n a) where             @(NonEmpty (Mixed sh2 (Ranked n a))) $        mliftL ssh2 f (coerce l) -  mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) +  mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr)    mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) @@ -523,10 +523,8 @@ rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)  mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a  mtoRanked arr -  | Refl <- lemAppNil @sh -  , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) -  , Refl <- lemRankReplicate (shxRank (mshape arr)) -  = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) +  | Refl <- lemRankReplicate (shxRank (mshape arr)) +  = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)    where      convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)      convSh ZSX = ZSX diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index f8df5aa..0a230c4 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -118,7 +118,7 @@ instance Elt a => Elt (Shaped sh a) where             @(NonEmpty (Mixed sh2 (Shaped sh a))) $        mliftL ssh2 f (coerce l) -  mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) +  mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr)    mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) @@ -324,13 +324,13 @@ stoVector :: PrimElt a => Shaped sh a -> VS.Vector a  stoVector = coerce mtoVector  sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))  sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1  sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim +sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim  stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]  stoListOuter (Shaped arr) = coerce (mtoListOuter arr) @@ -476,10 +476,8 @@ stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)  mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')                => Mixed sh a -> ShS sh' -> Shaped sh' a  mcastToShaped arr targetsh -  | Refl <- lemAppNil @sh -  , Refl <- lemAppNil @(MapJust sh') -  , Refl <- lemRankMapJust targetsh -  = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) +  | Refl <- lemRankMapJust targetsh +  = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)  stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a  stoMixed (Shaped arr) = arr @@ -487,9 +485,7 @@ stoMixed (Shaped arr) = arr  -- | A more weakly-typed version of 'stoMixed' that does a runtime shape  -- compatibility check.  scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') -             => IShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed shx sarr@(Shaped arr) -  | Refl <- lemAppNil @sh' -  , Refl <- lemAppNil @(MapJust sh) -  , Refl <- lemRankMapJust (sshape sarr) -  = mcast (ssxFromShape (mshape arr)) shx (Proxy @'[]) arr +             => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) +  | Refl <- lemRankMapJust (sshape sarr) +  = mcast sshx arr | 
