diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 00:18:17 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 00:18:17 +0200 |
commit | a0010622885dcb55a916bf3514c0e9040f6871e9 (patch) | |
tree | 9e10c18eaf5c873d50e1f88a3bf114179c151769 /src/Data/Array/Nested/Internal/Arith | |
parent | 4b74d1b1f7c46a4b3907838bee11f669060d3a23 (diff) |
Fast numeric operations for Num
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith')
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 33 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 47 |
2 files changed, 80 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..dbd9ddc --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -0,0 +1,33 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Nested.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Nested.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM binopsList $ \arithop -> do + let base = aboName arithop ++ "_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,guard (aboComm arithop == NonComm) >> + Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + forM unopsList $ \arithop -> do + let base = auoName arithop ++ "_" ++ atCName arithtype + ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs new file mode 100644 index 0000000..1b29770 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Internal.Arith.Lists where + +import Data.Int + +import Language.Haskell.TH + + +data Commutative = Comm | NonComm + deriving (Show, Eq) + +data ArithType = ArithType + { atType :: Name -- ''Int32 + , atCName :: String -- "i32" + } + +typesList :: [ArithType] +typesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ,ArithType ''Float "float" + ,ArithType ''Double "double" + ] + +data ArithBOp = ArithBOp + { aboName :: String -- "add" + , aboComm :: Commutative -- Comm + , aboScalFun :: ArithType -> Name -- \_ -> '(+) + } + +binopsList :: [ArithBOp] +binopsList = + [ArithBOp "add" Comm (\_ -> '(+)) + ,ArithBOp "sub" NonComm (\_ -> '(-)) + ,ArithBOp "mul" Comm (\_ -> '(*)) + ] + +data ArithUOp = ArithUOp + { auoName :: String -- "neg" + } + +unopsList :: [ArithUOp] +unopsList = + [ArithUOp "neg" + ,ArithUOp "abs" + ,ArithUOp "signum" + ] |