aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 00:18:17 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 00:18:17 +0200
commita0010622885dcb55a916bf3514c0e9040f6871e9 (patch)
tree9e10c18eaf5c873d50e1f88a3bf114179c151769 /src/Data/Array/Nested/Internal/Arith
parent4b74d1b1f7c46a4b3907838bee11f669060d3a23 (diff)
Fast numeric operations for Num
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith')
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs33
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs47
2 files changed, 80 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
new file mode 100644
index 0000000..dbd9ddc
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -0,0 +1,33 @@
+{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Nested.Internal.Arith.Foreign where
+
+import Control.Monad
+import Data.Int
+import Data.Maybe
+import Foreign.Ptr
+import Language.Haskell.TH
+
+import Data.Array.Nested.Internal.Arith.Lists
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM binopsList $ \arithop -> do
+ let base = aboName arithop ++ "_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,guard (aboComm arithop == NonComm) >>
+ Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ forM unopsList $ \arithop -> do
+ let base = auoName arithop ++ "_" ++ atCName arithtype
+ ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
new file mode 100644
index 0000000..1b29770
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -0,0 +1,47 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Internal.Arith.Lists where
+
+import Data.Int
+
+import Language.Haskell.TH
+
+
+data Commutative = Comm | NonComm
+ deriving (Show, Eq)
+
+data ArithType = ArithType
+ { atType :: Name -- ''Int32
+ , atCName :: String -- "i32"
+ }
+
+typesList :: [ArithType]
+typesList =
+ [ArithType ''Int32 "i32"
+ ,ArithType ''Int64 "i64"
+ ,ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
+data ArithBOp = ArithBOp
+ { aboName :: String -- "add"
+ , aboComm :: Commutative -- Comm
+ , aboScalFun :: ArithType -> Name -- \_ -> '(+)
+ }
+
+binopsList :: [ArithBOp]
+binopsList =
+ [ArithBOp "add" Comm (\_ -> '(+))
+ ,ArithBOp "sub" NonComm (\_ -> '(-))
+ ,ArithBOp "mul" Comm (\_ -> '(*))
+ ]
+
+data ArithUOp = ArithUOp
+ { auoName :: String -- "neg"
+ }
+
+unopsList :: [ArithUOp]
+unopsList =
+ [ArithUOp "neg"
+ ,ArithUOp "abs"
+ ,ArithUOp "signum"
+ ]