From 34a9ac8e4497e776c3ca499c41ef749f4edf8383 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 26 May 2024 00:11:00 +0200 Subject: Refactor C interface to pass operation as enum This is hmatrix style, less proliferation of functions as the number of ops increases --- src/Data/Array/Nested/Internal/Arith.hs | 77 ++++++++++++++--------- src/Data/Array/Nested/Internal/Arith/Foreign.hs | 35 +++++------ src/Data/Array/Nested/Internal/Arith/Lists.hs | 58 ++++++++---------- src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 78 ++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 80 deletions(-) create mode 100644 src/Data/Array/Nested/Internal/Arith/Lists/TH.hs (limited to 'src') diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4bfc043..7484455 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -182,14 +182,13 @@ class NumElt a where $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM binopsList $ \arithop -> do + fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype - c_ss = varE (aboScalFun arithop arithtype) - c_sv = varE $ mkName (cnamebase ++ "_sv") - c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs") - | otherwise = [| flipOp $c_sv |] - c_vv = varE $ mkName (cnamebase ++ "_vv") + cnamebase = "c_binary_" ++ atCName arithtype + c_ss = varE (aboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] @@ -197,9 +196,9 @@ $(fmap concat . forM typesList $ \arithtype -> do $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM unopsList $ \arithop -> do + fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) + c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] @@ -207,10 +206,10 @@ $(fmap concat . forM typesList $ \arithtype -> do $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM redopsList $ \redop -> do - let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) - c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] @@ -297,21 +296,41 @@ instance NumElt Double where numEltProduct1Inner = product1VectorDouble instance NumElt Int where - numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv - numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv - numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv - numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 - numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 - numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 - numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 - numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 + numEltAdd = intWidBranch2 @Int (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @Int (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @Int (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) instance NumElt CInt where - numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv - numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv - numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv - numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 - numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 - numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 - numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 - numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 + numEltAdd = intWidBranch2 @CInt (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @CInt (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @CInt (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index f84b1c5..49effa1 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where import Control.Monad import Data.Int import Data.Maybe +import Foreign.C.Types import Foreign.Ptr import Language.Haskell.TH @@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - fmap concat . forM binopsList $ \arithop -> do - let base = aboName arithop ++ "_" ++ atCName arithtype - sequence $ catMaybes - [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,guard (aboComm arithop == NonComm) >> - Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ]) + let base = "binary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - forM unopsList $ \arithop -> do - let base = auoName arithop ++ "_" ++ atCName arithtype - ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + let base = "unary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - forM redopsList $ \redop -> do - let base = aroName redop ++ "_" ++ atCName arithtype - ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + let base = "reduce_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 78fe24a..91e50ad 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -1,13 +1,13 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} module Data.Array.Nested.Internal.Arith.Lists where +import Data.Char import Data.Int - import Language.Haskell.TH +import Data.Array.Nested.Internal.Arith.Lists.TH -data Commutative = Comm | NonComm - deriving (Show, Eq) data ArithType = ArithType { atType :: Name -- ''Int32 @@ -22,36 +22,30 @@ typesList = ,ArithType ''Double "double" ] -data ArithBOp = ArithBOp - { aboName :: String -- "add" - , aboComm :: Commutative -- Comm - , aboScalFun :: ArithType -> Name -- \_ -> '(+) - } +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") -binopsList :: [ArithBOp] -binopsList = - [ArithBOp "add" Comm (\_ -> '(+)) - ,ArithBOp "sub" NonComm (\_ -> '(-)) - ,ArithBOp "mul" Comm (\_ -> '(*)) - ] +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") -data ArithUOp = ArithUOp - { auoName :: String -- "neg" - } +$(do clauses <- readArithLists Binop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] + ,return $ FunD (mkName "aboNumOp") clauses]) -unopsList :: [ArithUOp] -unopsList = - [ArithUOp "neg" - ,ArithUOp "abs" - ,ArithUOp "signum" - ] -data ArithRedOp = ArithRedOp - { aroName :: String -- "sum" - } +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") -redopsList :: [ArithRedOp] -redopsList = - [ArithRedOp "sum1" - ,ArithRedOp "product1" - ] +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..b748b97 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Text.Read + + +data OpKind = Binop | Unop | Redop + deriving (Show, Eq) + +readArithLists :: OpKind + -> (String -> Int -> String -> Q a) + -> ([a] -> Q r) + -> Q r +readArithLists targetkind fop fcombine = do + lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + + mvals <- forM lns $ \line -> do + if null (dropWhile (== ' ') line) + then return Nothing + else do let (kind, name, num, aux) = parseLine line + if kind == targetkind + then Just <$> fop name num aux + else return Nothing + + fcombine (catMaybes mvals) + where + parseLine s0 + | ("LIST_", s1) <- splitAt 5 s0 + , (kindstr, '(' : s2) <- break (== '(') s1 + , (f1, ',' : s3) <- parseField s2 + , (f2, ',' : s4) <- parseField s3 + , (f3, ')' : _) <- parseField s4 + , Just kind <- parseKind kindstr + , let name = f1 + , Just num <- readMaybe f2 + , let aux = f3 + = (kind, name, num, aux) + | otherwise + = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + + parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + + parseKind "BINOP" = Just Binop + parseKind "UNOP" = Just Unop + parseKind "REDOP" = Just Redop + parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do + cons <- readArithLists kind + (\name _num _ -> return $ NormalC (mkName name) []) + return + return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do + clauses <- readArithLists kind + (\name _num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (StringL (nametrans name)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) + ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do + clauses <- readArithLists kind + (\name num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (IntegerL (fromIntegral num)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) + ,FunD (mkName funname) clauses] -- cgit v1.2.3-70-g09d2