From e001480cd6ac3a3b79c837c4a12645bf78200b98 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 14 Apr 2024 16:01:02 +0200 Subject: scalar --- src/Data/Array/Nested.hs | 6 +++--- src/Data/Array/Nested/Internal.hs | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 0a9408b..9e4c0e7 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -4,7 +4,7 @@ module Data.Array.Nested ( Ranked, IxR(..), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, - rtranspose, + rtranspose, rscalar, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -13,7 +13,7 @@ module Data.Array.Nested ( IxS(..), KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, - stranspose, + stranspose, sscalar, -- ** Lifting orthotope operations to 'Shaped' arrays slift, @@ -24,7 +24,7 @@ module Data.Array.Nested ( mgenerate, mtranspose, -- * Array elements - Elt(mshape, mindex, mindexPartial, mlift), + Elt(mshape, mindex, mindexPartial, mscalar, mlift), Primitive(..), -- * Inductive natural numbers diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 209d594..1079e99 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -146,6 +146,7 @@ class Elt a where mshape :: KnownShapeX sh => Mixed sh a -> IxX sh mindex :: Mixed sh a -> IxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -189,6 +190,7 @@ instance Storable a => Elt (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive (X.scalar x) mlift :: forall sh1 sh2. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -231,6 +233,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) @@ -265,6 +268,8 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + mscalar x = M_Nest x + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) @@ -444,6 +449,8 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ mindexPartial arr i + mscalar (Ranked x) = M_Ranked (M_Nest x) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) @@ -547,6 +554,8 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mindexPartial arr i + mscalar (Shaped x) = M_Shaped (M_Nest x) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) @@ -707,6 +716,9 @@ rappend :: forall n a. (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend +rscalar :: Elt a => a -> Ranked I0 a +rscalar x = Ranked (mscalar x) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -800,3 +812,6 @@ stranspose perm (Shaped arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) -- cgit v1.2.3-70-g09d2