diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 87 | 
1 files changed, 27 insertions, 60 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ef8f3cd..c1c0070 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -2,7 +2,6 @@  {-# LANGUAGE TemplateHaskell #-}  module Data.Array.Mixed.Internal.Arith.Foreign where -import Control.Monad  import Data.Int  import Foreign.C.Types  import Foreign.Ptr @@ -11,64 +10,32 @@ import Language.Haskell.TH  import Data.Array.Mixed.Internal.Arith.Lists -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "binary_" ++ atCName arithtype -    sequence -      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |] -      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |] -      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |] -      ]) +$(do +  let importsScal ttyp tyn = +        [("binary_" ++ tyn ++ "_vv",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |]) +        ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) +        ,("unary_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) +        ,("extremum_min_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("extremum_max_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) +        ,("dotprod_" ++ tyn,               [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) +        ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) +        ] -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "fbinary_" ++ atCName arithtype -    sequence -      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |] -      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |] -      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |] -      ]) +  let importsFloat ttyp tyn = +        [("fbinary_" ++ tyn ++ "_vv",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("fbinary_" ++ tyn ++ "_sv",      [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |]) +        ,("fbinary_" ++ tyn ++ "_vs",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |]) +        ,("funary_" ++ tyn,                [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ] -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "unary_" ++ atCName arithtype -    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "funary_" ++ atCName arithtype -    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base1 = "reduce1_" ++ atCName arithtype -        basefull = "reducefull_" ++ atCName arithtype -    sequence -      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$> -         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |] -      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$> -         [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]]) - -$(fmap concat . forM typesList $ \arithtype -> -    fmap concat . forM ["min", "max"] $ \fname -> do -      let ttyp = conT (atType arithtype) -      let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype -      pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -        [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "dotprod_" ++ atCName arithtype -    sequence -      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -         [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |] -      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$> -         [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]]) +  let generate types imports = +        sequence +          [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ +          | arithtype <- types +          , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] +  decs1 <- generate typesList importsScal +  decs2 <- generate floatTypesList importsFloat +  return (decs1 ++ decs2)) | 
