From 8344417a8ed9ad59337ce3b880e2fadf89bec964 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 2 Jun 2025 23:22:23 +0200 Subject: WIP simplify Castable --- src/Data/Array/Nested/Convert.hs | 116 +++++++++++++++++++++++++++++++++++ src/Data/Array/Nested/Mixed.hs | 13 +++- src/Data/Array/Nested/Ranked/Base.hs | 2 + src/Data/Array/Nested/Shaped/Base.hs | 2 + 4 files changed, 131 insertions(+), 2 deletions(-) diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index fe590d1..4c9c9a1 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,10 +1,12 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( -- * Shape/index/list casting functions @@ -140,6 +142,120 @@ instance Category Castable where id = CastId (.) = CastCmp +data Rec a b where + RecId :: Rec a a + RecRR :: Castable a b -> Rec (Ranked n a) (Ranked n b) + RecSS :: Castable a b -> Rec (Shaped sh a) (Shaped sh b) + RecXX :: Castable a b -> Rec (Mixed sh a) (Mixed sh b) + +data RecEq t f b where + RecEq :: RecEq (f a) f b + +recEq :: Rec t (f b) -> RecEq t f b +recEq RecId = RecEq +recEq RecRR{} = RecEq +recEq RecSS{} = RecEq +recEq RecXX{} = RecEq + +recRX :: Rec (f a) (Ranked n b) -> Rec (Mixed sh a) (Mixed sh b) +recRX RecId = RecId +recRX (RecRR c) = RecXX c + +recSX :: Rec (f a) (Shaped sh b) -> Rec (Mixed sh' a) (Mixed sh' b) +recSX RecId = RecId +recSX (RecSS c) = RecXX c + +recXR :: Rec (f a) (Mixed sh b) -> Rec (Ranked n a) (Ranked n b) +recXR RecId = RecId +recXR (RecXX c) = RecRR c + +recXS :: Rec (f a) (Mixed sh b) -> Rec (Shaped sh' a) (Shaped sh' b) +recXS RecId = RecId +recXS (RecXX c) = RecSS c + +recXX :: Rec (f a) (Mixed sh b) -> Rec (Mixed sh' a) (Mixed sh' b) +recXX RecId = RecId +recXX (RecXX c) = RecXX c + +recCmp :: Rec b c -> Rec a b -> Rec a c +recCmp RecId r = r +recCmp r RecId = r +recCmp (RecRR c) (RecRR c') = RecRR (CastCmp c c') +recCmp (RecSS c) (RecSS c') = RecSS (CastCmp c c') +recCmp (RecXX c) (RecXX c') = RecXX (CastCmp c c') + +type family IsArray t where + IsArray (Ranked n a) = True + IsArray (Shaped sh a) = True + IsArray (Mixed sh a) = True + IsArray _ = False + +data RSplitCastable a b where + RsplitCastableId + :: RSplitCastable a a + RSplitCastable + :: (IsArray b ~ True, IsArray c ~ True, IsArray d ~ True, IsArray e ~ True + ,Elt c, Elt d, Elt e) + => Rec d e -- possibly a recursive call + -> Castable c d -- middle stuff + -> Castable b c -- right endpoint (no Cmp) + -> RSplitCastable b e + RSplitCastableT2 + :: Castable a a' + -> Castable b b' + -> RSplitCastable (a, b) (a', b') + +rsplitCastable :: Elt a => Castable a b -> RSplitCastable a b +rsplitCastable = \case + CastCmp (CastCmp c1 c2) c3 -> rsplitCastable (CastCmp c1 (CastCmp c2 c3)) + + CastCmp c1 c2 -> case rsplitCastable c2 of + RSplitCastable rec mid right -> case c1 of + CastId -> RSplitCastable rec mid right + CastRX | RecEq <- recEq rec -> RSplitCastable (recRX rec) (CastRX `CastCmp` mid) right + CastSX | RecEq <- recEq rec -> RSplitCastable (recSX rec) (CastSX `CastCmp` mid) right + CastXR | RecEq <- recEq rec -> RSplitCastable (recXR rec) (CastXR `CastCmp` mid) right + CastXS | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS `CastCmp` mid) right + CastXS' sh | RecEq <- recEq rec -> RSplitCastable (recXS rec) (CastXS' sh `CastCmp` mid) right + CastXX' ssh | RecEq <- recEq rec -> RSplitCastable (recXX rec) (CastXX' ssh `CastCmp` mid) right + CastRR c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecRR c') rec) mid right + CastSS c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecSS c') rec) mid right + CastXX c' | Dict <- transferElt c' -> RSplitCastable (recCmp (RecXX c') rec) mid right + + CastId -> RsplitCastableId + c@CastRX -> RSplitCastable RecId c CastId + c@CastSX -> RSplitCastable RecId c CastId + c@CastXR -> RSplitCastable RecId c CastId + c@CastXS -> RSplitCastable RecId c CastId + c@CastXS'{} -> RSplitCastable RecId c CastId + c@CastXX'{} -> RSplitCastable RecId c CastId + + CastRR c | Dict <- transferElt c -> RSplitCastable (RecRR c) CastId CastId + CastSS c | Dict <- transferElt c -> RSplitCastable (RecSS c) CastId CastId + CastXX c | Dict <- transferElt c -> RSplitCastable (RecXX c) CastId CastId + + CastT2 c1 c2 -> RSplitCastableT2 c1 c2 + + Cast0X -> _ + CastX0 -> _ + +transferElt :: Elt a => Castable a b -> Dict Elt b +transferElt = \case + CastId -> Dict + CastCmp c1 c2 | Dict <- transferElt c2, Dict <- transferElt c1 -> Dict + CastRX -> Dict + CastSX -> Dict + CastXR -> Dict + CastXS -> Dict + CastXS' _ -> Dict + CastXX' _ -> Dict + CastRR c | Dict <- transferElt c -> Dict + CastSS c | Dict <- transferElt c -> Dict + CastXX c | Dict <- transferElt c -> Dict + CastT2 c1 c2 | Dict <- transferElt c1, Dict <- transferElt c2 -> Dict + Cast0X -> Dict + CastX0 -> Dict + castCastable :: (Elt a, Elt b) => Castable a b -> a -> b castCastable = \c x -> munScalar (go c (mscalar x)) where diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 221393f..8d25bc3 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -16,6 +16,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE ViewPatterns #-} module Data.Array.Nested.Mixed where @@ -29,7 +30,7 @@ import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) import Data.Int -import Data.Kind (Type) +import Data.Kind (Type, Constraint) import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty qualified as NE import Data.Proxy @@ -290,7 +291,9 @@ matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' -- a@; see the documentation for 'Primitive' for more details. -class Elt a where +class EltC a => Elt a where + type EltC a :: Constraint + -- ====== PUBLIC METHODS ====== -- mshape :: Mixed sh a -> IShX sh @@ -396,6 +399,8 @@ class Elt a => KnownElt a where -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where + type EltC (Primitive a) = Storable a + mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) @@ -500,6 +505,8 @@ deriving via Primitive () instance KnownElt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where + type EltC (a, b) = (Elt a, Elt b) + mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) @@ -549,6 +556,8 @@ instance (KnownElt a, KnownElt b) => KnownElt (a, b) where -- Arrays of arrays are just arrays, but with more dimensions. instance Elt a => Elt (Mixed sh' a) where + type EltC (Mixed sh' a) = Elt a + -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index babc809..e4305e5 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -87,6 +87,8 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance Elt a => Elt (Ranked n a) where + type EltC (Ranked n a) = Elt a + mshape (M_Ranked arr) = mshape arr mindex (M_Ranked arr) i = Ranked (mindex arr i) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index ddd44bf..5c45abd 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -80,6 +80,8 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance Elt a => Elt (Shaped sh a) where + type EltC (Shaped sh a) = Elt a + mshape (M_Shaped arr) = mshape arr mindex (M_Shaped arr) i = Shaped (mindex arr i) -- cgit v1.2.3-70-g09d2