{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( -- * Shape/index/list casting functions ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, ixsFromIxX, shsFromShX, ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, -- * Array conversions castCastable, Castable(..), -- * Special cases of array conversions -- -- | These functions can all be implemented using 'castCastable' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, ) where import Control.Category import Data.Proxy import Data.Type.Equality import GHC.TypeLits import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types -- * Shape/index/list casting functions -- * To ranked ixrFromIxS :: IxS sh i -> IxR (Rank sh) i ixrFromIxS ZIS = ZIR ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix ixrFromIxX :: IxX sh i -> IxR (Rank sh) i ixrFromIxX ZIX = ZIR ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx shrFromShS :: ShS sh -> IShR (Rank sh) shrFromShS ZSS = ZSR shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh -- shrFromShX re-exported -- shrFromShX2 re-exported -- * To shaped -- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh -- ixsFromIxR = \ix -> go ix _ -- where -- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r -- go ZIR k = k ZIS -- go (i :.: ix) k = go ix (i :.$) ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i ixsFromIxX ZSS ZIX = ZIS ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx -- shsFromShX re-exported -- * To mixed ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i ixxFromIxR ZIR = ZIX ixxFromIxR (n :.: (idx :: IxR m i)) = castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) (n :.% ixxFromIxR idx) ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i ixxFromIxS ZIS = ZIX ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i shxFromShR ZSR = ZSX shxFromShR (n :$: (idx :: ShR m i)) = castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shxFromShR idx) shxFromShS :: ShS sh -> IShX (MapJust sh) shxFromShS ZSS = ZSX shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -- * Array conversions -- | The constructors that perform runtime shape checking are marked with a -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and -- 'Shaped', go via 'Mixed'. data Castable a b where CastId :: Castable a a CastCmp :: Castable b c -> Castable a b -> Castable a c CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a) CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a) CastXR :: Elt a => Castable (Mixed sh a) (Ranked (Rank sh) a) CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a) CastXS' :: (Rank sh ~ Rank sh', Elt a) => ShS sh' -> Castable (Mixed sh a) (Shaped sh' a) CastXX' :: (Rank sh ~ Rank sh', Elt a) => StaticShX sh' -> Castable (Mixed sh a) (Mixed sh' a) CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) CastT2 :: Castable a a' -> Castable b b' -> Castable (a, b) (a', b') Cast0X :: Elt a => Castable a (Mixed '[] a) CastX0 :: Castable (Mixed '[] a) a instance Category Castable where id = CastId (.) = CastCmp data Rec a b where RecId :: Rec a a RecRR :: Castable a b -> Rec (Ranked n a) (Ranked n b) RecSS :: Castable a b -> Rec (Shaped sh a) (Shaped sh b) RecXX :: Castable a b -> Rec (Mixed sh a) (Mixed sh b) data RecEq t f b where RecEq :: RecEq (f a) f b recEq :: Rec t (f b) -> RecEq t f b recEq RecId = RecEq recEq RecRR{} = RecEq recEq RecSS{} = RecEq recEq RecXX{} = RecEq recRX :: Rec (f a) (Ranked n b) -> Rec (Mixed sh a) (Mixed sh b) recRX RecId = RecId recRX (RecRR c) = RecXX c recSX :: Rec (f a) (Shaped sh b) -> Rec (Mixed sh' a) (Mixed sh' b) recSX RecId = RecId recSX (RecSS c) = RecXX c recXR :: Rec (f a) (Mixed sh b) -> Rec (Ranked n a) (Ranked n b) recXR RecId = RecId recXR (RecXX c) = RecRR c recXS :: Rec (f a) (Mixed sh b) -> Rec (Shaped sh' a) (Shaped sh' b) recXS RecId = RecId recXS (RecXX c) = RecSS c recXX :: Rec (f a) (Mixed sh b) -> Rec (Mixed sh' a) (Mixed sh' b) recXX RecId = RecId recXX (RecXX c) = RecXX c recCmp :: Rec b c -> Rec a b -> Rec a c recCmp RecId r = r recCmp r RecId = r recCmp (RecRR c) (RecRR c') = RecRR (CastCmp c c') recCmp (RecSS c) (RecSS c') = RecSS (CastCmp c c') recCmp (RecXX c) (RecXX c') = RecXX (CastCmp c c') type family IsArray t where IsArray (Ranked n a) = True IsArray (Shaped sh a) = True IsArray (Mixed sh a) = True IsArray _ = False data RSplitCastable a b where RsplitCastableId :: RSplitCastable a a RSplitCastable :: (IsArray b ~ True, IsArray c ~ True, IsArray d ~ True, IsArray e ~ True ,Elt c, Elt d, Elt e) => Rec d e -- possibly a recursive call -> Castable c d -- middle stuff -> Castable b c -- right endpoint (no Cmp) -> RSplitCastable b e RSplitCastableT2 :: Castable a a' -> Castable b b' -> RSplitCastable (a, b) (a', b') rsplitCastable :: Elt a => Castable a b -> RSplitCastable a b rsplitCastable = \case CastCmp (CastCmp c1 c2) c3 -> rsplitCastable (CastCmp c1 (CastCmp c2 c3)) CastCmp c1 c2 -> case rsplitCastable c2 of RSplitCastable rec mid right -> case c1 of CastId -> RSplitCastable rec mid right CastRX | RecEq <- recEq rec -> RSplitCastable (recRX rec) (CastRX `CastCmp` mid) right CastSX | RecEq <- recEq rec -> RSplitCastable (recSX rec) (CastSX `CastCmp` mid) right CastXR | RecEq <- recEq rec -> RSplitCastable (recXR rec) (CastXR `CastCmp` mid) right CastXS | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS `CastCmp` mid) right CastXS' sh | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS' sh `CastCmp` mid) right CastXX' ssh | RecEq <- recEq rec -> RSplitCastable (recXX rec) (CastXX' ssh `CastCmp` mid) right CastRR c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecRR c') rec) mid right CastSS c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecSS c') rec) mid right CastXX c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecXX c') rec) mid right CastId -> RsplitCastableId c@CastRX -> RSplitCastable RecId c CastId c@CastSX -> RSplitCastable RecId c CastId c@CastXR -> RSplitCastable RecId c CastId c@CastXS -> RSplitCastable RecId c CastId c@CastXS'{} -> RSplitCastable RecId c CastId c@CastXX'{} -> RSplitCastable RecId c CastId CastRR c | Dict <- transferElt c -> RSplitCastable (RecRR c) CastId CastId CastSS c | Dict <- transferElt c -> RSplitCastable (RecSS c) CastId CastId CastXX c | Dict <- transferElt c -> RSplitCastable (RecXX c) CastId CastId CastT2 c1 c2 -> RSplitCastableT2 c1 c2 Cast0X -> _ CastX0 -> _ transferElt :: Elt a => Castable a b -> Dict Elt b transferElt = \case CastId -> Dict CastCmp c1 c2 | Dict <- transferElt c2, Dict <- transferElt c1 -> Dict CastRX -> Dict CastSX -> Dict CastXR -> Dict CastXS -> Dict CastXS' _ -> Dict CastXX' _ -> Dict CastRR c | Dict <- transferElt c -> Dict CastSS c | Dict <- transferElt c -> Dict CastXX c | Dict <- transferElt c -> Dict CastT2 c1 c2 | Dict <- transferElt c1, Dict <- transferElt c2 -> Dict Cast0X -> Dict CastX0 -> Dict castCastable :: (Elt a, Elt b) => Castable a b -> a -> b castCastable = \c x -> munScalar (go c (mscalar x)) where -- The 'esh' is the extension shape: the casting happens under a whole -- bunch of additional dimensions that it does not touch. These dimensions -- are 'esh'. -- The strategy is to unwind step-by-step to a large Mixed array, and to -- perform the required checks and castings when re-nesting back up. go :: Castable a b -> Mixed esh a -> Mixed esh b go CastId x = x go (CastCmp c1 c2) x = go c1 (go c2 x) go CastRX (M_Ranked x) = x go CastSX (M_Shaped x) = x go (CastXR @_ @sh) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let ssx' = ssxAppend (ssxFromShX esh) (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) in M_Ranked (M_Nest esh (mcast ssx' x)) go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x) go (CastXS' @sh @sh' sh') (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) x)) go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x) | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) go Cast0X (x :: Mixed esh a) | Refl <- lemAppNil @esh = M_Nest (mshape x) x go CastX0 (M_Nest @esh _ x) | Refl <- lemAppNil @esh = x lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' -> Rank (esh ++ sh) :~: Rank (esh ++ sh') lemRankAppRankEq _ _ _ = unsafeCoerceRefl lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl -- * Special cases of array conversions mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a mcast ssh2 arr | Refl <- lemAppNil @sh1 , Refl <- lemAppNil @sh2 = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked = castCastable CastXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr -- | A more weakly-typed version of 'rtoMixed' that does a runtime shape -- compatibility check. rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a rcastToMixed sshx rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank rarr) = mcast sshx arr mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a mcastToShaped targetsh = castCastable (CastXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr -- | A more weakly-typed version of 'stoMixed' that does a runtime shape -- compatibility check. scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => StaticShX sh' -> Shaped sh a -> Mixed sh' a scastToMixed sshx sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mcast sshx arr stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr