diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 133 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Trace.hs | 4 | 
2 files changed, 68 insertions, 69 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d07bab9..723e965 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -14,12 +14,12 @@ module Data.Array.Nested.Convert (    ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,    -- * Array conversions -  castCastable, -  Castable(..), +  convert, +  Conversion(..),    -- * Special cases of array conversions    -- -  -- | These functions can all be implemented using 'castCastable' in some way, +  -- | These functions can all be implemented using 'convert' in some way,    -- but some have fewer constraints.    rtoMixed, rcastToMixed, rcastToShaped,    stoMixed, scastToMixed, stoRanked, @@ -102,108 +102,107 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh  -- * Array conversions  -- | The constructors that perform runtime shape checking are marked with a --- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure +-- @'@: 'ConvXS'' and 'ConvXX''. For the other constructors, the types ensure  -- that the shapes are already compatible. To convert between 'Ranked' and  -- 'Shaped', go via 'Mixed'.  -- --- The guiding principle behind 'Castable' is that it should represent the +-- The guiding principle behind 'Conversion' is that it should represent the  -- array restructurings, or perhaps re-presentations, that do not change the  -- underlying 'XArray's. This leads to the inclusion of some operations that do --- not look like a cast at first glance, like 'CastZip'; with the underlying --- representation in mind, however, they are very much like a cast. -data Castable a b where -  CastId  :: Castable a a -  CastCmp :: Castable b c -> Castable a b -> Castable a c +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +data Conversion a b where +  ConvId  :: Conversion a a +  ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c -  CastRX  :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a) -  CastSX  :: Castable (Shaped sh a) (Mixed (MapJust sh) a) +  ConvRX  :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) +  ConvSX  :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) -  CastXR  :: Elt a -          => Castable (Mixed sh a) (Ranked (Rank sh) a) -  CastXS  :: Castable (Mixed (MapJust sh) a) (Shaped sh a) -  CastXS' :: (Rank sh ~ Rank sh', Elt a) +  ConvXR  :: Elt a +          => Conversion (Mixed sh a) (Ranked (Rank sh) a) +  ConvXS  :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) +  ConvXS' :: (Rank sh ~ Rank sh', Elt a)            => ShS sh' -          -> Castable (Mixed sh a) (Shaped sh' a) +          -> Conversion (Mixed sh a) (Shaped sh' a) -  CastXX' :: (Rank sh ~ Rank sh', Elt a) +  ConvXX' :: (Rank sh ~ Rank sh', Elt a)            => StaticShX sh' -          -> Castable (Mixed sh a) (Mixed sh' a) +          -> Conversion (Mixed sh a) (Mixed sh' a) -  CastRR  :: Castable a b -          -> Castable (Ranked n a) (Ranked n b) -  CastSS  :: Castable a b -          -> Castable (Shaped sh a) (Shaped sh b) -  CastXX  :: Castable a b -          -> Castable (Mixed sh a) (Mixed sh b) -  CastT2  :: Castable a a' -          -> Castable b b' -          -> Castable (a, b) (a', b') +  ConvRR  :: Conversion a b +          -> Conversion (Ranked n a) (Ranked n b) +  ConvSS  :: Conversion a b +          -> Conversion (Shaped sh a) (Shaped sh b) +  ConvXX  :: Conversion a b +          -> Conversion (Mixed sh a) (Mixed sh b) +  ConvT2  :: Conversion a a' +          -> Conversion b b' +          -> Conversion (a, b) (a', b') -  Cast0X  :: Elt a -          => Castable a (Mixed '[] a) -  CastX0  :: Castable (Mixed '[] a) a +  Conv0X  :: Elt a +          => Conversion a (Mixed '[] a) +  ConvX0  :: Conversion (Mixed '[] a) a -  CastNest   :: Elt a => StaticShX sh -             -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) -  CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) +  ConvNest   :: Elt a => StaticShX sh +             -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) +  ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) -  CastZip   :: (Elt a, Elt b) -            => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) -  CastUnzip :: (Elt a, Elt b) -            => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) -deriving instance Show (Castable a b) +  ConvZip   :: (Elt a, Elt b) +            => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) +  ConvUnzip :: (Elt a, Elt b) +            => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) -instance Category Castable where -  id = CastId -  (.) = CastCmp +instance Category Conversion where +  id = ConvId +  (.) = ConvCmp -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x))    where -    -- The 'esh' is the extension shape: the casting happens under a whole +    -- The 'esh' is the extension shape: the conversion happens under a whole      -- bunch of additional dimensions that it does not touch. These dimensions      -- are 'esh'.      -- The strategy is to unwind step-by-step to a large Mixed array, and to -    -- perform the required checks and castings when re-nesting back up. -    go :: Castable a b -> Mixed esh a -> Mixed esh b -    go CastId x = x -    go (CastCmp c1 c2) x = go c1 (go c2 x) -    go CastRX (M_Ranked x) = x -    go CastSX (M_Shaped x) = x -    go (CastXR @_ @sh) (M_Nest @esh esh x) +    -- perform the required checks and conversions when re-nesting back up. +    go :: Conversion a b -> Mixed esh a -> Mixed esh b +    go ConvId x = x +    go (ConvCmp c1 c2) x = go c1 (go c2 x) +    go ConvRX (M_Ranked x) = x +    go ConvSX (M_Shaped x) = x +    go (ConvXR @_ @sh) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)        = let ssx' = ssxAppend (ssxFromShX esh)                               (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh))))          in M_Ranked (M_Nest esh (mcast ssx' x)) -    go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x) -    go (CastXS' @sh @sh' sh') (M_Nest @esh esh x) +    go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) +    go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))                                      x)) -    go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x) +    go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x -    go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) -    go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) -    go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) -    go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) -    go Cast0X (x :: Mixed esh a) +    go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) +    go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) +    go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) +    go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) +    go Conv0X (x :: Mixed esh a)        | Refl <- lemAppNil @esh        = M_Nest (mshape x) x -    go CastX0 (M_Nest @esh _ x) +    go ConvX0 (M_Nest @esh _ x)        | Refl <- lemAppNil @esh        = x -    go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x) +    go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x)        | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) -    go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) +    go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))        | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Nest esh x -    go CastZip x = +    go ConvZip x =        -- no need to check that the two esh's are equal because they were zipped previously        let (M_Nest esh x1, M_Nest _ x2) = munzip x        in M_Nest esh (mzip x1 x2) -    go CastUnzip (M_Nest esh x) = +    go ConvUnzip (M_Nest esh x) =        let (x1, x2) = munzip x        in mzip (M_Nest esh x1) (M_Nest esh x2) @@ -232,7 +231,7 @@ mcast ssh2 arr    = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr  mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable CastXR +mtoRanked = convert ConvXR  rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a  rtoMixed (Ranked arr) = arr @@ -246,7 +245,7 @@ rcastToMixed sshx rarr@(Ranked arr)  mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')                => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh) +mcastToShaped targetsh = convert (ConvXS' targetsh)  stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a  stoMixed (Shaped arr) = arr diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index 3581f10..8a29aa5 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -42,7 +42,7 @@ module Data.Array.Nested.Trace (    ShX(..), KnownShX(..), IShX,    StaticShX(..),    SMayNat(..), -  Castable(..), +  Conversion(..),    Elt,    PrimElt, @@ -69,4 +69,4 @@ import Data.Array.Nested.Trace.TH  $(concat <$> mapM convertFun -    ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'castCastable, 'mquotArray, 'mremArray, 'matan2Array]) +    ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) | 
