diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..2438f68 --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,333 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +module Data.Array.Nested.Convert ( + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', + + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions + -- + -- | These functions can all be implemented using 'convert' in some way, + -- but some have fewer constraints. + rtoMixed, rcastToMixed, rcastToShaped, + stoMixed, scastToMixed, stoRanked, + mcast, mcastToShaped, mtoRanked, +) where + +import Control.Category +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: IxS sh i -> IxR (Rank sh) i +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported +-- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ _ = error "unreachable" + +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +-- shsFromShX re-exported + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn@SNat :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl + + +-- * Special cases of array conversions + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked = convert ConvXR + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => ShS sh' -> Mixed sh a -> Shaped sh' a +mcastToShaped targetsh = convert (ConvXS' targetsh) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped targetsh arr |