aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs2
-rw-r--r--src/Data/Array/Nested/Convert.hs149
-rw-r--r--src/Data/Array/Nested/Trace.hs4
3 files changed, 77 insertions, 78 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index bb22d29..c3635e9 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -77,7 +77,7 @@ module Data.Array.Nested (
mtoXArrayPrim, mfromXArrayPrim,
mcast,
mcastToShaped, mtoRanked,
- castCastable, Castable(..),
+ convert, Conversion(..),
-- ** Additional arithmetic operations
--
-- $integralRealFloat
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index d07bab9..723e965 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -14,12 +14,12 @@ module Data.Array.Nested.Convert (
ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
-- * Array conversions
- castCastable,
- Castable(..),
+ convert,
+ Conversion(..),
-- * Special cases of array conversions
--
- -- | These functions can all be implemented using 'castCastable' in some way,
+ -- | These functions can all be implemented using 'convert' in some way,
-- but some have fewer constraints.
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
@@ -102,108 +102,107 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-- * Array conversions
-- | The constructors that perform runtime shape checking are marked with a
--- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
+-- @'@: 'ConvXS'' and 'ConvXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
-- 'Shaped', go via 'Mixed'.
--
--- The guiding principle behind 'Castable' is that it should represent the
+-- The guiding principle behind 'Conversion' is that it should represent the
-- array restructurings, or perhaps re-presentations, that do not change the
-- underlying 'XArray's. This leads to the inclusion of some operations that do
--- not look like a cast at first glance, like 'CastZip'; with the underlying
--- representation in mind, however, they are very much like a cast.
-data Castable a b where
- CastId :: Castable a a
- CastCmp :: Castable b c -> Castable a b -> Castable a c
-
- CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a)
- CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a)
-
- CastXR :: Elt a
- => Castable (Mixed sh a) (Ranked (Rank sh) a)
- CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a)
- CastXS' :: (Rank sh ~ Rank sh', Elt a)
+-- not look like simple conversions (casts) at first glance, like 'ConvZip'.
+data Conversion a b where
+ ConvId :: Conversion a a
+ ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c
+
+ ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a)
+ ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a)
+
+ ConvXR :: Elt a
+ => Conversion (Mixed sh a) (Ranked (Rank sh) a)
+ ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a)
+ ConvXS' :: (Rank sh ~ Rank sh', Elt a)
=> ShS sh'
- -> Castable (Mixed sh a) (Shaped sh' a)
+ -> Conversion (Mixed sh a) (Shaped sh' a)
- CastXX' :: (Rank sh ~ Rank sh', Elt a)
+ ConvXX' :: (Rank sh ~ Rank sh', Elt a)
=> StaticShX sh'
- -> Castable (Mixed sh a) (Mixed sh' a)
-
- CastRR :: Castable a b
- -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b
- -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b
- -> Castable (Mixed sh a) (Mixed sh b)
- CastT2 :: Castable a a'
- -> Castable b b'
- -> Castable (a, b) (a', b')
-
- Cast0X :: Elt a
- => Castable a (Mixed '[] a)
- CastX0 :: Castable (Mixed '[] a) a
-
- CastNest :: Elt a => StaticShX sh
- -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
- CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
-
- CastZip :: (Elt a, Elt b)
- => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
- CastUnzip :: (Elt a, Elt b)
- => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
-deriving instance Show (Castable a b)
-
-instance Category Castable where
- id = CastId
- (.) = CastCmp
-
-castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
-castCastable = \c x -> munScalar (go c (mscalar x))
+ -> Conversion (Mixed sh a) (Mixed sh' a)
+
+ ConvRR :: Conversion a b
+ -> Conversion (Ranked n a) (Ranked n b)
+ ConvSS :: Conversion a b
+ -> Conversion (Shaped sh a) (Shaped sh b)
+ ConvXX :: Conversion a b
+ -> Conversion (Mixed sh a) (Mixed sh b)
+ ConvT2 :: Conversion a a'
+ -> Conversion b b'
+ -> Conversion (a, b) (a', b')
+
+ Conv0X :: Elt a
+ => Conversion a (Mixed '[] a)
+ ConvX0 :: Conversion (Mixed '[] a) a
+
+ ConvNest :: Elt a => StaticShX sh
+ -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
+ ConvZip :: (Elt a, Elt b)
+ => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ ConvUnzip :: (Elt a, Elt b)
+ => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+deriving instance Show (Conversion a b)
+
+instance Category Conversion where
+ id = ConvId
+ (.) = ConvCmp
+
+convert :: (Elt a, Elt b) => Conversion a b -> a -> b
+convert = \c x -> munScalar (go c (mscalar x))
where
- -- The 'esh' is the extension shape: the casting happens under a whole
+ -- The 'esh' is the extension shape: the conversion happens under a whole
-- bunch of additional dimensions that it does not touch. These dimensions
-- are 'esh'.
-- The strategy is to unwind step-by-step to a large Mixed array, and to
- -- perform the required checks and castings when re-nesting back up.
- go :: Castable a b -> Mixed esh a -> Mixed esh b
- go CastId x = x
- go (CastCmp c1 c2) x = go c1 (go c2 x)
- go CastRX (M_Ranked x) = x
- go CastSX (M_Shaped x) = x
- go (CastXR @_ @sh) (M_Nest @esh esh x)
+ -- perform the required checks and conversions when re-nesting back up.
+ go :: Conversion a b -> Mixed esh a -> Mixed esh b
+ go ConvId x = x
+ go (ConvCmp c1 c2) x = go c1 (go c2 x)
+ go ConvRX (M_Ranked x) = x
+ go ConvSX (M_Shaped x) = x
+ go (ConvXR @_ @sh) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
= let ssx' = ssxAppend (ssxFromShX esh)
(ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh))))
in M_Ranked (M_Nest esh (mcast ssx' x))
- go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
- go (CastXS' @sh @sh' sh') (M_Nest @esh esh x)
+ go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x)
+ go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
x))
- go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x)
+ go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x
- go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
- go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
- go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
- go Cast0X (x :: Mixed esh a)
+ go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2)
+ go Conv0X (x :: Mixed esh a)
| Refl <- lemAppNil @esh
= M_Nest (mshape x) x
- go CastX0 (M_Nest @esh _ x)
+ go ConvX0 (M_Nest @esh _ x)
| Refl <- lemAppNil @esh
= x
- go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
| Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x)
- go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
| Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh x
- go CastZip x =
+ go ConvZip x =
-- no need to check that the two esh's are equal because they were zipped previously
let (M_Nest esh x1, M_Nest _ x2) = munzip x
in M_Nest esh (mzip x1 x2)
- go CastUnzip (M_Nest esh x) =
+ go ConvUnzip (M_Nest esh x) =
let (x1, x2) = munzip x
in mzip (M_Nest esh x1) (M_Nest esh x2)
@@ -232,7 +231,7 @@ mcast ssh2 arr
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked = castCastable CastXR
+mtoRanked = convert ConvXR
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
@@ -246,7 +245,7 @@ rcastToMixed sshx rarr@(Ranked arr)
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
-mcastToShaped targetsh = castCastable (CastXS' targetsh)
+mcastToShaped targetsh = convert (ConvXS' targetsh)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 3581f10..8a29aa5 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -42,7 +42,7 @@ module Data.Array.Nested.Trace (
ShX(..), KnownShX(..), IShX,
StaticShX(..),
SMayNat(..),
- Castable(..),
+ Conversion(..),
Elt,
PrimElt,
@@ -69,4 +69,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'castCastable, 'mquotArray, 'mremArray, 'matan2Array])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])