diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index cea2489..184e6ae 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,10 +1,12 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( -- * Shape/index/list casting functions @@ -15,6 +17,7 @@ module Data.Array.Nested.Convert ( -- * Array conversions castCastable, Castable(..), + transferEltCastable, -- * Special cases of array conversions -- @@ -155,6 +158,124 @@ instance Category Castable where id = CastId (.) = CastCmp +data Rec a b where + RecId :: Rec a a + RecRR :: Castable a b -> Rec (Ranked n a) (Ranked n b) + RecSS :: Castable a b -> Rec (Shaped sh a) (Shaped sh b) + RecXX :: Castable a b -> Rec (Mixed sh a) (Mixed sh b) + +data RecEq t f b where + RecEq :: RecEq (f a) f b + +recEq :: Rec t (f b) -> RecEq t f b +recEq RecId = RecEq +recEq RecRR{} = RecEq +recEq RecSS{} = RecEq +recEq RecXX{} = RecEq + +recRX :: Rec (f a) (Ranked n b) -> Rec (Mixed sh a) (Mixed sh b) +recRX RecId = RecId +recRX (RecRR c) = RecXX c + +recSX :: Rec (f a) (Shaped sh b) -> Rec (Mixed sh' a) (Mixed sh' b) +recSX RecId = RecId +recSX (RecSS c) = RecXX c + +recXR :: Rec (f a) (Mixed sh b) -> Rec (Ranked n a) (Ranked n b) +recXR RecId = RecId +recXR (RecXX c) = RecRR c + +recXS :: Rec (f a) (Mixed sh b) -> Rec (Shaped sh' a) (Shaped sh' b) +recXS RecId = RecId +recXS (RecXX c) = RecSS c + +recXX :: Rec (f a) (Mixed sh b) -> Rec (Mixed sh' a) (Mixed sh' b) +recXX RecId = RecId +recXX (RecXX c) = RecXX c + +recCmp :: Rec b c -> Rec a b -> Rec a c +recCmp RecId r = r +recCmp r RecId = r +recCmp (RecRR c) (RecRR c') = RecRR (CastCmp c c') +recCmp (RecSS c) (RecSS c') = RecSS (CastCmp c c') +recCmp (RecXX c) (RecXX c') = RecXX (CastCmp c c') + +type family IsArray t where + IsArray (Ranked n a) = True + IsArray (Shaped sh a) = True + IsArray (Mixed sh a) = True + IsArray _ = False + +data RSplitCastable a b where + RsplitCastableId + :: RSplitCastable a a + RSplitCastable + :: (IsArray b ~ True, IsArray c ~ True, IsArray d ~ True, IsArray e ~ True + ,Elt c, Elt d, Elt e) + => Rec d e -- possibly a recursive call + -> Castable c d -- middle stuff + -> Castable b c -- right endpoint (no Cmp) + -> RSplitCastable b e + RSplitCastableT2 + :: Castable a a' + -> Castable b b' + -> RSplitCastable (a, b) (a', b') + +rsplitCastable :: Elt a => Castable a b -> RSplitCastable a b +rsplitCastable = \case + CastCmp (CastCmp c1 c2) c3 -> rsplitCastable (CastCmp c1 (CastCmp c2 c3)) + + CastCmp c1 c2 -> case rsplitCastable c2 of + RSplitCastable rec mid right -> case c1 of + CastId -> RSplitCastable rec mid right + CastRX | RecEq <- recEq rec -> RSplitCastable (recRX rec) (CastRX `CastCmp` mid) right + CastSX | RecEq <- recEq rec -> RSplitCastable (recSX rec) (CastSX `CastCmp` mid) right + CastXR | RecEq <- recEq rec -> RSplitCastable (recXR rec) (CastXR `CastCmp` mid) right + CastXS | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS `CastCmp` mid) right + CastXS' sh | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS' sh `CastCmp` mid) right + CastXX' ssh | RecEq <- recEq rec -> RSplitCastable (recXX rec) (CastXX' ssh `CastCmp` mid) right + CastRR c' | Dict <- transferEltCastable c' -> RSplitCastable (recCmp (RecRR c') rec) mid right + CastSS c' | Dict <- transferEltCastable c' -> RSplitCastable (recCmp (RecSS c') rec) mid right + CastXX c' | Dict <- transferEltCastable c' -> RSplitCastable (recCmp (RecXX c') rec) mid right + + CastId -> RsplitCastableId + c@CastRX -> RSplitCastable RecId c CastId + c@CastSX -> RSplitCastable RecId c CastId + c@CastXR -> RSplitCastable RecId c CastId + c@CastXS -> RSplitCastable RecId c CastId + c@CastXS'{} -> RSplitCastable RecId c CastId + c@CastXX'{} -> RSplitCastable RecId c CastId + + CastRR c | Dict <- transferEltCastable c -> RSplitCastable (RecRR c) CastId CastId + CastSS c | Dict <- transferEltCastable c -> RSplitCastable (RecSS c) CastId CastId + CastXX c | Dict <- transferEltCastable c -> RSplitCastable (RecXX c) CastId CastId + + CastT2 c1 c2 -> RSplitCastableT2 c1 c2 + + Cast0X -> undefined + CastX0 -> undefined + +transferEltCastable :: Elt a => Castable a b -> Dict Elt b +transferEltCastable = \case + CastId -> Dict + CastCmp c1 c2 | Dict <- transferEltCastable c2, Dict <- transferEltCastable c1 -> Dict + CastRX -> Dict + CastSX -> Dict + CastXR -> Dict + CastXS -> Dict + CastXS' _ -> Dict + CastXX' _ -> Dict + CastRR c | Dict <- transferEltCastable c -> Dict + CastSS c | Dict <- transferEltCastable c -> Dict + CastXX c | Dict <- transferEltCastable c -> Dict + CastT2 c1 c2 | Dict <- transferEltCastable c1, Dict <- transferEltCastable c2 -> Dict + Cast0X -> Dict + CastX0 -> Dict + CastNest _ -> Dict + CastUnnest -> Dict + CastZip -> Dict + CastUnzip -> Dict + castCastable :: (Elt a, Elt b) => Castable a b -> a -> b castCastable = \c x -> munScalar (go c (mscalar x)) where |