aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-02 14:19:04 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-02 14:19:04 +0100
commita7b64fe342524e82194d73af852b5f2f1bc5bab3 (patch)
tree73fb2b9484d33e7400e92a67388227a75adf56c9 /src/Data/Array/Nested/Mixed.hs
parent9f47aa6a2bcd772388a5d5150ca7254e4eb95bc2 (diff)
Generalize also mgenerate to potentially avoid @fmap fromIntegral@mgenerate-integral
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs21
1 files changed, 11 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index d658ed3..cad5c4e 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -369,11 +369,11 @@ class Elt a where
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWrite :: Integral i => IShX sh -> IxX sh i -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartial :: Integral i => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
@@ -468,8 +468,8 @@ instance Storable a => Elt (Primitive a) where
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
- :: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ :: forall sh' sh s i. Integral i
+ => IShX (sh ++ sh') -> IxX sh i -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
@@ -678,8 +678,8 @@ instance Elt a => Elt (Mixed sh' a) where
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mvecsWritePartial :: forall sh1 sh2 s i. Integral i
+ => IShX (sh1 ++ sh2) -> IxX sh1 i -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
@@ -728,8 +728,10 @@ msize = shxSize . mshape
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
-mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case shxEnum sh of
+{-# INLINEABLE mgenerate #-}
+mgenerate :: forall sh a i. (KnownElt a, Integral i)
+ => IShX sh -> (IxX sh i -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum' sh of
[] -> memptyArrayUnsafe sh
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
@@ -751,8 +753,7 @@ mgenerate sh f = case shxEnum sh of
-- | An optimized special case of `mgenerate', where the function results
-- are of a primitive type and so there's not need to verify the shapes
--- of them all are equal. This is also generalized to aribitrary @Num@ index
--- type compared to @mgenerate@.
+-- of them all are equal.
{-# INLINE mgeneratePrim #-}
mgeneratePrim :: forall sh a i. (PrimElt a, Num i)
=> IShX sh -> (IxX sh i -> a) -> Mixed sh a