diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:27:02 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:27:02 +0200 | 
| commit | 690a74d571c61330978fdf5e4565ce0b8622030b (patch) | |
| tree | 92f716c175d7d201d7731c4ff50bc574d2f3926e | |
| parent | 754526c1ed56d3eb10106af3e9981863ef8c9d0b (diff) | |
unScalar, constant
| -rw-r--r-- | src/Data/Array/Nested.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 37 | 
2 files changed, 39 insertions, 7 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 23d20eb..cd2dde7 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,7 +4,8 @@ module Data.Array.Nested (    Ranked,    IxR(..),    rshape, rindex, rindexPartial, rgenerate, rsumOuter1, -  rtranspose, rappend, rscalar, rfromVector, +  rtranspose, rappend, rscalar, rfromVector, runScalar, +  rconstant,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, @@ -13,7 +14,8 @@ module Data.Array.Nested (    IxS(..),    KnownShape(..), SShape(..),    sshape, sindex, sindexPartial, sgenerate, ssumOuter1, -  stranspose, sappend, sscalar, sfromVector, +  stranspose, sappend, sscalar, sfromVector, sunScalar, +  sconstant,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, @@ -21,7 +23,8 @@ module Data.Array.Nested (    Mixed,    IxX(..),    KnownShapeX(..), StaticShapeX(..), -  mgenerate, mtranspose, mappend, mfromVector, +  mgenerate, mtranspose, mappend, mfromVector, munScalar, +  mconstant,    -- * Array elements    Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index c85b1fc..4764165 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -377,6 +377,17 @@ mfromVector sh v    | otherwise =        M_Primitive (X.fromVector sh v) +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr IZX + +mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) +          => IxX sh -> a -> Mixed sh a +mconstant sh x +  | not (checkBounds sh (knownShapeX @sh)) = +      error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) +  | otherwise = +      coerce (M_Primitive (X.constant sh x)) +  mliftPrim :: (KnownShapeX sh, Storable a)            => (a -> a)            -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -399,7 +410,7 @@ instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) whe      case X.ssxToShape' (knownShapeX @sh) of        Just sh -> M_Primitive (X.constant sh (fromInteger n))        Nothing -> error "Data.Array.Nested.fromIntegral: \ -                       \Unknown components in shape, use explicit replicate" +                       \Unknown components in shape, use explicit mconstant"  deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int)  deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) @@ -732,6 +743,15 @@ rfromVector sh v    | Dict <- lemKnownReplicate (Proxy @n)    = Ranked (mfromVector (ixCvtRX sh) v) +runScalar :: Elt a => Ranked I0 a -> a +runScalar arr = rindex arr IZR + +rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) +          => IxR n -> a -> Ranked n a +rconstant sh x +  | Dict <- lemKnownReplicate (Proxy @n) +  = Ranked (mconstant (ixCvtRX sh) x) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -795,10 +815,10 @@ sindexPartial (Shaped arr) idx =              (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)              (ixCvtSX idx)) -sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a -sgenerate sh f +sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a +sgenerate f    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) +  = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))  slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)        => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -834,3 +854,12 @@ sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped  sfromVector v    | Dict <- lemKnownMapJust (Proxy @sh)    = Shaped (mfromVector (ixCvtSX (cvtSShapeIxS (knownShape @sh))) v) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr IZS + +sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) +          => a -> Shaped sh a +sconstant x +  | Dict <- lemKnownMapJust (Proxy @sh) +  = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) | 
