From c65320ad151cb5b92051866d17dcda49c7174e57 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 18 Jun 2024 21:55:08 +0200
Subject: More sensible argument order to reduce1 C op

---
 src/Data/Array/Mixed/Internal/Arith.hs         | 14 +++++++-------
 src/Data/Array/Mixed/Internal/Arith/Foreign.hs |  2 +-
 2 files changed, 8 insertions(+), 8 deletions(-)

(limited to 'src/Data/Array/Mixed')

diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index d547084..9f99c3b 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -170,7 +170,7 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a)
                  -> (a -> b)
                  -> (Ptr a -> Ptr b)
                  -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant
-                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel
+                 -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                  -> RS.Array (n + 1) a -> RS.Array n a
 vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
   | null sh = error "unreachable"
@@ -208,7 +208,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
              VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
                VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
                  VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
-                   fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR)
+                   fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)
            TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
              RS.stretch (init sh)  -- replicate to original shape
                . RS.reshape (init shOnes)  -- add 1-sized dimensions where the original was replicated
@@ -307,7 +307,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
 vectorDotprodOp :: (Num a, Storable a)
                 => (b -> a)
                 -> (Ptr a -> Ptr b)
-                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel
+                -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                 -> (Int64 -> Ptr b -> Ptr b -> IO b)  -- ^ dotprod kernel
                 -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b)  -- ^ strided dotprod kernel
                 -> RS.Array 1 a -> RS.Array 1 a -> a
@@ -338,7 +338,7 @@ vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"
 {-# NOINLINE dotScalarVector #-}
 dotScalarVector :: forall a b. (Num a, Storable a)
                 => Int -> (Ptr a -> Ptr b)
-                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel
+                -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel
                 -> a -> VS.Vector a -> a
 dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
   alloca @a $ \pout -> do
@@ -347,7 +347,7 @@ dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
       alloca @Int64 $ \pstride -> do
         poke pstride 1
         VS.unsafeWith vec $ \pvec ->
-          fred 1 pshape pstride (ptrconv pout) (ptrconv pvec)
+          fred 1 (ptrconv pout) pshape pstride (ptrconv pvec)
     res <- peek pout
     return (scalar * res)
 
@@ -500,7 +500,7 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
 intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
                  => -- int32
                     (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant
-                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel
+                 -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel
                     -- int64
                  -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant
                  -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel
@@ -535,7 +535,7 @@ intWidBranchExtr fextr32 fextr64
 
 intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)
                     => -- int32
-                       (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel
+                       (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel
                     -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32)  -- ^ dotprod kernel
                     -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32)  -- ^ strided dotprod kernel
                        -- int64
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ca96093..ef8f3cd 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -53,7 +53,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
         basefull = "reducefull_" ++ atCName arithtype
     sequence
       [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$>
-         [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]
+         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]
       ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$>
          [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]])
 
-- 
cgit v1.2.3-70-g09d2