From ba5a31c976f80421464af1af8d6ab1e2a154cd83 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Tue, 2 Dec 2025 05:38:41 +0100 Subject: Define mgeneratePrim as a fast special case variant --- src/Data/Array/Nested/Mixed.hs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'src/Data/Array/Nested/Mixed.hs') diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index e3aa7a1..515e867 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -739,8 +739,9 @@ mgenerate sh f = case shxEnum sh of else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. + -- This is likely fine if @a@ is big, but if @a@ is a scalar + -- this array copying is inefficient so it's better to use + -- the @mgeneratePrim@ below. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ @@ -748,6 +749,15 @@ mgenerate sh f = case shxEnum sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs +-- | An optimized special case of `mgenerate', where the function results +-- are of a primitive type and so there's not need to verify the shapes +-- of them all are equal. +mgeneratePrim :: forall sh a. PrimElt a + => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g + msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = -- cgit v1.2.3-70-g09d2