diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 13:47:18 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 13:47:18 +0200 |
commit | 4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch) | |
tree | 2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Nested/Internal.hs | |
parent | 827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff) |
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f3f8f7d..118612f 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -884,13 +884,13 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -msumOuter1P :: forall sh n a. (Storable a, Num a) +msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) -msumOuter1 :: forall sh n a. (Num a, PrimElt a) +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive @@ -1466,13 +1466,13 @@ rlift2 :: forall n1 n2 n3 a. Elt a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) rsumOuter1P :: forall n a. - (Storable a, Num a) + (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked (msumOuter1P arr) -rsumOuter1 :: forall n a. (Num a, PrimElt a) +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive @@ -1748,11 +1748,11 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) -ssumOuter1P :: forall sh n a. (Storable a, Num a) +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) -ssumOuter1 :: forall sh n a. (Num a, PrimElt a) +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive |