aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
commit4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch)
tree2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Nested/Internal.hs
parent827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff)
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index f3f8f7d..118612f 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -884,13 +884,13 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, Num a)
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1P (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr)
-msumOuter1 :: forall sh n a. (Num a, PrimElt a)
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
@@ -1466,13 +1466,13 @@ rlift2 :: forall n1 n2 n3 a. Elt a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
rsumOuter1P :: forall n a.
- (Storable a, Num a)
+ (Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked (msumOuter1P arr)
-rsumOuter1 :: forall n a. (Num a, PrimElt a)
+rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
@@ -1748,11 +1748,11 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2)
-ssumOuter1P :: forall sh n a. (Storable a, Num a)
+ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
-ssumOuter1 :: forall sh n a. (Num a, PrimElt a)
+ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive