aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs2
3 files changed, 16 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index fa4b4a1..5730354 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -24,6 +24,9 @@ import Control.DeepSeq (NFData(..))
import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.RankedS qualified as S
+import Data.Array.Internal.RankedS qualified as Internal.RankedS
+import Data.Array.Internal.RankedG qualified as Internal.RankedG
+import Data.Array.Internal qualified as Array.Internal
import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
@@ -272,6 +275,9 @@ mremArray = mliftNumElt2 (liftO2 . intEltRem)
matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
+data StrideTree =
+ StrideLeaf [Int]
+ | StrideNode StrideTree StrideTree
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
@@ -345,6 +351,8 @@ class Elt a where
mshowShapeTree :: Proxy a -> ShapeTree a -> String
+ mstrideTree :: Mixed sh a -> StrideTree
+
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
@@ -441,6 +449,7 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
+ mstrideTree (M_Primitive _ (XArray (Internal.RankedS.A (Internal.RankedG.A _ (Array.Internal.T ss _ _))))) = StrideLeaf ss
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-- TODO: this use of toVector is suboptimal
@@ -514,6 +523,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
+ mstrideTree (M_Tup2 a b) = StrideNode (mstrideTree a) (mstrideTree b)
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
@@ -646,6 +656,8 @@ instance Elt a => Elt (Mixed sh' a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+ mstrideTree (M_Nest _ arr) = mstrideTree arr
+
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s.
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 0a165bc..c501015 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -144,6 +144,8 @@ instance Elt a => Elt (Ranked n a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+ mstrideTree (M_Ranked arr) = mstrideTree arr
+
mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs =
mvecsWrite sh idx arr
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index d7a8ece..eebf66a 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -142,6 +142,8 @@ instance Elt a => Elt (Shaped sh a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+ mstrideTree (M_Shaped arr) = mstrideTree arr
+
mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs =
mvecsWrite sh idx arr