From 205a20fd581bb7c5728fd457a15e4f78fbee9e75 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Jun 2024 10:02:59 +0200 Subject: Dot product --- src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'src/Data/Array/Mixed/Internal/Arith') diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 0bd72e8..96a85d1 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -60,3 +60,12 @@ $(fmap concat . forM typesList $ \arithtype -> let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "dotprod_" ++ atCName arithtype + sequence + [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |] + ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$> + [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]]) -- cgit v1.2.3-70-g09d2