From 690a74d571c61330978fdf5e4565ce0b8622030b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 16:27:02 +0200 Subject: unScalar, constant --- src/Data/Array/Nested/Internal.hs | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) (limited to 'src/Data/Array/Nested/Internal.hs') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index c85b1fc..4764165 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -377,6 +377,17 @@ mfromVector sh v | otherwise = M_Primitive (X.fromVector sh v) +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr IZX + +mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) + => IxX sh -> a -> Mixed sh a +mconstant sh x + | not (checkBounds sh (knownShapeX @sh)) = + error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) + | otherwise = + coerce (M_Primitive (X.constant sh x)) + mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -399,7 +410,7 @@ instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) whe case X.ssxToShape' (knownShapeX @sh) of Just sh -> M_Primitive (X.constant sh (fromInteger n)) Nothing -> error "Data.Array.Nested.fromIntegral: \ - \Unknown components in shape, use explicit replicate" + \Unknown components in shape, use explicit mconstant" deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) @@ -732,6 +743,15 @@ rfromVector sh v | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mfromVector (ixCvtRX sh) v) +runScalar :: Elt a => Ranked I0 a -> a +runScalar arr = rindex arr IZR + +rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) + => IxR n -> a -> Ranked n a +rconstant sh x + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mconstant (ixCvtRX sh) x) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -795,10 +815,10 @@ sindexPartial (Shaped arr) idx = (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) -sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a -sgenerate sh f +sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a +sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) + = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -834,3 +854,12 @@ sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sfromVector v | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr IZS + +sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) + => a -> Shaped sh a +sconstant x + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) -- cgit v1.2.3-70-g09d2