aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs9
-rw-r--r--src/Data/Array/Nested/Internal.hs37
2 files changed, 39 insertions, 7 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 23d20eb..cd2dde7 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -4,7 +4,8 @@ module Data.Array.Nested (
Ranked,
IxR(..),
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
- rtranspose, rappend, rscalar, rfromVector,
+ rtranspose, rappend, rscalar, rfromVector, runScalar,
+ rconstant,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
@@ -13,7 +14,8 @@ module Data.Array.Nested (
IxS(..),
KnownShape(..), SShape(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
- stranspose, sappend, sscalar, sfromVector,
+ stranspose, sappend, sscalar, sfromVector, sunScalar,
+ sconstant,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
@@ -21,7 +23,8 @@ module Data.Array.Nested (
Mixed,
IxX(..),
KnownShapeX(..), StaticShapeX(..),
- mgenerate, mtranspose, mappend, mfromVector,
+ mgenerate, mtranspose, mappend, mfromVector, munScalar,
+ mconstant,
-- * Array elements
Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2),
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index c85b1fc..4764165 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -377,6 +377,17 @@ mfromVector sh v
| otherwise =
M_Primitive (X.fromVector sh v)
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr IZX
+
+mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
+ => IxX sh -> a -> Mixed sh a
+mconstant sh x
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
+ | otherwise =
+ coerce (M_Primitive (X.constant sh x))
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -399,7 +410,7 @@ instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) whe
case X.ssxToShape' (knownShapeX @sh) of
Just sh -> M_Primitive (X.constant sh (fromInteger n))
Nothing -> error "Data.Array.Nested.fromIntegral: \
- \Unknown components in shape, use explicit replicate"
+ \Unknown components in shape, use explicit mconstant"
deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)
deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double)
@@ -732,6 +743,15 @@ rfromVector sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVector (ixCvtRX sh) v)
+runScalar :: Elt a => Ranked I0 a -> a
+runScalar arr = rindex arr IZR
+
+rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
+ => IxR n -> a -> Ranked n a
+rconstant sh x
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mconstant (ixCvtRX sh) x)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -795,10 +815,10 @@ sindexPartial (Shaped arr) idx =
(rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
-sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
-sgenerate sh f
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a
+sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
+ = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -834,3 +854,12 @@ sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped
sfromVector v
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr IZS
+
+sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
+ => a -> Shaped sh a
+sconstant x
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x)