diff options
| -rw-r--r-- | cbits/arith.c | 8 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 2 | 
3 files changed, 12 insertions, 12 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index 4d60228..ea06ac1 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -164,7 +164,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define REDUCE1_OP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ +  static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \      TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \        typ accum = arr[arrlinidx]; \        for (i64 i = 1; i < shape[rank - 1]; i++) { \ @@ -399,10 +399,10 @@ enum redop_tag_t {  };  #define ENTRY_REDUCE1_OPS(typ) \ -  void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ +  void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \      switch (tag) { \ -      case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \ -      case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \ +      case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \ +      case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \        default: wrong_op("reduce", tag); \      } \    } diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index d547084..9f99c3b 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -170,7 +170,7 @@ vectorRedInnerOp :: forall a b n. (Num a, Storable a)                   -> (a -> b)                   -> (Ptr a -> Ptr b)                   -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant -                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel                   -> RS.Array (n + 1) a -> RS.Array n a  vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))    | null sh = error "unreachable" @@ -208,7 +208,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->                 VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->                   VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR -> -                   fred (fromIntegral ndimsF) pshF pstridesR (ptrconv poutvR) (ptrconv pvecR) +                   fred (fromIntegral ndimsF) (ptrconv poutvR) pshF pstridesR (ptrconv pvecR)             TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->               RS.stretch (init sh)  -- replicate to original shape                 . RS.reshape (init shOnes)  -- add 1-sized dimensions where the original was replicated @@ -307,7 +307,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))  vectorDotprodOp :: (Num a, Storable a)                  => (b -> a)                  -> (Ptr a -> Ptr b) -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel                  -> (Int64 -> Ptr b -> Ptr b -> IO b)  -- ^ dotprod kernel                  -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b)  -- ^ strided dotprod kernel                  -> RS.Array 1 a -> RS.Array 1 a -> a @@ -338,7 +338,7 @@ vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"  {-# NOINLINE dotScalarVector #-}  dotScalarVector :: forall a b. (Num a, Storable a)                  => Int -> (Ptr a -> Ptr b) -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ reduction kernel                  -> a -> VS.Vector a -> a  dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do    alloca @a $ \pout -> do @@ -347,7 +347,7 @@ dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do        alloca @Int64 $ \pstride -> do          poke pstride 1          VS.unsafeWith vec $ \pvec -> -          fred 1 pshape pstride (ptrconv pout) (ptrconv pvec) +          fred 1 (ptrconv pout) pshape pstride (ptrconv pvec)      res <- peek pout      return (scalar * res) @@ -500,7 +500,7 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn  intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)                   => -- int32                      (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant -                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                 -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel                      -- int64                   -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant                   -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel @@ -535,7 +535,7 @@ intWidBranchExtr fextr32 fextr64  intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)                      => -- int32 -                       (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                       (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ reduction kernel                      -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32)  -- ^ dotprod kernel                      -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32)  -- ^ strided dotprod kernel                         -- int64 diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ca96093..ef8f3cd 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -53,7 +53,7 @@ $(fmap concat . forM typesList $ \arithtype -> do          basefull = "reducefull_" ++ atCName arithtype      sequence        [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$> -         [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |] +         [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]        ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$>           [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]]) | 
