diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 00:11:00 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 00:11:00 +0200 | 
| commit | 34a9ac8e4497e776c3ca499c41ef749f4edf8383 (patch) | |
| tree | f2b2e34d830d66d23ae19909c71771e810c262d0 /src/Data | |
| parent | 85593969debadbf11ad3c159de71e7b480ca367c (diff) | |
Refactor C interface to pass operation as enum
This is hmatrix style, less proliferation of functions as the number of
ops increases
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 77 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 35 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 58 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 78 | 
4 files changed, 168 insertions, 80 deletions
| diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4bfc043..7484455 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -182,14 +182,13 @@ class NumElt a where  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    fmap concat . forM binopsList $ \arithop -> do +    fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          cnamebase = "c_" ++ aboName arithop ++ "_" ++ atCName arithtype -          c_ss = varE (aboScalFun arithop arithtype) -          c_sv = varE $ mkName (cnamebase ++ "_sv") -          c_vs | aboComm arithop == NonComm = varE $ mkName (cnamebase ++ "_vs") -               | otherwise = [| flipOp $c_sv |] -          c_vv = varE $ mkName (cnamebase ++ "_vv") +          cnamebase = "c_binary_" ++ atCName arithtype +          c_ss = varE (aboNumOp arithop) +          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]                 ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] @@ -197,9 +196,9 @@ $(fmap concat . forM typesList $ \arithtype -> do  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    fmap concat . forM unopsList $ \arithop -> do +    fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) +          c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]                 ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] @@ -207,10 +206,10 @@ $(fmap concat . forM typesList $ \arithtype -> do  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    fmap concat . forM redopsList $ \redop -> do -      let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) -          c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]                 ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] @@ -297,21 +296,41 @@ instance NumElt Double where    numEltProduct1Inner = product1VectorDouble  instance NumElt Int where -  numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv -  numEltSub = intWidBranch2 @Int (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv -  numEltMul = intWidBranch2 @Int (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv -  numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 -  numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 -  numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 -  numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 -  numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 +  numEltAdd = intWidBranch2 @Int (+) +                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) +                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @Int (-) +                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) +                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @Int (*) +                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) +                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed @Int +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) +  numEltProduct1Inner = intWidBranchRed @Int +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))  instance NumElt CInt where -  numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv -  numEltSub = intWidBranch2 @CInt (-) c_sub_i32_sv (flipOp c_sub_i32_sv) c_sub_i32_vv c_sub_i64_sv (flipOp c_sub_i64_sv) c_sub_i64_vv -  numEltMul = intWidBranch2 @CInt (*) c_mul_i32_sv (flipOp c_mul_i32_sv) c_mul_i32_vv c_mul_i64_sv (flipOp c_mul_i64_sv) c_mul_i64_vv -  numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 -  numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 -  numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 -  numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 -  numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 +  numEltAdd = intWidBranch2 @CInt (+) +                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) +                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @CInt (-) +                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) +                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @CInt (*) +                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) +                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed @CInt +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) +  numEltProduct1Inner = intWidBranchRed @CInt +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs index f84b1c5..49effa1 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs @@ -5,6 +5,7 @@ module Data.Array.Nested.Internal.Arith.Foreign where  import Control.Monad  import Data.Int  import Data.Maybe +import Foreign.C.Types  import Foreign.Ptr  import Language.Haskell.TH @@ -13,28 +14,24 @@ import Data.Array.Nested.Internal.Arith.Lists  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    fmap concat . forM binopsList $ \arithop -> do -      let base = aboName arithop ++ "_" ++ atCName arithtype -      sequence $ catMaybes -        [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> -                 [t| Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) -        ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> -                 [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) -        ,guard (aboComm arithop == NonComm) >> -           Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -                   [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) -        ]) +    let base = "binary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ])  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    forM unopsList $ \arithop -> do -      let base = auoName arithop ++ "_" ++ atCName arithtype -      ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -        [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +    let base = "unary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])  $(fmap concat . forM typesList $ \arithtype -> do      let ttyp = conT (atType arithtype) -    forM redopsList $ \redop -> do -      let base = aroName redop ++ "_" ++ atCName arithtype -      ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -        [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +    let base = "reduce_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs index 78fe24a..91e50ad 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs @@ -1,13 +1,13 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-}  module Data.Array.Nested.Internal.Arith.Lists where +import Data.Char  import Data.Int -  import Language.Haskell.TH +import Data.Array.Nested.Internal.Arith.Lists.TH -data Commutative = Comm | NonComm -  deriving (Show, Eq)  data ArithType = ArithType    { atType :: Name  -- ''Int32 @@ -22,36 +22,30 @@ typesList =    ,ArithType ''Double "double"    ] -data ArithBOp = ArithBOp -  { aboName :: String  -- "add" -  , aboComm :: Commutative  -- Comm -  , aboScalFun :: ArithType -> Name  -- \_ -> '(+) -  } +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") -binopsList :: [ArithBOp] -binopsList = -  [ArithBOp "add" Comm (\_ -> '(+)) -  ,ArithBOp "sub" NonComm (\_ -> '(-)) -  ,ArithBOp "mul" Comm (\_ -> '(*)) -  ] +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") -data ArithUOp = ArithUOp -  { auoName :: String  -- "neg" -  } +$(do clauses <- readArithLists Binop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] +              ,return $ FunD (mkName "aboNumOp") clauses]) -unopsList :: [ArithUOp] -unopsList = -  [ArithUOp "neg" -  ,ArithUOp "abs" -  ,ArithUOp "signum" -  ] -data ArithRedOp = ArithRedOp -  { aroName :: String  -- "sum" -  } +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") -redopsList :: [ArithRedOp] -redopsList = -  [ArithRedOp "sum1" -  ,ArithRedOp "product1" -  ] +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..b748b97 --- /dev/null +++ b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Text.Read + + +data OpKind = Binop | Unop | Redop +  deriving (Show, Eq) + +readArithLists :: OpKind +               -> (String -> Int -> String -> Q a) +               -> ([a] -> Q r) +               -> Q r +readArithLists targetkind fop fcombine = do +  lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + +  mvals <- forM lns $ \line -> do +    if null (dropWhile (== ' ') line) +      then return Nothing +      else do let (kind, name, num, aux) = parseLine line +              if kind == targetkind +                then Just <$> fop name num aux +                else return Nothing + +  fcombine (catMaybes mvals) +  where +    parseLine s0 +      | ("LIST_", s1) <- splitAt 5 s0 +      , (kindstr, '(' : s2) <- break (== '(') s1 +      , (f1, ',' : s3) <- parseField s2 +      , (f2, ',' : s4) <- parseField s3 +      , (f3, ')' : _) <- parseField s4 +      , Just kind <- parseKind kindstr +      , let name = f1 +      , Just num <- readMaybe f2 +      , let aux = f3 +      = (kind, name, num, aux) +      | otherwise +      = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + +    parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + +    parseKind "BINOP" = Just Binop +    parseKind "UNOP" = Just Unop +    parseKind "REDOP" = Just Redop +    parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do +  cons <- readArithLists kind +            (\name _num _ -> return $ NormalC (mkName name) []) +            return +  return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do +  clauses <- readArithLists kind +               (\name _num _ -> return (Clause [ConP (mkName name) [] []] +                                               (NormalB (LitE (StringL (nametrans name)))) +                                               [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) +         ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do +  clauses <- readArithLists kind +               (\name num _ -> return (Clause [ConP (mkName name) [] []] +                                              (NormalB (LitE (IntegerL (fromIntegral num)))) +                                              [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) +         ,FunD (mkName funname) clauses] | 
