aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-02 05:38:41 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-02 05:38:41 +0100
commitba5a31c976f80421464af1af8d6ab1e2a154cd83 (patch)
tree526de2796e98cd238e1dcefc53894780376577e2
parent88828bd004ccba13e227f732106ab30c3731837f (diff)
Define mgeneratePrim as a fast special case variant
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Mixed.hs14
-rw-r--r--src/Data/Array/Nested/Ranked.hs9
-rw-r--r--src/Data/Array/Nested/Shaped.hs3
-rw-r--r--src/Data/Array/Nested/Trace.hs2
5 files changed, 28 insertions, 6 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 0bb6003..f32266c 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -6,7 +6,7 @@ module Data.Array.Nested (
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1Prim, rsumAllPrim,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
remptyArray,
rrerankPrim,
@@ -36,7 +36,7 @@ module Data.Array.Nested (
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
- sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1Prim, ssumAllPrim,
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
semptyArray,
@@ -65,7 +65,7 @@ module Data.Array.Nested (
ShX(.., ZSX, (:$%)), KnownShX(..), IShX,
StaticShX(.., ZKX, (:!%)),
SMayNat(..),
- mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1Prim, msumAllPrim,
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
memptyArray,
mrerankPrim,
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index e3aa7a1..515e867 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -739,8 +739,9 @@ mgenerate sh f = case shxEnum sh of
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
+ -- This is likely fine if @a@ is big, but if @a@ is a scalar
+ -- this array copying is inefficient so it's better to use
+ -- the @mgeneratePrim@ below.
forM_ restidxs $ \idx -> do
let val = f idx
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
@@ -748,6 +749,15 @@ mgenerate sh f = case shxEnum sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
+-- | An optimized special case of `mgenerate', where the function results
+-- are of a primitive type and so there's not need to verify the shapes
+-- of them all are equal.
+mgeneratePrim :: forall sh a. PrimElt a
+ => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgeneratePrim sh f =
+ let g i = f (ixxFromLinear sh i)
+ in mfromVector sh $ VS.generate (shxSize sh) g
+
msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index bf35cc4..9504247 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -69,6 +69,15 @@ rgenerate sh f
, Refl <- lemRankReplicate sn
= Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX))
+-- TODO: this would be shorter and faster written with rfromVector,
+-- but unfortunately we don't have ixrFromLinear
+rgeneratePrim :: forall n a. PrimElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgeneratePrim sh f
+ | sn@SNat <- shrRank sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
+ = Ranked (mgeneratePrim (shxFromShR sh) (f . ixrFromIxX))
+
-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. Elt a
=> SNat n2
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 82dfc91..31a7706 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -72,6 +72,9 @@ sindexPartial sarr@(Shaped arr) idx =
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh))
+sgeneratePrim :: forall sh a. PrimElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgeneratePrim sh f = Shaped (mgeneratePrim (shxFromShS sh) (f . ixsFromIxX sh))
+
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index f793774..66d2818 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -76,4 +76,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rgeneratePrim, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'sgeneratePrim, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'mgeneratePrim, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])