diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-02 23:22:23 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-02 23:22:39 +0200 |
commit | 8344417a8ed9ad59337ce3b880e2fadf89bec964 (patch) | |
tree | f0e4fdc9d7a272f4f76dded3aae9871d4c3b7c66 /src/Data/Array/Nested/Convert.hs | |
parent | 75ee1572b75b45dcdc50e3af82ed50259ca77df0 (diff) |
WIP simplify Castable
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index fe590d1..4c9c9a1 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,10 +1,12 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( -- * Shape/index/list casting functions @@ -140,6 +142,120 @@ instance Category Castable where id = CastId (.) = CastCmp +data Rec a b where + RecId :: Rec a a + RecRR :: Castable a b -> Rec (Ranked n a) (Ranked n b) + RecSS :: Castable a b -> Rec (Shaped sh a) (Shaped sh b) + RecXX :: Castable a b -> Rec (Mixed sh a) (Mixed sh b) + +data RecEq t f b where + RecEq :: RecEq (f a) f b + +recEq :: Rec t (f b) -> RecEq t f b +recEq RecId = RecEq +recEq RecRR{} = RecEq +recEq RecSS{} = RecEq +recEq RecXX{} = RecEq + +recRX :: Rec (f a) (Ranked n b) -> Rec (Mixed sh a) (Mixed sh b) +recRX RecId = RecId +recRX (RecRR c) = RecXX c + +recSX :: Rec (f a) (Shaped sh b) -> Rec (Mixed sh' a) (Mixed sh' b) +recSX RecId = RecId +recSX (RecSS c) = RecXX c + +recXR :: Rec (f a) (Mixed sh b) -> Rec (Ranked n a) (Ranked n b) +recXR RecId = RecId +recXR (RecXX c) = RecRR c + +recXS :: Rec (f a) (Mixed sh b) -> Rec (Shaped sh' a) (Shaped sh' b) +recXS RecId = RecId +recXS (RecXX c) = RecSS c + +recXX :: Rec (f a) (Mixed sh b) -> Rec (Mixed sh' a) (Mixed sh' b) +recXX RecId = RecId +recXX (RecXX c) = RecXX c + +recCmp :: Rec b c -> Rec a b -> Rec a c +recCmp RecId r = r +recCmp r RecId = r +recCmp (RecRR c) (RecRR c') = RecRR (CastCmp c c') +recCmp (RecSS c) (RecSS c') = RecSS (CastCmp c c') +recCmp (RecXX c) (RecXX c') = RecXX (CastCmp c c') + +type family IsArray t where + IsArray (Ranked n a) = True + IsArray (Shaped sh a) = True + IsArray (Mixed sh a) = True + IsArray _ = False + +data RSplitCastable a b where + RsplitCastableId + :: RSplitCastable a a + RSplitCastable + :: (IsArray b ~ True, IsArray c ~ True, IsArray d ~ True, IsArray e ~ True + ,Elt c, Elt d, Elt e) + => Rec d e -- possibly a recursive call + -> Castable c d -- middle stuff + -> Castable b c -- right endpoint (no Cmp) + -> RSplitCastable b e + RSplitCastableT2 + :: Castable a a' + -> Castable b b' + -> RSplitCastable (a, b) (a', b') + +rsplitCastable :: Elt a => Castable a b -> RSplitCastable a b +rsplitCastable = \case + CastCmp (CastCmp c1 c2) c3 -> rsplitCastable (CastCmp c1 (CastCmp c2 c3)) + + CastCmp c1 c2 -> case rsplitCastable c2 of + RSplitCastable rec mid right -> case c1 of + CastId -> RSplitCastable rec mid right + CastRX | RecEq <- recEq rec -> RSplitCastable (recRX rec) (CastRX `CastCmp` mid) right + CastSX | RecEq <- recEq rec -> RSplitCastable (recSX rec) (CastSX `CastCmp` mid) right + CastXR | RecEq <- recEq rec -> RSplitCastable (recXR rec) (CastXR `CastCmp` mid) right + CastXS | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS `CastCmp` mid) right + CastXS' sh | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS' sh `CastCmp` mid) right + CastXX' ssh | RecEq <- recEq rec -> RSplitCastable (recXX rec) (CastXX' ssh `CastCmp` mid) right + CastRR c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecRR c') rec) mid right + CastSS c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecSS c') rec) mid right + CastXX c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecXX c') rec) mid right + + CastId -> RsplitCastableId + c@CastRX -> RSplitCastable RecId c CastId + c@CastSX -> RSplitCastable RecId c CastId + c@CastXR -> RSplitCastable RecId c CastId + c@CastXS -> RSplitCastable RecId c CastId + c@CastXS'{} -> RSplitCastable RecId c CastId + c@CastXX'{} -> RSplitCastable RecId c CastId + + CastRR c | Dict <- transferElt c -> RSplitCastable (RecRR c) CastId CastId + CastSS c | Dict <- transferElt c -> RSplitCastable (RecSS c) CastId CastId + CastXX c | Dict <- transferElt c -> RSplitCastable (RecXX c) CastId CastId + + CastT2 c1 c2 -> RSplitCastableT2 c1 c2 + + Cast0X -> _ + CastX0 -> _ + +transferElt :: Elt a => Castable a b -> Dict Elt b +transferElt = \case + CastId -> Dict + CastCmp c1 c2 | Dict <- transferElt c2, Dict <- transferElt c1 -> Dict + CastRX -> Dict + CastSX -> Dict + CastXR -> Dict + CastXS -> Dict + CastXS' _ -> Dict + CastXX' _ -> Dict + CastRR c | Dict <- transferElt c -> Dict + CastSS c | Dict <- transferElt c -> Dict + CastXX c | Dict <- transferElt c -> Dict + CastT2 c1 c2 | Dict <- transferElt c1, Dict <- transferElt c2 -> Dict + Cast0X -> Dict + CastX0 -> Dict + castCastable :: (Elt a, Elt b) => Castable a b -> a -> b castCastable = \c x -> munScalar (go c (mscalar x)) where |