aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs14
1 files changed, 12 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index e3aa7a1..515e867 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -739,8 +739,9 @@ mgenerate sh f = case shxEnum sh of
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
+ -- This is likely fine if @a@ is big, but if @a@ is a scalar
+ -- this array copying is inefficient so it's better to use
+ -- the @mgeneratePrim@ below.
forM_ restidxs $ \idx -> do
let val = f idx
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
@@ -748,6 +749,15 @@ mgenerate sh f = case shxEnum sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
+-- | An optimized special case of `mgenerate', where the function results
+-- are of a primitive type and so there's not need to verify the shapes
+-- of them all are equal.
+mgeneratePrim :: forall sh a. PrimElt a
+ => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgeneratePrim sh f =
+ let g i = f (ixxFromLinear sh i)
+ in mfromVector sh $ VS.generate (shxSize sh) g
+
msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1PrimP (M_Primitive (n :$% sh) arr) =