{-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE TemplateHaskell #-} module Data.Array.Mixed.Internal.Arith.Foreign where import Data.Int import Foreign.C.Types import Foreign.Ptr import Language.Haskell.TH import Data.Array.Mixed.Internal.Arith.Lists $(do let importsScal ttyp tyn = [("binary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) ,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) ,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) ,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ] let importsFloat ttyp tyn = [("fbinary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) ,("fbinary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) ,("fbinary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) ,("funary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) ] let generate types imports = sequence [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ | arithtype <- types , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] decs1 <- generate typesList importsScal decs2 <- generate floatTypesList importsFloat return (decs1 ++ decs2))