diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 05:38:41 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-02 05:38:41 +0100 |
| commit | ba5a31c976f80421464af1af8d6ab1e2a154cd83 (patch) | |
| tree | 526de2796e98cd238e1dcefc53894780376577e2 | |
| parent | 88828bd004ccba13e227f732106ab30c3731837f (diff) | |
Define mgeneratePrim as a fast special case variant
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Trace.hs | 2 |
5 files changed, 28 insertions, 6 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 0bb6003..f32266c 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -6,7 +6,7 @@ module Data.Array.Nested ( ListR(ZR, (:::)), IxR(.., ZIR, (:.:)), IIxR, ShR(.., ZSR, (:$:)), IShR, - rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1Prim, rsumAllPrim, + rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rgeneratePrim, rsumOuter1Prim, rsumAllPrim, rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, remptyArray, rrerankPrim, @@ -36,7 +36,7 @@ module Data.Array.Nested ( ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, ShS(.., ZSS, (:$$)), KnownShS(..), - sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1Prim, ssumAllPrim, + sshape, srank, ssize, sindex, sindexPartial, sgenerate, sgeneratePrim, ssumOuter1Prim, ssumAllPrim, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, -- TODO: sconcat? What should its type be? semptyArray, @@ -65,7 +65,7 @@ module Data.Array.Nested ( ShX(.., ZSX, (:$%)), KnownShX(..), IShX, StaticShX(.., ZKX, (:!%)), SMayNat(..), - mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1Prim, msumAllPrim, + mshape, mrank, msize, mindex, mindexPartial, mgenerate, mgeneratePrim, msumOuter1Prim, msumAllPrim, mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, memptyArray, mrerankPrim, diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index e3aa7a1..515e867 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -739,8 +739,9 @@ mgenerate sh f = case shxEnum sh of else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. + -- This is likely fine if @a@ is big, but if @a@ is a scalar + -- this array copying is inefficient so it's better to use + -- the @mgeneratePrim@ below. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ @@ -748,6 +749,15 @@ mgenerate sh f = case shxEnum sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs +-- | An optimized special case of `mgenerate', where the function results +-- are of a primitive type and so there's not need to verify the shapes +-- of them all are equal. +mgeneratePrim :: forall sh a. PrimElt a + => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g + msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index bf35cc4..9504247 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -69,6 +69,15 @@ rgenerate sh f , Refl <- lemRankReplicate sn = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) +-- TODO: this would be shorter and faster written with rfromVector, +-- but unfortunately we don't have ixrFromLinear +rgeneratePrim :: forall n a. PrimElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgeneratePrim sh f + | sn@SNat <- shrRank sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn + = Ranked (mgeneratePrim (shxFromShR sh) (f . ixrFromIxX)) + -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. Elt a => SNat n2 diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 82dfc91..31a7706 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -72,6 +72,9 @@ sindexPartial sarr@(Shaped arr) idx = sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +sgeneratePrim :: forall sh a. PrimElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgeneratePrim sh f = Shaped (mgeneratePrim (shxFromShS sh) (f . ixsFromIxX sh)) + -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. Elt a => ShS sh2 diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index f793774..66d2818 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -76,4 +76,4 @@ import Data.Array.Nested.Trace.TH $(concat <$> mapM convertFun - ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rgeneratePrim, 'rsumOuter1Prim, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerankPrim, 'rreplicate, 'rreplicatePrim, 'rfromListOuter, 'rfromListOuterN, 'rfromList1, 'rfromList1N, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'sgeneratePrim, 'ssumOuter1Prim, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerankPrim, 'sreplicate, 'sreplicatePrim, 'sfromListOuter, 'sfromList1, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'mgeneratePrim, 'msumOuter1Prim, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerankPrim, 'mreplicate, 'mreplicatePrim, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) |
