aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Convert.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-02 23:22:23 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-02 23:22:39 +0200
commit8344417a8ed9ad59337ce3b880e2fadf89bec964 (patch)
treef0e4fdc9d7a272f4f76dded3aae9871d4c3b7c66 /src/Data/Array/Nested/Convert.hs
parent75ee1572b75b45dcdc50e3af82ed50259ca77df0 (diff)
WIP simplify Castable
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r--src/Data/Array/Nested/Convert.hs116
1 files changed, 116 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index fe590d1..4c9c9a1 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -1,10 +1,12 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Convert (
-- * Shape/index/list casting functions
@@ -140,6 +142,120 @@ instance Category Castable where
id = CastId
(.) = CastCmp
+data Rec a b where
+ RecId :: Rec a a
+ RecRR :: Castable a b -> Rec (Ranked n a) (Ranked n b)
+ RecSS :: Castable a b -> Rec (Shaped sh a) (Shaped sh b)
+ RecXX :: Castable a b -> Rec (Mixed sh a) (Mixed sh b)
+
+data RecEq t f b where
+ RecEq :: RecEq (f a) f b
+
+recEq :: Rec t (f b) -> RecEq t f b
+recEq RecId = RecEq
+recEq RecRR{} = RecEq
+recEq RecSS{} = RecEq
+recEq RecXX{} = RecEq
+
+recRX :: Rec (f a) (Ranked n b) -> Rec (Mixed sh a) (Mixed sh b)
+recRX RecId = RecId
+recRX (RecRR c) = RecXX c
+
+recSX :: Rec (f a) (Shaped sh b) -> Rec (Mixed sh' a) (Mixed sh' b)
+recSX RecId = RecId
+recSX (RecSS c) = RecXX c
+
+recXR :: Rec (f a) (Mixed sh b) -> Rec (Ranked n a) (Ranked n b)
+recXR RecId = RecId
+recXR (RecXX c) = RecRR c
+
+recXS :: Rec (f a) (Mixed sh b) -> Rec (Shaped sh' a) (Shaped sh' b)
+recXS RecId = RecId
+recXS (RecXX c) = RecSS c
+
+recXX :: Rec (f a) (Mixed sh b) -> Rec (Mixed sh' a) (Mixed sh' b)
+recXX RecId = RecId
+recXX (RecXX c) = RecXX c
+
+recCmp :: Rec b c -> Rec a b -> Rec a c
+recCmp RecId r = r
+recCmp r RecId = r
+recCmp (RecRR c) (RecRR c') = RecRR (CastCmp c c')
+recCmp (RecSS c) (RecSS c') = RecSS (CastCmp c c')
+recCmp (RecXX c) (RecXX c') = RecXX (CastCmp c c')
+
+type family IsArray t where
+ IsArray (Ranked n a) = True
+ IsArray (Shaped sh a) = True
+ IsArray (Mixed sh a) = True
+ IsArray _ = False
+
+data RSplitCastable a b where
+ RsplitCastableId
+ :: RSplitCastable a a
+ RSplitCastable
+ :: (IsArray b ~ True, IsArray c ~ True, IsArray d ~ True, IsArray e ~ True
+ ,Elt c, Elt d, Elt e)
+ => Rec d e -- possibly a recursive call
+ -> Castable c d -- middle stuff
+ -> Castable b c -- right endpoint (no Cmp)
+ -> RSplitCastable b e
+ RSplitCastableT2
+ :: Castable a a'
+ -> Castable b b'
+ -> RSplitCastable (a, b) (a', b')
+
+rsplitCastable :: Elt a => Castable a b -> RSplitCastable a b
+rsplitCastable = \case
+ CastCmp (CastCmp c1 c2) c3 -> rsplitCastable (CastCmp c1 (CastCmp c2 c3))
+
+ CastCmp c1 c2 -> case rsplitCastable c2 of
+ RSplitCastable rec mid right -> case c1 of
+ CastId -> RSplitCastable rec mid right
+ CastRX | RecEq <- recEq rec -> RSplitCastable (recRX rec) (CastRX `CastCmp` mid) right
+ CastSX | RecEq <- recEq rec -> RSplitCastable (recSX rec) (CastSX `CastCmp` mid) right
+ CastXR | RecEq <- recEq rec -> RSplitCastable (recXR rec) (CastXR `CastCmp` mid) right
+ CastXS | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS `CastCmp` mid) right
+ CastXS' sh | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS' sh `CastCmp` mid) right
+ CastXX' ssh | RecEq <- recEq rec -> RSplitCastable (recXX rec) (CastXX' ssh `CastCmp` mid) right
+ CastRR c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecRR c') rec) mid right
+ CastSS c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecSS c') rec) mid right
+ CastXX c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecXX c') rec) mid right
+
+ CastId -> RsplitCastableId
+ c@CastRX -> RSplitCastable RecId c CastId
+ c@CastSX -> RSplitCastable RecId c CastId
+ c@CastXR -> RSplitCastable RecId c CastId
+ c@CastXS -> RSplitCastable RecId c CastId
+ c@CastXS'{} -> RSplitCastable RecId c CastId
+ c@CastXX'{} -> RSplitCastable RecId c CastId
+
+ CastRR c | Dict <- transferElt c -> RSplitCastable (RecRR c) CastId CastId
+ CastSS c | Dict <- transferElt c -> RSplitCastable (RecSS c) CastId CastId
+ CastXX c | Dict <- transferElt c -> RSplitCastable (RecXX c) CastId CastId
+
+ CastT2 c1 c2 -> RSplitCastableT2 c1 c2
+
+ Cast0X -> _
+ CastX0 -> _
+
+transferElt :: Elt a => Castable a b -> Dict Elt b
+transferElt = \case
+ CastId -> Dict
+ CastCmp c1 c2 | Dict <- transferElt c2, Dict <- transferElt c1 -> Dict
+ CastRX -> Dict
+ CastSX -> Dict
+ CastXR -> Dict
+ CastXS -> Dict
+ CastXS' _ -> Dict
+ CastXX' _ -> Dict
+ CastRR c | Dict <- transferElt c -> Dict
+ CastSS c | Dict <- transferElt c -> Dict
+ CastXX c | Dict <- transferElt c -> Dict
+ CastT2 c1 c2 | Dict <- transferElt c1, Dict <- transferElt c2 -> Dict
+ Cast0X -> Dict
+ CastX0 -> Dict
+
castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
where