aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
commit4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch)
tree2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Nested/Internal/Arith
parent827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff)
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith')
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs10
2 files changed, 18 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index dbd9ddc..f84b1c5 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -22,7 +22,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
[t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,guard (aboComm arithop == NonComm) >>
Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
])
$(fmap concat . forM typesList $ \arithtype -> do
@@ -31,3 +31,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
let base = auoName arithop ++ "_" ++ atCName arithtype
ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ forM redopsList $ \redop -> do
+ let base = aroName redop ++ "_" ++ atCName arithtype
+ ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 1b29770..78fe24a 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -45,3 +45,13 @@ unopsList =
,ArithUOp "abs"
,ArithUOp "signum"
]
+
+data ArithRedOp = ArithRedOp
+ { aroName :: String -- "sum"
+ }
+
+redopsList :: [ArithRedOp]
+redopsList =
+ [ArithRedOp "sum1"
+ ,ArithRedOp "product1"
+ ]