From 827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 23 May 2024 09:30:18 +0200 Subject: msumOuter1 --- src/Data/Array/Nested/Internal.hs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'src/Data/Array/Nested/Internal.hs') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 3ff7967..f3f8f7d 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -884,6 +884,16 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs +msumOuter1P :: forall sh n a. (Storable a, Num a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = + let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX + in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) + +msumOuter1 :: forall sh n a. (Num a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 @@ -1458,10 +1468,9 @@ rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f ar rsumOuter1P :: forall n a. (Storable a, Num a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked (M_Primitive sh arr)) +rsumOuter1P (Ranked arr) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - , _ :$% shT <- sh - = Ranked (M_Primitive shT (X.sumOuter (SUnknown () :!% ZKX) (X.staticShapeFrom shT) arr)) + = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (Num a, PrimElt a) => Ranked (n + 1) a -> Ranked n a @@ -1741,8 +1750,7 @@ slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (sh ssumOuter1P :: forall sh n a. (Storable a, Num a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) = - Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr)) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) ssumOuter1 :: forall sh n a. (Num a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a -- cgit v1.2.3-70-g09d2