aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith')
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs35
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs58
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs78
3 files changed, 120 insertions, 51 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index f84b1c5..49effa1 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where
import Control.Monad
import Data.Int
import Data.Maybe
+import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
@@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
- let base = aboName arithop ++ "_" ++ atCName arithtype
- sequence $ catMaybes
- [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
- [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,guard (aboComm arithop == NonComm) >>
- Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ])
+ let base = "binary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM unopsList $ \arithop -> do
- let base = auoName arithop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "unary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM redopsList $ \redop -> do
- let base = aroName redop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "reduce_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 78fe24a..91e50ad 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -1,13 +1,13 @@
-{-# LANGUAGE TemplateHaskellQuotes #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Nested.Internal.Arith.Lists where
+import Data.Char
import Data.Int
-
import Language.Haskell.TH
+import Data.Array.Nested.Internal.Arith.Lists.TH
-data Commutative = Comm | NonComm
- deriving (Show, Eq)
data ArithType = ArithType
{ atType :: Name -- ''Int32
@@ -22,36 +22,30 @@ typesList =
,ArithType ''Double "double"
]
-data ArithBOp = ArithBOp
- { aboName :: String -- "add"
- , aboComm :: Commutative -- Comm
- , aboScalFun :: ArithType -> Name -- \_ -> '(+)
- }
+-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
+$(genArithDataType Binop "ArithBOp")
-binopsList :: [ArithBOp]
-binopsList =
- [ArithBOp "add" Comm (\_ -> '(+))
- ,ArithBOp "sub" NonComm (\_ -> '(-))
- ,ArithBOp "mul" Comm (\_ -> '(*))
- ]
+$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
+$(genArithEnumFun Binop ''ArithBOp "aboEnum")
-data ArithUOp = ArithUOp
- { auoName :: String -- "neg"
- }
+$(do clauses <- readArithLists Binop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
+ ,return $ FunD (mkName "aboNumOp") clauses])
-unopsList :: [ArithUOp]
-unopsList =
- [ArithUOp "neg"
- ,ArithUOp "abs"
- ,ArithUOp "signum"
- ]
-data ArithRedOp = ArithRedOp
- { aroName :: String -- "sum"
- }
+-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
+$(genArithDataType Unop "ArithUOp")
-redopsList :: [ArithRedOp]
-redopsList =
- [ArithRedOp "sum1"
- ,ArithRedOp "product1"
- ]
+$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
+$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+
+
+-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
+$(genArithDataType Redop "ArithRedOp")
+
+$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
+$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
new file mode 100644
index 0000000..b748b97
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Internal.Arith.Lists.TH where
+
+import Control.Monad
+import Control.Monad.IO.Class
+import Data.Maybe
+import Foreign.C.Types
+import Language.Haskell.TH
+import Text.Read
+
+
+data OpKind = Binop | Unop | Redop
+ deriving (Show, Eq)
+
+readArithLists :: OpKind
+ -> (String -> Int -> String -> Q a)
+ -> ([a] -> Q r)
+ -> Q r
+readArithLists targetkind fop fcombine = do
+ lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
+
+ mvals <- forM lns $ \line -> do
+ if null (dropWhile (== ' ') line)
+ then return Nothing
+ else do let (kind, name, num, aux) = parseLine line
+ if kind == targetkind
+ then Just <$> fop name num aux
+ else return Nothing
+
+ fcombine (catMaybes mvals)
+ where
+ parseLine s0
+ | ("LIST_", s1) <- splitAt 5 s0
+ , (kindstr, '(' : s2) <- break (== '(') s1
+ , (f1, ',' : s3) <- parseField s2
+ , (f2, ',' : s4) <- parseField s3
+ , (f3, ')' : _) <- parseField s4
+ , Just kind <- parseKind kindstr
+ , let name = f1
+ , Just num <- readMaybe f2
+ , let aux = f3
+ = (kind, name, num, aux)
+ | otherwise
+ = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
+
+ parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
+
+ parseKind "BINOP" = Just Binop
+ parseKind "UNOP" = Just Unop
+ parseKind "REDOP" = Just Redop
+ parseKind _ = Nothing
+
+genArithDataType :: OpKind -> String -> Q [Dec]
+genArithDataType kind dtname = do
+ cons <- readArithLists kind
+ (\name _num _ -> return $ NormalC (mkName name) [])
+ return
+ return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
+
+genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
+genArithNameFun kind dtname funname nametrans = do
+ clauses <- readArithLists kind
+ (\name _num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (StringL (nametrans name))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
+ ,FunD (mkName funname) clauses]
+
+genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
+genArithEnumFun kind dtname funname = do
+ clauses <- readArithLists kind
+ (\name num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (IntegerL (fromIntegral num))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
+ ,FunD (mkName funname) clauses]