diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-18 21:55:08 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-18 21:55:08 +0200 |
commit | c65320ad151cb5b92051866d17dcda49c7174e57 (patch) | |
tree | 104f4c69def294ebb7f2e5a1d49be166674fb8ab | |
parent | 4a0b2ef27a6e31250c56faef0efc0abf611a0cda (diff) |
More sensible argument order to reduce1 C op
-rw-r--r-- | cbits/arith.c | 8 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 2 |
3 files changed, 12 insertions, 12 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 4d60228..ea06ac1 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -164,7 +164,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define REDUCE1_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ + static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ @@ -399,10 +399,10 @@ enum redop_tag_t { }; #define ENTRY_REDUCE1_OPS(typ) \ - void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ + void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ switch (tag) { \ - case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \ - case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \ + case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \ + case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \ default: wrong_op("reduce", tag); \ } \ } diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index d547084..9f99c3b 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -170,7 +170,7 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a) -> (a -> b) -> (Ptr a -> Ptr b) -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> RS.Array (n + 1) a -> RS.Array n a vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) | null sh = error "unreachable" @@ -208,7 +208,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR -> VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> - fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR) + fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR) TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> RS.stretch (init sh) -- replicate to original shape . RS.reshape (init shOnes) -- add 1-sized dimensions where the original was replicated @@ -307,7 +307,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) vectorDotprodOp :: (Num a, Storable a) => (b -> a) -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel -> RS.Array 1 a -> RS.Array 1 a -> a @@ -338,7 +338,7 @@ vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" {-# NOINLINE dotScalarVector #-} dotScalarVector :: forall a b. (Num a, Storable a) => Int -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel -> a -> VS.Vector a -> a dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do alloca @a $ \pout -> do @@ -347,7 +347,7 @@ dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do alloca @Int64 $ \pstride -> do poke pstride 1 VS.unsafeWith vec $ \pvec -> - fred 1 pshape pstride (ptrconv pout) (ptrconv pvec) + fred 1 (ptrconv pout) pshape pstride (ptrconv pvec) res <- peek pout return (scalar * res) @@ -500,7 +500,7 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) => -- int32 (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel -- int64 -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel @@ -535,7 +535,7 @@ intWidBranchExtr fextr32 fextr64 intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) => -- int32 - (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel -- int64 diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ca96093..ef8f3cd 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -53,7 +53,7 @@ $(fmap concat . forM typesList $ \arithtype -> do basefull = "reducefull_" ++ atCName arithtype sequence [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$> - [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |] + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |] ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$> [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]]) |