From d8d8fc39c6d52b0960c89f38bfa8ec3969a8ca02 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 23:25:36 +0200 Subject: iota --- src/Data/Array/Mixed.hs | 4 ++++ src/Data/Array/Nested.hs | 6 +++--- src/Data/Array/Nested/Internal.hs | 10 ++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 1dc6b58..ef036af 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -782,3 +782,7 @@ reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthShX sh2) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 8d9601e..9a291e6 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -10,7 +10,7 @@ module Data.Array.Nested ( rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, rrerank, rreplicate, rfromListOuter, rfromList1, rtoListOuter, rtoList1, - rslice, rrev1, rreshape, + rslice, rrev1, rreshape, riota, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, rlift2, -- ** Conversions @@ -27,7 +27,7 @@ module Data.Array.Nested ( stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, srerank, sreplicate, sfromListOuter, sfromList1, stoListOuter, stoList1, - sslice, srev1, sreshape, + sslice, srev1, sreshape, siota, -- ** Lifting orthotope operations to 'Shaped' arrays slift, slift2, -- ** Conversions @@ -43,7 +43,7 @@ module Data.Array.Nested ( mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar, mrerank, mreplicate, mfromListOuter, mfromList1, mtoListOuter, mtoList1, - mslice, mrev1, mreshape, + mslice, mrev1, mreshape, miota, -- ** Lifting orthotope operations to 'Mixed' arrays mlift, mlift2, -- ** Conversions diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index e402f1e..308f8ce 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -55,6 +55,7 @@ import Foreign.Storable (Storable) import GHC.IsList (IsList) import qualified GHC.IsList as IsList import GHC.TypeLits +import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce import Data.Array.Mixed @@ -951,6 +952,9 @@ mreshape sh' arr = (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh') arr +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) masXArrayPrimP (M_Primitive sh arr) = (sh, arr) @@ -1489,6 +1493,9 @@ rreshape sh' rarr@(Ranked arr) , Dict <- lemKnownReplicate (snatFromShR sh') = Ranked (mreshape (shCvtRX sh') arr) +riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a +riota n = TypeNats.withSomeSNat (fromIntegral n) $ mtoRanked . miota + rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) @@ -1749,6 +1756,9 @@ srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) -- cgit v1.2.3-70-g09d2