diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:01:24 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:01:24 +0100 |
commit | 55036a5ea4a6e590d0404638b2823c6a4aec3fba (patch) | |
tree | 484bc377229d3edff36bd9a2a80f999bbcd2e889 /ops/Data/Array/Strided/Arith/Internal | |
parent | 5414434df62b2b196354b9748b265093c168601b (diff) |
Separate arith routines into a library
The point is that this separate library does not depend on orthotope.
Diffstat (limited to 'ops/Data/Array/Strided/Arith/Internal')
-rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Foreign.hs | 47 | ||||
-rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Lists.hs | 95 | ||||
-rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs | 83 |
3 files changed, 225 insertions, 0 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal/Foreign.hs b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs new file mode 100644 index 0000000..dad65f9 --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Strided.Arith.Internal.Foreign where + +import Data.Int +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Strided.Arith.Internal.Lists + + +$(do + let importsScal ttyp tyn = + [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) + ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ] + + let importsInt ttyp tyn = + [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ] + + let importsFloat ttyp tyn = + [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ] + + let generate types imports = + sequence + [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ + | arithtype <- types + , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] + decs1 <- generate typesList importsScal + decs2 <- generate intTypesList importsInt + decs3 <- generate floatTypesList importsFloat + return (decs1 ++ decs2 ++ decs3)) diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists.hs b/ops/Data/Array/Strided/Arith/Internal/Lists.hs new file mode 100644 index 0000000..910a77c --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Lists.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Strided.Arith.Internal.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Strided.Arith.Internal.Lists.TH + + +data ArithType = ArithType + { atType :: Name -- ''Int32 + , atCName :: String -- "i32" + } + +intTypesList :: [ArithType] +intTypesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ] + +floatTypesList :: [ArithType] +floatTypesList = + [ArithType ''Float "float" + ,ArithType ''Double "double" + ] + +typesList :: [ArithType] +typesList = intTypesList ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] + ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithIBOp = IB_QUOT deriving (Show, Enum, Bounded) +$(genArithDataType IBinop "ArithIBOp") + +$(genArithNameFun IBinop ''ArithIBOp "aiboName" (map toLower . drop 3)) +$(genArithEnumFun IBinop ''ArithIBOp "aiboEnum") + +$(do clauses <- readArithLists IBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aiboNumOp") <$> [t| ArithIBOp -> Name |] + ,return $ FunD (mkName "aiboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] + ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs new file mode 100644 index 0000000..b8f6a3d --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Lists/TH.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Strided.Arith.Internal.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | IBinop | FBinop | Unop | FUnop | Redop + deriving (Show, Eq) + +readArithLists :: OpKind + -> (String -> Int -> String -> Q a) + -> ([a] -> Q r) + -> Q r +readArithLists targetkind fop fcombine = do + addDependentFile "cbits/arith_lists.h" + lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + + mvals <- forM lns $ \line -> do + if null (dropWhile (== ' ') line) + then return Nothing + else do let (kind, name, num, aux) = parseLine line + if kind == targetkind + then Just <$> fop name num aux + else return Nothing + + fcombine (catMaybes mvals) + where + parseLine s0 + | ("LIST_", s1) <- splitAt 5 s0 + , (kindstr, '(' : s2) <- break (== '(') s1 + , (f1, ',' : s3) <- parseField s2 + , (f2, ',' : s4) <- parseField s3 + , (f3, ')' : _) <- parseField s4 + , Just kind <- parseKind kindstr + , let name = f1 + , Just num <- readMaybe f2 + , let aux = f3 + = (kind, name, num, aux) + | otherwise + = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + + parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + + parseKind "BINOP" = Just Binop + parseKind "IBINOP" = Just IBinop + parseKind "FBINOP" = Just FBinop + parseKind "UNOP" = Just Unop + parseKind "FUNOP" = Just FUnop + parseKind "REDOP" = Just Redop + parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do + cons <- readArithLists kind + (\name _num _ -> return $ NormalC (mkName name) []) + return + return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do + clauses <- readArithLists kind + (\name _num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (StringL (nametrans name)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) + ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do + clauses <- readArithLists kind + (\name num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (IntegerL (fromIntegral num)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) + ,FunD (mkName funname) clauses] |