aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-18 21:55:08 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-18 21:55:08 +0200
commitc65320ad151cb5b92051866d17dcda49c7174e57 (patch)
tree104f4c69def294ebb7f2e5a1d49be166674fb8ab /src/Data/Array/Mixed/Internal
parent4a0b2ef27a6e31250c56faef0efc0abf611a0cda (diff)
More sensible argument order to reduce1 C op
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs14
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs2
2 files changed, 8 insertions, 8 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index d547084..9f99c3b 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -170,7 +170,7 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a)
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> RS.Array (n + 1) a -> RS.Array n a
vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
| null sh = error "unreachable"
@@ -208,7 +208,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR)
+ fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)
TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
RS.stretch (init sh) -- replicate to original shape
. RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated
@@ -307,7 +307,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
vectorDotprodOp :: (Num a, Storable a)
=> (b -> a)
-> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
-> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel
-> RS.Array 1 a -> RS.Array 1 a -> a
@@ -338,7 +338,7 @@ vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"
{-# NOINLINE dotScalarVector #-}
dotScalarVector :: forall a b. (Num a, Storable a)
=> Int -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> a -> VS.Vector a -> a
dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
alloca @a $ \pout -> do
@@ -347,7 +347,7 @@ dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
alloca @Int64 $ \pstride -> do
poke pstride 1
VS.unsafeWith vec $ \pvec ->
- fred 1 pshape pstride (ptrconv pout) (ptrconv pvec)
+ fred 1 (ptrconv pout) pshape pstride (ptrconv pvec)
res <- peek pout
return (scalar * res)
@@ -500,7 +500,7 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
=> -- int32
(Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
-- int64
-> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
@@ -535,7 +535,7 @@ intWidBranchExtr fextr32 fextr64
intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)
=> -- int32
- (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
-> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel
-> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel
-- int64
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ca96093..ef8f3cd 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -53,7 +53,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
basefull = "reducefull_" ++ atCName arithtype
sequence
[ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$>
- [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]
,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$>
[t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]])