aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 14:57:34 +0200
commite80b2593edc3d216905279ebcfa797593a1efbfc (patch)
tree5e5057e03f35369983f6600efc59c438c0cf2366 /src/Data/Array/Nested/Internal/Arith
parent2ac16efe59051e0cdeb37422ab579c8d354d562a (diff)
Fast Fractional ops via C code
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith')
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs18
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs31
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs4
3 files changed, 50 insertions, 3 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index 49effa1..ac83188 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -24,12 +24,30 @@ $(fmap concat . forM typesList $ \arithtype -> do
[t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "fbinary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
let base = "unary_" ++ atCName arithtype
pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "funary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
let base = "reduce_" ++ atCName arithtype
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 91e50ad..ce2836d 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -14,13 +14,18 @@ data ArithType = ArithType
, atCName :: String -- "i32"
}
+floatTypesList :: [ArithType]
+floatTypesList =
+ [ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
typesList :: [ArithType]
typesList =
[ArithType ''Int32 "i32"
,ArithType ''Int64 "i64"
- ,ArithType ''Float "float"
- ,ArithType ''Double "double"
]
+ ++ floatTypesList
-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
$(genArithDataType Binop "ArithBOp")
@@ -37,6 +42,21 @@ $(do clauses <- readArithLists Binop
,return $ FunD (mkName "aboNumOp") clauses])
+-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
+$(genArithDataType FBinop "ArithFBOp")
+
+$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
+$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
+
+$(do clauses <- readArithLists FBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
+ ,return $ FunD (mkName "afboNumOp") clauses])
+
+
-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
$(genArithDataType Unop "ArithUOp")
@@ -44,6 +64,13 @@ $(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
+$(genArithDataType FUnop "ArithFUOp")
+
+$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
+$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
+
+
-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
$(genArithDataType Redop "ArithRedOp")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
index b748b97..b40a066 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
@@ -9,7 +9,7 @@ import Language.Haskell.TH
import Text.Read
-data OpKind = Binop | Unop | Redop
+data OpKind = Binop | FBinop | Unop | FUnop | Redop
deriving (Show, Eq)
readArithLists :: OpKind
@@ -46,7 +46,9 @@ readArithLists targetkind fop fcombine = do
parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
parseKind "BINOP" = Just Binop
+ parseKind "FBINOP" = Just FBinop
parseKind "UNOP" = Just Unop
+ parseKind "FUNOP" = Just FUnop
parseKind "REDOP" = Just Redop
parseKind _ = Nothing