aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
commit34a9ac8e4497e776c3ca499c41ef749f4edf8383 (patch)
treef2b2e34d830d66d23ae19909c71771e810c262d0 /src/Data/Array/Nested
parent85593969debadbf11ad3c159de71e7b480ca367c (diff)
Refactor C interface to pass operation as enum
This is hmatrix style, less proliferation of functions as the number of ops increases
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs77
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs35
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs58
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs78
4 files changed, 168 insertions, 80 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 4bfc043..7484455 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -182,14 +182,13 @@ class NumElt a where
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype
- c_ss = varE (aboScalFun arithop arithtype)
- c_sv = varE $ mkName (cnamebase ++ "_sv")
- c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs")
- | otherwise = [| flipOp $c_sv |]
- c_vv = varE $ mkName (cnamebase ++ "_vv")
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss = varE (aboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
@@ -197,9 +196,9 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM unopsList $ \arithop -> do
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)
+ c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
@@ -207,10 +206,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM redopsList $ \redop -> do
- let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype)
- c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv")
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
@@ -297,21 +296,41 @@ instance NumElt Double where
numEltProduct1Inner = product1VectorDouble
instance NumElt Int where
- numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
- numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
- numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
- numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64
- numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64
- numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64
- numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
- numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
instance NumElt CInt where
- numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
- numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv
- numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv
- numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64
- numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64
- numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64
- numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
- numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index f84b1c5..49effa1 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where
import Control.Monad
import Data.Int
import Data.Maybe
+import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
@@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- fmap concat . forM binopsList $ \arithop -> do
- let base = aboName arithop ++ "_" ++ atCName arithtype
- sequence $ catMaybes
- [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
- [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,guard (aboComm arithop == NonComm) >>
- Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ])
+ let base = "binary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM unopsList $ \arithop -> do
- let base = auoName arithop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "unary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- forM redopsList $ \redop -> do
- let base = aroName redop ++ "_" ++ atCName arithtype
- ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base = "reduce_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 78fe24a..91e50ad 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -1,13 +1,13 @@
-{-# LANGUAGE TemplateHaskellQuotes #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Nested.Internal.Arith.Lists where
+import Data.Char
import Data.Int
-
import Language.Haskell.TH
+import Data.Array.Nested.Internal.Arith.Lists.TH
-data Commutative = Comm | NonComm
- deriving (Show, Eq)
data ArithType = ArithType
{ atType :: Name -- ''Int32
@@ -22,36 +22,30 @@ typesList =
,ArithType ''Double "double"
]
-data ArithBOp = ArithBOp
- { aboName :: String -- "add"
- , aboComm :: Commutative -- Comm
- , aboScalFun :: ArithType -> Name -- \_ -> '(+)
- }
+-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
+$(genArithDataType Binop "ArithBOp")
-binopsList :: [ArithBOp]
-binopsList =
- [ArithBOp "add" Comm (\_ -> '(+))
- ,ArithBOp "sub" NonComm (\_ -> '(-))
- ,ArithBOp "mul" Comm (\_ -> '(*))
- ]
+$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
+$(genArithEnumFun Binop ''ArithBOp "aboEnum")
-data ArithUOp = ArithUOp
- { auoName :: String -- "neg"
- }
+$(do clauses <- readArithLists Binop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
+ ,return $ FunD (mkName "aboNumOp") clauses])
-unopsList :: [ArithUOp]
-unopsList =
- [ArithUOp "neg"
- ,ArithUOp "abs"
- ,ArithUOp "signum"
- ]
-data ArithRedOp = ArithRedOp
- { aroName :: String -- "sum"
- }
+-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
+$(genArithDataType Unop "ArithUOp")
-redopsList :: [ArithRedOp]
-redopsList =
- [ArithRedOp "sum1"
- ,ArithRedOp "product1"
- ]
+$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
+$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+
+
+-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
+$(genArithDataType Redop "ArithRedOp")
+
+$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
+$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
new file mode 100644
index 0000000..b748b97
--- /dev/null
+++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Internal.Arith.Lists.TH where
+
+import Control.Monad
+import Control.Monad.IO.Class
+import Data.Maybe
+import Foreign.C.Types
+import Language.Haskell.TH
+import Text.Read
+
+
+data OpKind = Binop | Unop | Redop
+ deriving (Show, Eq)
+
+readArithLists :: OpKind
+ -> (String -> Int -> String -> Q a)
+ -> ([a] -> Q r)
+ -> Q r
+readArithLists targetkind fop fcombine = do
+ lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
+
+ mvals <- forM lns $ \line -> do
+ if null (dropWhile (== ' ') line)
+ then return Nothing
+ else do let (kind, name, num, aux) = parseLine line
+ if kind == targetkind
+ then Just <$> fop name num aux
+ else return Nothing
+
+ fcombine (catMaybes mvals)
+ where
+ parseLine s0
+ | ("LIST_", s1) <- splitAt 5 s0
+ , (kindstr, '(' : s2) <- break (== '(') s1
+ , (f1, ',' : s3) <- parseField s2
+ , (f2, ',' : s4) <- parseField s3
+ , (f3, ')' : _) <- parseField s4
+ , Just kind <- parseKind kindstr
+ , let name = f1
+ , Just num <- readMaybe f2
+ , let aux = f3
+ = (kind, name, num, aux)
+ | otherwise
+ = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
+
+ parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
+
+ parseKind "BINOP" = Just Binop
+ parseKind "UNOP" = Just Unop
+ parseKind "REDOP" = Just Redop
+ parseKind _ = Nothing
+
+genArithDataType :: OpKind -> String -> Q [Dec]
+genArithDataType kind dtname = do
+ cons <- readArithLists kind
+ (\name _num _ -> return $ NormalC (mkName name) [])
+ return
+ return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
+
+genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
+genArithNameFun kind dtname funname nametrans = do
+ clauses <- readArithLists kind
+ (\name _num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (StringL (nametrans name))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
+ ,FunD (mkName funname) clauses]
+
+genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
+genArithEnumFun kind dtname funname = do
+ clauses <- readArithLists kind
+ (\name num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (IntegerL (fromIntegral num))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
+ ,FunD (mkName funname) clauses]