diff options
-rw-r--r-- | src/Data/Array/Nested.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 149 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace.hs | 4 |
3 files changed, 77 insertions, 78 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index bb22d29..c3635e9 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -77,7 +77,7 @@ module Data.Array.Nested ( mtoXArrayPrim, mfromXArrayPrim, mcast, mcastToShaped, mtoRanked, - castCastable, Castable(..), + convert, Conversion(..), -- ** Additional arithmetic operations -- -- $integralRealFloat diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d07bab9..723e965 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -14,12 +14,12 @@ module Data.Array.Nested.Convert ( ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, -- * Array conversions - castCastable, - Castable(..), + convert, + Conversion(..), -- * Special cases of array conversions -- - -- | These functions can all be implemented using 'castCastable' in some way, + -- | These functions can all be implemented using 'convert' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, @@ -102,108 +102,107 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -- * Array conversions -- | The constructors that perform runtime shape checking are marked with a --- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure +-- @'@: 'ConvXS'' and 'ConvXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and -- 'Shaped', go via 'Mixed'. -- --- The guiding principle behind 'Castable' is that it should represent the +-- The guiding principle behind 'Conversion' is that it should represent the -- array restructurings, or perhaps re-presentations, that do not change the -- underlying 'XArray's. This leads to the inclusion of some operations that do --- not look like a cast at first glance, like 'CastZip'; with the underlying --- representation in mind, however, they are very much like a cast. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a) - CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a) - - CastXR :: Elt a - => Castable (Mixed sh a) (Ranked (Rank sh) a) - CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a) - CastXS' :: (Rank sh ~ Rank sh', Elt a) +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) => ShS sh' - -> Castable (Mixed sh a) (Shaped sh' a) + -> Conversion (Mixed sh a) (Shaped sh' a) - CastXX' :: (Rank sh ~ Rank sh', Elt a) + ConvXX' :: (Rank sh ~ Rank sh', Elt a) => StaticShX sh' - -> Castable (Mixed sh a) (Mixed sh' a) - - CastRR :: Castable a b - -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b - -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b - -> Castable (Mixed sh a) (Mixed sh b) - CastT2 :: Castable a a' - -> Castable b b' - -> Castable (a, b) (a', b') - - Cast0X :: Elt a - => Castable a (Mixed '[] a) - CastX0 :: Castable (Mixed '[] a) a - - CastNest :: Elt a => StaticShX sh - -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) - CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) - - CastZip :: (Elt a, Elt b) - => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) - CastUnzip :: (Elt a, Elt b) - => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) -deriving instance Show (Castable a b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) where - -- The 'esh' is the extension shape: the casting happens under a whole + -- The 'esh' is the extension shape: the conversion happens under a whole -- bunch of additional dimensions that it does not touch. These dimensions -- are 'esh'. -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go CastRX (M_Ranked x) = x - go CastSX (M_Shaped x) = x - go (CastXR @_ @sh) (M_Nest @esh esh x) + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let ssx' = ssxAppend (ssxFromShX esh) (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) in M_Ranked (M_Nest esh (mcast ssx' x)) - go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x) - go (CastXS' @sh @sh' sh') (M_Nest @esh esh x) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) x)) - go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) - go Cast0X (x :: Mixed esh a) + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) | Refl <- lemAppNil @esh = M_Nest (mshape x) x - go CastX0 (M_Nest @esh _ x) + go ConvX0 (M_Nest @esh _ x) | Refl <- lemAppNil @esh = x - go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) - go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh x - go CastZip x = + go ConvZip x = -- no need to check that the two esh's are equal because they were zipped previously let (M_Nest esh x1, M_Nest _ x2) = munzip x in M_Nest esh (mzip x1 x2) - go CastUnzip (M_Nest esh x) = + go ConvUnzip (M_Nest esh x) = let (x1, x2) = munzip x in mzip (M_Nest esh x1) (M_Nest esh x2) @@ -232,7 +231,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable CastXR +mtoRanked = convert ConvXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -246,7 +245,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh) +mcastToShaped targetsh = convert (ConvXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index 3581f10..8a29aa5 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -42,7 +42,7 @@ module Data.Array.Nested.Trace ( ShX(..), KnownShX(..), IShX, StaticShX(..), SMayNat(..), - Castable(..), + Conversion(..), Elt, PrimElt, @@ -69,4 +69,4 @@ import Data.Array.Nested.Trace.TH $(concat <$> mapM convertFun - ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'castCastable, 'mquotArray, 'mremArray, 'matan2Array]) + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) |