From 4087d405b51cf32363cb7507df6ffe1a170c0f7f Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Fri, 25 Apr 2025 23:45:48 +0200 Subject: Add mstrideTree and StrideTree --- src/Data/Array/Nested/Internal/Mixed.hs | 12 ++++++++++++ src/Data/Array/Nested/Internal/Ranked.hs | 2 ++ src/Data/Array/Nested/Internal/Shaped.hs | 2 ++ 3 files changed, 16 insertions(+) diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index fa4b4a1..5730354 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -24,6 +24,9 @@ import Control.DeepSeq (NFData(..)) import Control.Monad (forM_, when) import Control.Monad.ST import Data.Array.RankedS qualified as S +import Data.Array.Internal.RankedS qualified as Internal.RankedS +import Data.Array.Internal.RankedG qualified as Internal.RankedG +import Data.Array.Internal qualified as Array.Internal import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) @@ -272,6 +275,9 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) +data StrideTree = + StrideLeaf [Int] + | StrideNode StrideTree StrideTree -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' @@ -345,6 +351,8 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String + mstrideTree :: Mixed sh a -> StrideTree + -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () @@ -441,6 +449,7 @@ instance Storable a => Elt (Primitive a) where mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" + mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal @@ -514,6 +523,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b) mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -646,6 +656,8 @@ instance Elt a => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mstrideTree (M_Nest _ arr) = mstrideTree arr + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 0a165bc..c501015 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -144,6 +144,8 @@ instance Elt a => Elt (Ranked n a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mstrideTree (M_Ranked arr) = mstrideTree arr + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs = mvecsWrite sh idx arr diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d7a8ece..eebf66a 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -142,6 +142,8 @@ instance Elt a => Elt (Shaped sh a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mstrideTree (M_Shaped arr) = mstrideTree arr + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs = mvecsWrite sh idx arr -- cgit v1.2.3-70-g09d2