1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Mixed.Internal.Arith.Foreign where
import Data.Int
import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
import Data.Array.Mixed.Internal.Arith.Lists
$(do
let importsScal ttyp tyn =
[("binary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
]
let importsFloat ttyp tyn =
[("fbinary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,("fbinary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
,("fbinary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
,("funary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
]
let generate types imports =
sequence
[ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ
| arithtype <- types
, (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)]
decs1 <- generate typesList importsScal
decs2 <- generate floatTypesList importsFloat
return (decs1 ++ decs2))
|