diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 | 
| commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
| tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed/Internal/Arith | |
| parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) | |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 55 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 78 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 82 | 
3 files changed, 215 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..6fc7229 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "binary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "fbinary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "unary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "funary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "reduce_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs new file mode 100644 index 0000000..a284bc1 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists.TH + + +data ArithType = ArithType +  { atType :: Name  -- ''Int32 +  , atCName :: String  -- "i32" +  } + +floatTypesList :: [ArithType] +floatTypesList = +  [ArithType ''Float "float" +  ,ArithType ''Double "double" +  ] + +typesList :: [ArithType] +typesList = +  [ArithType ''Int32 "i32" +  ,ArithType ''Int64 "i64" +  ] +  ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] +              ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] +              ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..8b7d05f --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Mixed.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | FBinop | Unop | FUnop | Redop +  deriving (Show, Eq) + +readArithLists :: OpKind +               -> (String -> Int -> String -> Q a) +               -> ([a] -> Q r) +               -> Q r +readArithLists targetkind fop fcombine = do +  addDependentFile "cbits/arith_lists.h" +  lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + +  mvals <- forM lns $ \line -> do +    if null (dropWhile (== ' ') line) +      then return Nothing +      else do let (kind, name, num, aux) = parseLine line +              if kind == targetkind +                then Just <$> fop name num aux +                else return Nothing + +  fcombine (catMaybes mvals) +  where +    parseLine s0 +      | ("LIST_", s1) <- splitAt 5 s0 +      , (kindstr, '(' : s2) <- break (== '(') s1 +      , (f1, ',' : s3) <- parseField s2 +      , (f2, ',' : s4) <- parseField s3 +      , (f3, ')' : _) <- parseField s4 +      , Just kind <- parseKind kindstr +      , let name = f1 +      , Just num <- readMaybe f2 +      , let aux = f3 +      = (kind, name, num, aux) +      | otherwise +      = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + +    parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + +    parseKind "BINOP" = Just Binop +    parseKind "FBINOP" = Just FBinop +    parseKind "UNOP" = Just Unop +    parseKind "FUNOP" = Just FUnop +    parseKind "REDOP" = Just Redop +    parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do +  cons <- readArithLists kind +            (\name _num _ -> return $ NormalC (mkName name) []) +            return +  return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do +  clauses <- readArithLists kind +               (\name _num _ -> return (Clause [ConP (mkName name) [] []] +                                               (NormalB (LitE (StringL (nametrans name)))) +                                               [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) +         ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do +  clauses <- readArithLists kind +               (\name num _ -> return (Clause [ConP (mkName name) [] []] +                                              (NormalB (LitE (IntegerL (fromIntegral num)))) +                                              [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) +         ,FunD (mkName funname) clauses] | 
