diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:01:24 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:01:24 +0100 |
commit | 55036a5ea4a6e590d0404638b2823c6a4aec3fba (patch) | |
tree | 484bc377229d3edff36bd9a2a80f999bbcd2e889 /ops/Data/Array/Strided/Arith/Internal/Foreign.hs | |
parent | 5414434df62b2b196354b9748b265093c168601b (diff) |
Separate arith routines into a library
The point is that this separate library does not depend on orthotope.
Diffstat (limited to 'ops/Data/Array/Strided/Arith/Internal/Foreign.hs')
-rw-r--r-- | ops/Data/Array/Strided/Arith/Internal/Foreign.hs | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal/Foreign.hs b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs new file mode 100644 index 0000000..dad65f9 --- /dev/null +++ b/ops/Data/Array/Strided/Arith/Internal/Foreign.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Strided.Arith.Internal.Foreign where + +import Data.Int +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Strided.Arith.Internal.Lists + + +$(do + let importsScal ttyp tyn = + [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) + ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ] + + let importsInt ttyp tyn = + [("ibinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("ibinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ] + + let importsFloat ttyp tyn = + [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ] + + let generate types imports = + sequence + [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ + | arithtype <- types + , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] + decs1 <- generate typesList importsScal + decs2 <- generate intTypesList importsInt + decs3 <- generate floatTypesList importsFloat + return (decs1 ++ decs2 ++ decs3)) |