diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 05:38:41 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 05:38:41 +0100 |
| commit | ba5a31c976f80421464af1af8d6ab1e2a154cd83 (patch) | |
| tree | 526de2796e98cd238e1dcefc53894780376577e2 /src/Data/Array/Nested/Mixed.hs | |
| parent | 88828bd004ccba13e227f732106ab30c3731837f (diff) | |
Define mgeneratePrim as a fast special case variant
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index e3aa7a1..515e867 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -739,8 +739,9 @@ mgenerate sh f = case shxEnum sh of else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. + -- This is likely fine if @a@ is big, but if @a@ is a scalar + -- this array copying is inefficient so it's better to use + -- the @mgeneratePrim@ below. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ @@ -748,6 +749,15 @@ mgenerate sh f = case shxEnum sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs +-- | An optimized special case of `mgenerate', where the function results +-- are of a primitive type and so there's not need to verify the shapes +-- of them all are equal. +mgeneratePrim :: forall sh a. PrimElt a + => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g + msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = |
