aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-18 21:55:08 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-18 21:55:08 +0200
commitc65320ad151cb5b92051866d17dcda49c7174e57 (patch)
tree104f4c69def294ebb7f2e5a1d49be166674fb8ab
parent4a0b2ef27a6e31250c56faef0efc0abf611a0cda (diff)
More sensible argument order to reduce1 C op
-rw-r--r--cbits/arith.c8
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs14
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs2
3 files changed, 12 insertions, 12 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 4d60228..ea06ac1 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -164,7 +164,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
typ accum = arr[arrlinidx]; \
for (i64 i = 1; i < shape[rank - 1]; i++) { \
@@ -399,10 +399,10 @@ enum redop_tag_t {
};
#define ENTRY_REDUCE1_OPS(typ) \
- void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
switch (tag) { \
- case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
- case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \
+ case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \
default: wrong_op("reduce", tag); \
} \
}
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index d547084..9f99c3b 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -170,7 +170,7 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a)
-> (a -> b)
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> RS.Array (n + 1) a -> RS.Array n a
vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
| null sh = error "unreachable"
@@ -208,7 +208,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
- fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR)
+ fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)
TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
RS.stretch (init sh) -- replicate to original shape
. RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated
@@ -307,7 +307,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
vectorDotprodOp :: (Num a, Storable a)
=> (b -> a)
-> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
-> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel
-> RS.Array 1 a -> RS.Array 1 a -> a
@@ -338,7 +338,7 @@ vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"
{-# NOINLINE dotScalarVector #-}
dotScalarVector :: forall a b. (Num a, Storable a)
=> Int -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel
-> a -> VS.Vector a -> a
dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
alloca @a $ \pout -> do
@@ -347,7 +347,7 @@ dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
alloca @Int64 $ \pstride -> do
poke pstride 1
VS.unsafeWith vec $ \pvec ->
- fred 1 pshape pstride (ptrconv pout) (ptrconv pvec)
+ fred 1 (ptrconv pout) pshape pstride (ptrconv pvec)
res <- peek pout
return (scalar * res)
@@ -500,7 +500,7 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
=> -- int32
(Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
-- int64
-> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
@@ -535,7 +535,7 @@ intWidBranchExtr fextr32 fextr64
intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)
=> -- int32
- (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
-> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel
-> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel
-- int64
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ca96093..ef8f3cd 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -53,7 +53,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
basefull = "reducefull_" ++ atCName arithtype
sequence
[ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$>
- [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]
,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$>
[t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]])