aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith/Foreign.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
commit34a9ac8e4497e776c3ca499c41ef749f4edf8383 (patch)
treef2b2e34d830d66d23ae19909c71771e810c262d0 /src/Data/Array/Nested/Internal/Arith/Foreign.hs
parent85593969debadbf11ad3c159de71e7b480ca367c (diff)
Refactor C interface to pass operation as enum
This is hmatrix style, less proliferation of functions as the number of ops increases
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith/Foreign.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs35
1 files changed, 16 insertions, 19 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index f84b1c5..49effa1 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where
import Control.Monad
import Data.Int
import Data.Maybe
+import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
@@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
- let base = aboName arithop ++ "_" ++ atCName arithtype
- sequence $ catMaybes
- [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
- [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,guard (aboComm arithop == NonComm) >>
- Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ])
+ let base = "binary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM unopsList $ \arithop -> do
- let base = auoName arithop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "unary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM redopsList $ \redop -> do
- let base = aroName redop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "reduce_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])