aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Convert.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r--src/Data/Array/Nested/Convert.hs106
1 files changed, 76 insertions, 30 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index b3a2c63..cea2489 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -104,25 +104,52 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
-- 'Shaped', go via 'Mixed'.
+--
+-- The guiding principle behind 'Castable' is that it should represent the
+-- array restructurings, or perhaps re-presentations, that do not change the
+-- underlying 'XArray's. This leads to the inclusion of some operations that do
+-- not look like a cast at first glance, like 'CastZip'; with the underlying
+-- representation in mind, however, they are very much like a cast.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
-
- CastXR :: Elt b
- => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
-
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
-
- CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
- -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
+ CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a)
+ CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a)
+
+ CastXR :: Elt a
+ => Castable (Mixed sh a) (Ranked (Rank sh) a)
+ CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a)
+ CastXS' :: (Rank sh ~ Rank sh', Elt a)
+ => ShS sh'
+ -> Castable (Mixed sh a) (Shaped sh' a)
+
+ CastXX' :: (Rank sh ~ Rank sh', Elt a)
+ => StaticShX sh'
+ -> Castable (Mixed sh a) (Mixed sh' a)
+
+ CastRR :: Castable a b
+ -> Castable (Ranked n a) (Ranked n b)
+ CastSS :: Castable a b
+ -> Castable (Shaped sh a) (Shaped sh b)
+ CastXX :: Castable a b
+ -> Castable (Mixed sh a) (Mixed sh b)
+ CastT2 :: Castable a a'
+ -> Castable b b'
+ -> Castable (a, b) (a', b')
+
+ Cast0X :: Elt a
+ => Castable a (Mixed '[] a)
+ CastX0 :: Castable (Mixed '[] a) a
+
+ CastNest :: Elt a => StaticShX sh
+ -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ CastZip :: (Elt a, Elt b)
+ => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ CastUnzip :: (Elt a, Elt b)
+ => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
instance Category Castable where
id = CastId
@@ -139,25 +166,44 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
- go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
+ go CastRX (M_Ranked x) = x
+ go CastSX (M_Shaped x) = x
+ go (CastXR @_ @sh) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
- = let x' = go c x
- ssx' = ssxAppend (ssxFromShX esh)
- (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))
- in M_Ranked (M_Nest esh (mcast ssx' x'))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
+ = let ssx' = ssxAppend (ssxFromShX esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh))))
+ in M_Ranked (M_Nest esh (mcast ssx' x))
+ go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (CastXS' @sh @sh' sh') (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
- (go c x)))
+ x))
+ go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
- | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)
+ go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Cast0X (x :: Mixed esh a)
+ | Refl <- lemAppNil @esh
+ = M_Nest (mshape x) x
+ go CastX0 (M_Nest @esh _ x)
+ | Refl <- lemAppNil @esh
+ = x
+ go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x)
+ go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
+ go CastZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go CastUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
@@ -184,7 +230,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked = castCastable (CastXR CastId)
+mtoRanked = castCastable CastXR
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -198,7 +244,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
-mcastToShaped targetsh = castCastable (CastXS' targetsh CastId)
+mcastToShaped targetsh = castCastable (CastXS' targetsh)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr