diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index e3aa7a1..515e867 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -739,8 +739,9 @@ mgenerate sh f = case shxEnum sh of else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. + -- This is likely fine if @a@ is big, but if @a@ is a scalar + -- this array copying is inefficient so it's better to use + -- the @mgeneratePrim@ below. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ @@ -748,6 +749,15 @@ mgenerate sh f = case shxEnum sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs +-- | An optimized special case of `mgenerate', where the function results +-- are of a primitive type and so there's not need to verify the shapes +-- of them all are equal. +mgeneratePrim :: forall sh a. PrimElt a + => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g + msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = |
