diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:27:02 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:27:02 +0200 |
commit | 690a74d571c61330978fdf5e4565ce0b8622030b (patch) | |
tree | 92f716c175d7d201d7731c4ff50bc574d2f3926e | |
parent | 754526c1ed56d3eb10106af3e9981863ef8c9d0b (diff) |
unScalar, constant
-rw-r--r-- | src/Data/Array/Nested.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 37 |
2 files changed, 39 insertions, 7 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 23d20eb..cd2dde7 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,7 +4,8 @@ module Data.Array.Nested ( Ranked, IxR(..), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, - rtranspose, rappend, rscalar, rfromVector, + rtranspose, rappend, rscalar, rfromVector, runScalar, + rconstant, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -13,7 +14,8 @@ module Data.Array.Nested ( IxS(..), KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, - stranspose, sappend, sscalar, sfromVector, + stranspose, sappend, sscalar, sfromVector, sunScalar, + sconstant, -- ** Lifting orthotope operations to 'Shaped' arrays slift, @@ -21,7 +23,8 @@ module Data.Array.Nested ( Mixed, IxX(..), KnownShapeX(..), StaticShapeX(..), - mgenerate, mtranspose, mappend, mfromVector, + mgenerate, mtranspose, mappend, mfromVector, munScalar, + mconstant, -- * Array elements Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index c85b1fc..4764165 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -377,6 +377,17 @@ mfromVector sh v | otherwise = M_Primitive (X.fromVector sh v) +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr IZX + +mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) + => IxX sh -> a -> Mixed sh a +mconstant sh x + | not (checkBounds sh (knownShapeX @sh)) = + error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) + | otherwise = + coerce (M_Primitive (X.constant sh x)) + mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -399,7 +410,7 @@ instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) whe case X.ssxToShape' (knownShapeX @sh) of Just sh -> M_Primitive (X.constant sh (fromInteger n)) Nothing -> error "Data.Array.Nested.fromIntegral: \ - \Unknown components in shape, use explicit replicate" + \Unknown components in shape, use explicit mconstant" deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) @@ -732,6 +743,15 @@ rfromVector sh v | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mfromVector (ixCvtRX sh) v) +runScalar :: Elt a => Ranked I0 a -> a +runScalar arr = rindex arr IZR + +rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) + => IxR n -> a -> Ranked n a +rconstant sh x + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mconstant (ixCvtRX sh) x) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -795,10 +815,10 @@ sindexPartial (Shaped arr) idx = (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) -sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a -sgenerate sh f +sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a +sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) + = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -834,3 +854,12 @@ sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sfromVector v | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr IZS + +sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) + => a -> Shaped sh a +sconstant x + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) |