aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed.hs4
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs10
3 files changed, 17 insertions, 3 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 1dc6b58..ef036af 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -782,3 +782,7 @@ reshapePartial ssh1 ssh' sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
, Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
= XArray (S.reshape (shapeLshape sh2 ++ drop (lengthShX sh2) (S.shapeL arr)) arr)
+
+-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
+iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a
+iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 8d9601e..9a291e6 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -10,7 +10,7 @@ module Data.Array.Nested (
rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
rreplicate, rfromListOuter, rfromList1, rtoListOuter, rtoList1,
- rslice, rrev1, rreshape,
+ rslice, rrev1, rreshape, riota,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift, rlift2,
-- ** Conversions
@@ -27,7 +27,7 @@ module Data.Array.Nested (
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
srerank,
sreplicate, sfromListOuter, sfromList1, stoListOuter, stoList1,
- sslice, srev1, sreshape,
+ sslice, srev1, sreshape, siota,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift, slift2,
-- ** Conversions
@@ -43,7 +43,7 @@ module Data.Array.Nested (
mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
mreplicate, mfromListOuter, mfromList1, mtoListOuter, mtoList1,
- mslice, mrev1, mreshape,
+ mslice, mrev1, mreshape, miota,
-- ** Lifting orthotope operations to 'Mixed' arrays
mlift, mlift2,
-- ** Conversions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index e402f1e..308f8ce 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -55,6 +55,7 @@ import Foreign.Storable (Storable)
import GHC.IsList (IsList)
import qualified GHC.IsList as IsList
import GHC.TypeLits
+import qualified GHC.TypeNats as TypeNats
import Unsafe.Coerce
import Data.Array.Mixed
@@ -951,6 +952,9 @@ mreshape sh' arr =
(\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh')
arr
+miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
+miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+
masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
@@ -1489,6 +1493,9 @@ rreshape sh' rarr@(Ranked arr)
, Dict <- lemKnownReplicate (snatFromShR sh')
= Ranked (mreshape (shCvtRX sh') arr)
+riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
+riota n = TypeNats.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+
rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr)
@@ -1749,6 +1756,9 @@ srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
+siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
+siota sn = Shaped (miota sn)
+
sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)