From 97ab8502b9cd3f7d908160d13c7d85d23c99e203 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 18 Jun 2024 21:55:35 +0200
Subject: Clean up Foreign.hs

---
 src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 91 ++++++++------------------
 1 file changed, 29 insertions(+), 62 deletions(-)

diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ef8f3cd..c1c0070 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -2,7 +2,6 @@
 {-# LANGUAGE TemplateHaskell #-}
 module Data.Array.Mixed.Internal.Arith.Foreign where
 
-import Control.Monad
 import Data.Int
 import Foreign.C.Types
 import Foreign.Ptr
@@ -11,64 +10,32 @@ import Language.Haskell.TH
 import Data.Array.Mixed.Internal.Arith.Lists
 
 
-$(fmap concat . forM typesList $ \arithtype -> do
-    let ttyp = conT (atType arithtype)
-    let base = "binary_" ++ atCName arithtype
-    sequence
-      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]
-      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]
-      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]
-      ])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
-    let ttyp = conT (atType arithtype)
-    let base = "fbinary_" ++ atCName arithtype
-    sequence
-      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]
-      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]
-      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]
-      ])
-
-$(fmap concat . forM typesList $ \arithtype -> do
-    let ttyp = conT (atType arithtype)
-    let base = "unary_" ++ atCName arithtype
-    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
-      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
-    let ttyp = conT (atType arithtype)
-    let base = "funary_" ++ atCName arithtype
-    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
-      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
-
-$(fmap concat . forM typesList $ \arithtype -> do
-    let ttyp = conT (atType arithtype)
-    let base1 = "reduce1_" ++ atCName arithtype
-        basefull = "reducefull_" ++ atCName arithtype
-    sequence
-      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$>
-         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]
-      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$>
-         [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]])
-
-$(fmap concat . forM typesList $ \arithtype ->
-    fmap concat . forM ["min", "max"] $ \fname -> do
-      let ttyp = conT (atType arithtype)
-      let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype
-      pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
-        [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
-
-$(fmap concat . forM typesList $ \arithtype -> do
-    let ttyp = conT (atType arithtype)
-    let base = "dotprod_" ++ atCName arithtype
-    sequence
-      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
-         [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]
-      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$>
-         [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]])
+$(do
+  let importsScal ttyp tyn =
+        [("binary_" ++ tyn ++ "_vv",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+        ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])
+        ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |])
+        ,("unary_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+        ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+        ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
+        ,("extremum_min_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+        ,("extremum_max_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+        ,("dotprod_" ++ tyn,               [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
+        ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
+        ]
+
+  let importsFloat ttyp tyn =
+        [("fbinary_" ++ tyn ++ "_vv",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+        ,("fbinary_" ++ tyn ++ "_sv",      [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])
+        ,("fbinary_" ++ tyn ++ "_vs",      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |])
+        ,("funary_" ++ tyn,                [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+        ]
+
+  let generate types imports =
+        sequence
+          [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ name) (mkName ("c_" ++ name)) <$> typ
+          | arithtype <- types
+          , (name, typ) <- imports (conT (atType arithtype)) (atCName arithtype)]
+  decs1 <- generate typesList importsScal
+  decs2 <- generate floatTypesList importsFloat
+  return (decs1 ++ decs2))
-- 
cgit v1.2.3-70-g09d2