aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
blob: ade7ce1c11259c7852dc908d351c3f195167f4c2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Mixed.Internal.Arith.Foreign where

import Data.Int
import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH

import Data.Array.Mixed.Internal.Arith.Lists


$(do
  let importsScal ttyp tyn =
        [("binary_" ++ tyn ++ "_vv",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
        ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])
        ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |])
        ,("unary_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
        ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
        ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
        ,("extremum_min_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
        ,("extremum_max_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
        ,("dotprod_" ++ tyn,               [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
        ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
        ,("dotprodinner_" ++ tyn,          [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
        ]

  let importsFloat ttyp tyn =
        [("fbinary_" ++ tyn ++ "_vv",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
        ,("fbinary_" ++ tyn ++ "_sv",      [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])
        ,("fbinary_" ++ tyn ++ "_vs",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |])
        ,("funary_" ++ tyn,                [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
        ]

  let generate types imports =
        sequence
          [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ
          | arithtype <- types
          , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)]
  decs1 <- generate typesList importsScal
  decs2 <- generate floatTypesList importsFloat
  return (decs1 ++ decs2))