diff options
author | Mikolaj Konarski <mikolaj.konarski@gmail.com> | 2025-04-25 23:45:48 +0200 |
---|---|---|
committer | Mikolaj Konarski <mikolaj.konarski@gmail.com> | 2025-04-25 23:45:48 +0200 |
commit | 4087d405b51cf32363cb7507df6ffe1a170c0f7f (patch) | |
tree | fa9ec2db7f8d02c1e5cb4e10978e1871632e0441 /src/Data/Array/Nested/Internal/Mixed.hs | |
parent | 0121c2480b2ea8d34c9d293941283e6c8e2e09dc (diff) |
Add mstrideTree and StrideTree
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index fa4b4a1..5730354 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -24,6 +24,9 @@ import Control.DeepSeq (NFData(..)) import Control.Monad (forM_, when) import Control.Monad.ST import Data.Array.RankedS qualified as S +import Data.Array.Internal.RankedS qualified as Internal.RankedS +import Data.Array.Internal.RankedG qualified as Internal.RankedG +import Data.Array.Internal qualified as Array.Internal import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) @@ -272,6 +275,9 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) +data StrideTree = + StrideLeaf [Int] + | StrideNode StrideTree StrideTree -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or -- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' @@ -345,6 +351,8 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String + mstrideTree :: Mixed sh a -> StrideTree + -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () @@ -441,6 +449,7 @@ instance Storable a => Elt (Primitive a) where mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" + mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal @@ -514,6 +523,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b) mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -646,6 +656,8 @@ instance Elt a => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mstrideTree (M_Nest _ arr) = mstrideTree arr + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. |