summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 16:27:02 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 16:27:02 +0200
commit690a74d571c61330978fdf5e4565ce0b8622030b (patch)
tree92f716c175d7d201d7731c4ff50bc574d2f3926e /src/Data/Array/Nested
parent754526c1ed56d3eb10106af3e9981863ef8c9d0b (diff)
unScalar, constant
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs37
1 files changed, 33 insertions, 4 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index c85b1fc..4764165 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -377,6 +377,17 @@ mfromVector sh v
| otherwise =
M_Primitive (X.fromVector sh v)
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr IZX
+
+mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
+ => IxX sh -> a -> Mixed sh a
+mconstant sh x
+ | not (checkBounds sh (knownShapeX @sh)) =
+ error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
+ | otherwise =
+ coerce (M_Primitive (X.constant sh x))
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -399,7 +410,7 @@ instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) whe
case X.ssxToShape' (knownShapeX @sh) of
Just sh -> M_Primitive (X.constant sh (fromInteger n))
Nothing -> error "Data.Array.Nested.fromIntegral: \
- \Unknown components in shape, use explicit replicate"
+ \Unknown components in shape, use explicit mconstant"
deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)
deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double)
@@ -732,6 +743,15 @@ rfromVector sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVector (ixCvtRX sh) v)
+runScalar :: Elt a => Ranked I0 a -> a
+runScalar arr = rindex arr IZR
+
+rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
+ => IxR n -> a -> Ranked n a
+rconstant sh x
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mconstant (ixCvtRX sh) x)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -795,10 +815,10 @@ sindexPartial (Shaped arr) idx =
(rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
-sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
-sgenerate sh f
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a
+sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
+ = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
=> (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
@@ -834,3 +854,12 @@ sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped
sfromVector v
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr IZS
+
+sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
+ => a -> Shaped sh a
+sconstant x
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x)