aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Convert.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-09 11:05:06 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-09 11:05:06 +0200
commit16ec2aaef12130b2ec6cbd6c83d8688b52fc1577 (patch)
tree55320c317ffc9ab700f979897e2721375cf43a04 /src/Data/Array/Nested/Convert.hs
parent5e77fbb23845292920cdfdc6cca8fe4fbdd0709a (diff)
Rename Castable to ConversionHEADmaster
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r--src/Data/Array/Nested/Convert.hs149
1 files changed, 74 insertions, 75 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index d07bab9..723e965 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -14,12 +14,12 @@ module Data.Array.Nested.Convert (
ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
-- * Array conversions
- castCastable,
- Castable(..),
+ convert,
+ Conversion(..),
-- * Special cases of array conversions
--
- -- | These functions can all be implemented using 'castCastable' in some way,
+ -- | These functions can all be implemented using 'convert' in some way,
-- but some have fewer constraints.
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
@@ -102,108 +102,107 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-- * Array conversions
-- | The constructors that perform runtime shape checking are marked with a
--- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
+-- @'@: 'ConvXS'' and 'ConvXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
-- 'Shaped', go via 'Mixed'.
--
--- The guiding principle behind 'Castable' is that it should represent the
+-- The guiding principle behind 'Conversion' is that it should represent the
-- array restructurings, or perhaps re-presentations, that do not change the
-- underlying 'XArray's. This leads to the inclusion of some operations that do
--- not look like a cast at first glance, like 'CastZip'; with the underlying
--- representation in mind, however, they are very much like a cast.
-data Castable a b where
- CastId :: Castable a a
- CastCmp :: Castable b c -> Castable a b -> Castable a c
-
- CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a)
- CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a)
-
- CastXR :: Elt a
- => Castable (Mixed sh a) (Ranked (Rank sh) a)
- CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a)
- CastXS' :: (Rank sh ~ Rank sh', Elt a)
+-- not look like simple conversions (casts) at first glance, like 'ConvZip'.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
=> ShS sh'
- -> Castable (Mixed sh a) (Shaped sh' a)
+ -> Conversion (Mixed sh a) (Shaped sh' a)
- CastXX' :: (Rank sh ~ Rank sh', Elt a)
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
=> StaticShX sh'
- -> Castable (Mixed sh a) (Mixed sh' a)
-
- CastRR :: Castable a b
- -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b
- -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b
- -> Castable (Mixed sh a) (Mixed sh b)
- CastT2 :: Castable a a'
- -> Castable b b'
- -> Castable (a, b) (a', b')
-
- Cast0X :: Elt a
- => Castable a (Mixed '[] a)
- CastX0 :: Castable (Mixed '[] a) a
-
- CastNest :: Elt a => StaticShX sh
- -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
- CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
-
- CastZip :: (Elt a, Elt b)
- => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
- CastUnzip :: (Elt a, Elt b)
- => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
-deriving instance Show (Castable a b)
-
-instance Category Castable where
- id = CastId
- (.) = CastCmp
-
-castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
-castCastable = \c x -> munScalar (go c (mscalar x))
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
where
- -- The 'esh' is the extension shape: the casting happens under a whole
+ -- The 'esh' is the extension shape: the conversion happens under a whole
-- bunch of additional dimensions that it does not touch. These dimensions
-- are 'esh'.
-- The strategy is to unwind step-by-step to a large Mixed array, and to
- -- perform the required checks and castings when re-nesting back up.
- go :: Castable a b -> Mixed esh a -> Mixed esh b
- go CastId x = x
- go (CastCmp c1 c2) x = go c1 (go c2 x)
- go CastRX (M_Ranked x) = x
- go CastSX (M_Shaped x) = x
- go (CastXR @_ @sh) (M_Nest @esh esh x)
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
= let ssx' = ssxAppend (ssxFromShX esh)
(ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh))))
in M_Ranked (M_Nest esh (mcast ssx' x))
- go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
- go (CastXS' @sh @sh' sh') (M_Nest @esh esh x)
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
x))
- go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
- go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
- go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
- go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
- go Cast0X (x :: Mixed esh a)
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
| Refl <- lemAppNil @esh
= M_Nest (mshape x) x
- go CastX0 (M_Nest @esh _ x)
+ go ConvX0 (M_Nest @esh _ x)
| Refl <- lemAppNil @esh
= x
- go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
| Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x)
- go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
| Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh x
- go CastZip x =
+ go ConvZip x =
-- no need to check that the two esh's are equal because they were zipped previously
let (M_Nest esh x1, M_Nest _ x2) = munzip x
in M_Nest esh (mzip x1 x2)
- go CastUnzip (M_Nest esh x) =
+ go ConvUnzip (M_Nest esh x) =
let (x1, x2) = munzip x
in mzip (M_Nest esh x1) (M_Nest esh x2)
@@ -232,7 +231,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked = castCastable CastXR
+mtoRanked = convert ConvXR
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -246,7 +245,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
-mcastToShaped targetsh = castCastable (CastXS' targetsh)
+mcastToShaped targetsh = convert (ConvXS' targetsh)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr