aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-14 16:01:02 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-14 16:02:26 +0200
commite001480cd6ac3a3b79c837c4a12645bf78200b98 (patch)
treea463d8679f731d9fcf4e059704dcdc4065ce86bf /src/Data/Array/Nested/Internal.hs
parent40dcdf2360c90437fd5c8f76f5f75c96203ef880 (diff)
scalar
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 209d594..1079e99 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -146,6 +146,7 @@ class Elt a where
mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
mindex :: Mixed sh a -> IxX sh -> a
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+ mscalar :: a -> Mixed '[] a
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -189,6 +190,7 @@ instance Storable a => Elt (Primitive a) where
mshape (M_Primitive a) = X.shape a
mindex (M_Primitive a) i = Primitive (X.index a i)
mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive (X.scalar x)
mlift :: forall sh1 sh2.
(Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -231,6 +233,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
@@ -265,6 +268,8 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ mscalar x = M_Nest x
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
@@ -444,6 +449,8 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where
= coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
mindexPartial arr i
+ mscalar (Ranked x) = M_Ranked (M_Nest x)
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
@@ -547,6 +554,8 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
= coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
mindexPartial arr i
+ mscalar (Shaped x) = M_Shaped (M_Nest x)
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
@@ -707,6 +716,9 @@ rappend :: forall n a. (KnownINat n, Elt a)
=> Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
+rscalar :: Elt a => a -> Ranked I0 a
+rscalar x = Ranked (mscalar x)
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -800,3 +812,6 @@ stranspose perm (Shaped arr)
sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)
=> Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
+
+sscalar :: Elt a => a -> Shaped '[] a
+sscalar x = Shaped (mscalar x)