aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-02 23:22:23 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-04 22:57:32 +0200
commit3b4b5dae625a7781abb59b5d0b593fc077507cf2 (patch)
tree0dcafaf02f0f64e2ed534fae12f9d55c4f043517
parent21eb3ead0ced0aaca4fb400e98349b8acb99599a (diff)
WIP simplify Castable
-rw-r--r--src/Data/Array/Nested/Convert.hs121
-rw-r--r--src/Data/Array/Nested/Mixed.hs13
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs2
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs2
4 files changed, 136 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index cea2489..184e6ae 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -1,10 +1,12 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Convert (
-- * Shape/index/list casting functions
@@ -15,6 +17,7 @@ module Data.Array.Nested.Convert (
-- * Array conversions
castCastable,
Castable(..),
+ transferEltCastable,
-- * Special cases of array conversions
--
@@ -155,6 +158,124 @@ instance Category Castable where
id = CastId
(.) = CastCmp
+data Rec a b where
+ RecId :: Rec a a
+ RecRR :: Castable a b -> Rec (Ranked n a) (Ranked n b)
+ RecSS :: Castable a b -> Rec (Shaped sh a) (Shaped sh b)
+ RecXX :: Castable a b -> Rec (Mixed sh a) (Mixed sh b)
+
+data RecEq t f b where
+ RecEq :: RecEq (f a) f b
+
+recEq :: Rec t (f b) -> RecEq t f b
+recEq RecId = RecEq
+recEq RecRR{} = RecEq
+recEq RecSS{} = RecEq
+recEq RecXX{} = RecEq
+
+recRX :: Rec (f a) (Ranked n b) -> Rec (Mixed sh a) (Mixed sh b)
+recRX RecId = RecId
+recRX (RecRR c) = RecXX c
+
+recSX :: Rec (f a) (Shaped sh b) -> Rec (Mixed sh' a) (Mixed sh' b)
+recSX RecId = RecId
+recSX (RecSS c) = RecXX c
+
+recXR :: Rec (f a) (Mixed sh b) -> Rec (Ranked n a) (Ranked n b)
+recXR RecId = RecId
+recXR (RecXX c) = RecRR c
+
+recXS :: Rec (f a) (Mixed sh b) -> Rec (Shaped sh' a) (Shaped sh' b)
+recXS RecId = RecId
+recXS (RecXX c) = RecSS c
+
+recXX :: Rec (f a) (Mixed sh b) -> Rec (Mixed sh' a) (Mixed sh' b)
+recXX RecId = RecId
+recXX (RecXX c) = RecXX c
+
+recCmp :: Rec b c -> Rec a b -> Rec a c
+recCmp RecId r = r
+recCmp r RecId = r
+recCmp (RecRR c) (RecRR c') = RecRR (CastCmp c c')
+recCmp (RecSS c) (RecSS c') = RecSS (CastCmp c c')
+recCmp (RecXX c) (RecXX c') = RecXX (CastCmp c c')
+
+type family IsArray t where
+ IsArray (Ranked n a) = True
+ IsArray (Shaped sh a) = True
+ IsArray (Mixed sh a) = True
+ IsArray _ = False
+
+data RSplitCastable a b where
+ RsplitCastableId
+ :: RSplitCastable a a
+ RSplitCastable
+ :: (IsArray b ~ True, IsArray c ~ True, IsArray d ~ True, IsArray e ~ True
+ ,Elt c, Elt d, Elt e)
+ => Rec d e -- possibly a recursive call
+ -> Castable c d -- middle stuff
+ -> Castable b c -- right endpoint (no Cmp)
+ -> RSplitCastable b e
+ RSplitCastableT2
+ :: Castable a a'
+ -> Castable b b'
+ -> RSplitCastable (a, b) (a', b')
+
+rsplitCastable :: Elt a => Castable a b -> RSplitCastable a b
+rsplitCastable = \case
+ CastCmp (CastCmp c1 c2) c3 -> rsplitCastable (CastCmp c1 (CastCmp c2 c3))
+
+ CastCmp c1 c2 -> case rsplitCastable c2 of
+ RSplitCastable rec mid right -> case c1 of
+ CastId -> RSplitCastable rec mid right
+ CastRX | RecEq <- recEq rec -> RSplitCastable (recRX rec) (CastRX `CastCmp` mid) right
+ CastSX | RecEq <- recEq rec -> RSplitCastable (recSX rec) (CastSX `CastCmp` mid) right
+ CastXR | RecEq <- recEq rec -> RSplitCastable (recXR rec) (CastXR `CastCmp` mid) right
+ CastXS | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS `CastCmp` mid) right
+ CastXS' sh | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS' sh `CastCmp` mid) right
+ CastXX' ssh | RecEq <- recEq rec -> RSplitCastable (recXX rec) (CastXX' ssh `CastCmp` mid) right
+ CastRR c' | Dict <- transferEltCastable c' -> RSplitCastable (recCmp (RecRR c') rec) mid right
+ CastSS c' | Dict <- transferEltCastable c' -> RSplitCastable (recCmp (RecSS c') rec) mid right
+ CastXX c' | Dict <- transferEltCastable c' -> RSplitCastable (recCmp (RecXX c') rec) mid right
+
+ CastId -> RsplitCastableId
+ c@CastRX -> RSplitCastable RecId c CastId
+ c@CastSX -> RSplitCastable RecId c CastId
+ c@CastXR -> RSplitCastable RecId c CastId
+ c@CastXS -> RSplitCastable RecId c CastId
+ c@CastXS'{} -> RSplitCastable RecId c CastId
+ c@CastXX'{} -> RSplitCastable RecId c CastId
+
+ CastRR c | Dict <- transferEltCastable c -> RSplitCastable (RecRR c) CastId CastId
+ CastSS c | Dict <- transferEltCastable c -> RSplitCastable (RecSS c) CastId CastId
+ CastXX c | Dict <- transferEltCastable c -> RSplitCastable (RecXX c) CastId CastId
+
+ CastT2 c1 c2 -> RSplitCastableT2 c1 c2
+
+ Cast0X -> undefined
+ CastX0 -> undefined
+
+transferEltCastable :: Elt a => Castable a b -> Dict Elt b
+transferEltCastable = \case
+ CastId -> Dict
+ CastCmp c1 c2 | Dict <- transferEltCastable c2, Dict <- transferEltCastable c1 -> Dict
+ CastRX -> Dict
+ CastSX -> Dict
+ CastXR -> Dict
+ CastXS -> Dict
+ CastXS' _ -> Dict
+ CastXX' _ -> Dict
+ CastRR c | Dict <- transferEltCastable c -> Dict
+ CastSS c | Dict <- transferEltCastable c -> Dict
+ CastXX c | Dict <- transferEltCastable c -> Dict
+ CastT2 c1 c2 | Dict <- transferEltCastable c1, Dict <- transferEltCastable c2 -> Dict
+ Cast0X -> Dict
+ CastX0 -> Dict
+ CastNest _ -> Dict
+ CastUnnest -> Dict
+ CastZip -> Dict
+ CastUnzip -> Dict
+
castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
where
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 54f8fe6..0a2fc17 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Nested.Mixed where
@@ -29,7 +30,7 @@ import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
import Data.Int
-import Data.Kind (Type)
+import Data.Kind (Type, Constraint)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty qualified as NE
import Data.Proxy
@@ -290,7 +291,9 @@ matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
-- a@; see the documentation for 'Primitive' for more details.
-class Elt a where
+class EltC a => Elt a where
+ type EltC a :: Constraint
+
-- ====== PUBLIC METHODS ====== --
mshape :: Mixed sh a -> IShX sh
@@ -396,6 +399,8 @@ class Elt a => KnownElt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
+ type EltC (Primitive a) = Storable a
+
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
@@ -500,6 +505,8 @@ deriving via Primitive () instance KnownElt ()
-- Arrays of pairs are pairs of arrays.
instance (Elt a, Elt b) => Elt (a, b) where
+ type EltC (a, b) = (Elt a, Elt b)
+
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
@@ -549,6 +556,8 @@ instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
-- Arrays of arrays are just arrays, but with more dimensions.
instance Elt a => Elt (Mixed sh' a) where
+ type EltC (Mixed sh' a) = Elt a
+
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index babc809..e4305e5 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -87,6 +87,8 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
instance Elt a => Elt (Ranked n a) where
+ type EltC (Ranked n a) = Elt a
+
mshape (M_Ranked arr) = mshape arr
mindex (M_Ranked arr) i = Ranked (mindex arr i)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index ddd44bf..5c45abd 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -80,6 +80,8 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
instance Elt a => Elt (Shaped sh a) where
+ type EltC (Shaped sh a) = Elt a
+
mshape (M_Shaped arr) = mshape arr
mindex (M_Shaped arr) i = Shaped (mindex arr i)