diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith/Foreign.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 35 |
1 files changed, 16 insertions, 19 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index f84b1c5..49effa1 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where import Control.Monad import Data.Int import Data.Maybe +import Foreign.C.Types import Foreign.Ptr import Language.Haskell.TH @@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM binopsList $ \arithop -> do - let base = aboName arithop ++ "_" ++ atCName arithtype - sequence $ catMaybes - [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,guard (aboComm arithop == NonComm) >> - Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ]) + let base = "binary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - forM unopsList $ \arithop -> do - let base = auoName arithop ++ "_" ++ atCName arithtype - ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + let base = "unary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - forM redopsList $ \redop -> do - let base = aroName redop ++ "_" ++ atCName arithtype - ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + let base = "reduce_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) |