From 97ab8502b9cd3f7d908160d13c7d85d23c99e203 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 18 Jun 2024 21:55:35 +0200 Subject: Clean up Foreign.hs --- src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 91 ++++++++------------------ 1 file changed, 29 insertions(+), 62 deletions(-) (limited to 'src/Data/Array/Mixed/Internal/Arith') diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ef8f3cd..c1c0070 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -2,7 +2,6 @@ {-# LANGUAGE TemplateHaskell #-} module Data.Array.Mixed.Internal.Arith.Foreign where -import Control.Monad import Data.Int import Foreign.C.Types import Foreign.Ptr @@ -11,64 +10,32 @@ import Language.Haskell.TH import Data.Array.Mixed.Internal.Arith.Lists -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "binary_" ++ atCName arithtype - sequence - [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |] - ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |] - ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |] - ]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "fbinary_" ++ atCName arithtype - sequence - [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |] - ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |] - ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |] - ]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "unary_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "funary_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base1 = "reduce1_" ++ atCName arithtype - basefull = "reducefull_" ++ atCName arithtype - sequence - [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |] - ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$> - [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]]) - -$(fmap concat . forM typesList $ \arithtype -> - fmap concat . forM ["min", "max"] $ \fname -> do - let ttyp = conT (atType arithtype) - let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "dotprod_" ++ atCName arithtype - sequence - [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |] - ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$> - [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]]) +$(do + let importsScal ttyp tyn = + [("binary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) + ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) + ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) + ] + + let importsFloat ttyp tyn = + [("fbinary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("funary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ] + + let generate types imports = + sequence + [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ + | arithtype <- types + , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)] + decs1 <- generate typesList importsScal + decs2 <- generate floatTypesList importsFloat + return (decs1 ++ decs2)) -- cgit v1.2.3-70-g09d2