diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:01:02 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:02:26 +0200 | 
| commit | e001480cd6ac3a3b79c837c4a12645bf78200b98 (patch) | |
| tree | a463d8679f731d9fcf4e059704dcdc4065ce86bf /src/Data/Array | |
| parent | 40dcdf2360c90437fd5c8f76f5f75c96203ef880 (diff) | |
scalar
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 15 | 
2 files changed, 18 insertions, 3 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 0a9408b..9e4c0e7 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,7 +4,7 @@ module Data.Array.Nested (    Ranked,    IxR(..),    rshape, rindex, rindexPartial, rgenerate, rsumOuter1, -  rtranspose, +  rtranspose, rscalar,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, @@ -13,7 +13,7 @@ module Data.Array.Nested (    IxS(..),    KnownShape(..), SShape(..),    sshape, sindex, sindexPartial, sgenerate, ssumOuter1, -  stranspose, +  stranspose, sscalar,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, @@ -24,7 +24,7 @@ module Data.Array.Nested (    mgenerate, mtranspose,    -- * Array elements -  Elt(mshape, mindex, mindexPartial, mlift), +  Elt(mshape, mindex, mindexPartial, mscalar, mlift),    Primitive(..),    -- * Inductive natural numbers diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 209d594..1079e99 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -146,6 +146,7 @@ class Elt a where    mshape :: KnownShapeX sh => Mixed sh a -> IxX sh    mindex :: Mixed sh a -> IxX sh -> a    mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a +  mscalar :: a -> Mixed '[] a    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -189,6 +190,7 @@ instance Storable a => Elt (Primitive a) where    mshape (M_Primitive a) = X.shape a    mindex (M_Primitive a) i = Primitive (X.index a i)    mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) +  mscalar (Primitive x) = M_Primitive (X.scalar x)    mlift :: forall sh1 sh2.             (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -231,6 +233,7 @@ instance (Elt a, Elt b) => Elt (a, b) where    mshape (M_Tup2 a _) = mshape a    mindex (M_Tup2 a b) i = (mindex a i, mindex b i)    mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) +  mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)    mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)    mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) @@ -265,6 +268,8 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where      | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')      = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) +  mscalar x = M_Nest x +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)          -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) @@ -444,6 +449,8 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where      = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $          mindexPartial arr i +  mscalar (Ranked x) = M_Ranked (M_Nest x) +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) @@ -547,6 +554,8 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where      = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $          mindexPartial arr i +  mscalar (Shaped x) = M_Shaped (M_Nest x) +    mlift :: forall sh1 sh2. KnownShapeX sh2          => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)          -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) @@ -707,6 +716,9 @@ rappend :: forall n a. (KnownINat n, Elt a)          => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a  rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend +rscalar :: Elt a => a -> Ranked I0 a +rscalar x = Ranked (mscalar x) +  -- ====== API OF SHAPED ARRAYS ====== -- @@ -800,3 +812,6 @@ stranspose perm (Shaped arr)  sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)          => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a  sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) | 
