aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 09:30:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 09:30:18 +0200
commit827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (patch)
tree42b82306adda422ecba69e1d61cbb6ae9cfe085f /src/Data/Array
parent6af2f87bfc0fe7b11193c346115da70cddc244c5 (diff)
msumOuter1
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested.hs3
-rw-r--r--src/Data/Array/Nested/Internal.hs18
2 files changed, 14 insertions, 7 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9a291e6..292eba4 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -38,8 +38,7 @@ module Data.Array.Nested (
Mixed,
IxX(..), IIxX,
KnownShX(..), StaticShX(..),
- -- TODO: missing msumOuter1?
- mshape, mindex, mindexPartial, mgenerate,
+ mshape, mindex, mindexPartial, mgenerate, msumOuter1,
mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
mreplicate, mfromListOuter, mfromList1, mtoListOuter, mtoList1,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 3ff7967..f3f8f7d 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -884,6 +884,16 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
+msumOuter1P :: forall sh n a. (Storable a, Num a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1P (M_Primitive (n :$% sh) arr) =
+ let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
+ in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr)
+
+msumOuter1 :: forall sh n a. (Num a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
@@ -1458,10 +1468,9 @@ rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f ar
rsumOuter1P :: forall n a.
(Storable a, Num a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked (M_Primitive sh arr))
+rsumOuter1P (Ranked arr)
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- , _ :$% shT <- sh
- = Ranked (M_Primitive shT (X.sumOuter (SUnknown () :!% ZKX) (X.staticShapeFrom shT) arr))
+ = Ranked (msumOuter1P arr)
rsumOuter1 :: forall n a. (Num a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
@@ -1741,8 +1750,7 @@ slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (sh
ssumOuter1P :: forall sh n a. (Storable a, Num a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) =
- Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr))
+ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
ssumOuter1 :: forall sh n a. (Num a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a