diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:01:02 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-14 16:02:26 +0200 |
commit | e001480cd6ac3a3b79c837c4a12645bf78200b98 (patch) | |
tree | a463d8679f731d9fcf4e059704dcdc4065ce86bf /src/Data/Array/Nested/Internal.hs | |
parent | 40dcdf2360c90437fd5c8f76f5f75c96203ef880 (diff) |
scalar
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 209d594..1079e99 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -146,6 +146,7 @@ class Elt a where mshape :: KnownShapeX sh => Mixed sh a -> IxX sh mindex :: Mixed sh a -> IxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -189,6 +190,7 @@ instance Storable a => Elt (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive (X.scalar x) mlift :: forall sh1 sh2. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -231,6 +233,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) @@ -265,6 +268,8 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + mscalar x = M_Nest x + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) @@ -444,6 +449,8 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ mindexPartial arr i + mscalar (Ranked x) = M_Ranked (M_Nest x) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) @@ -547,6 +554,8 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mindexPartial arr i + mscalar (Shaped x) = M_Shaped (M_Nest x) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) @@ -707,6 +716,9 @@ rappend :: forall n a. (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend +rscalar :: Elt a => a -> Ranked I0 a +rscalar x = Ranked (mscalar x) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -800,3 +812,6 @@ stranspose perm (Shaped arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) |