diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-06 00:08:40 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-06 00:08:40 +0100 |
commit | e07bb2985e3befae6a093491de96965a87f0986f (patch) | |
tree | b3f14b1f7092f9b7be302245e27b9d9691986e15 | |
parent | eff6b7ba64fbe4e6e260ce3266109fd9fee27ae2 (diff) |
WIP binary ops without normalisationno-normalise-binop
-rw-r--r-- | cbits/arith.c | 138 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 158 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 30 |
3 files changed, 210 insertions, 116 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 2788e41..9aed3b4 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -175,29 +175,53 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { * Kernel functions * *****************************************************************************/ -#define COMM_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ +#define COMM_OP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \ + } \ + }); \ } \ - static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ + static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \ + } \ + }); \ } -#define NONCOMM_OP(name, op, typ) \ - COMM_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ - for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ +#define NONCOMM_OP_STRIDED(name, op, typ) \ + COMM_OP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \ + } \ + }); \ } -#define PREFIX_BINOP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ +#define PREFIX_BINOP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \ + } \ + }); \ } \ - static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ + static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \ + } \ + }); \ } \ - static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ + static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \ + } \ + }); \ } #define UNARY_OP_STRIDED(name, op, typ) \ @@ -360,29 +384,29 @@ enum binop_tag_t { #define LIST_BINOP(name, id, hsop) }; -#define ENTRY_BINARY_OPS(typ) \ - void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ +#define ENTRY_BINARY_STRIDED_OPS(typ) \ + void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ switch (tag) { \ - case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \ - case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \ - case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \ - default: wrong_op("binary_sv", tag); \ + case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("binary_sv_strided", tag); \ } \ } \ - void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ + void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ switch (tag) { \ - case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \ - case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \ - case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \ - default: wrong_op("binary_vs", tag); \ + case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \ + case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \ + default: wrong_op("binary_vs_strided", tag); \ } \ } \ - void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ + void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ switch (tag) { \ - case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \ - case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \ - case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \ - default: wrong_op("binary_vv", tag); \ + case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case BO_MUL: oxarop_op_mul_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("binary_vv_strided", tag); \ } \ } @@ -394,29 +418,29 @@ enum fbinop_tag_t { #define LIST_FBINOP(name, id, hsop) }; -#define ENTRY_FBINARY_OPS(typ) \ - void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ +#define ENTRY_FBINARY_STRIDED_OPS(typ) \ + void oxarop_fbinary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ switch (tag) { \ - case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \ - case FB_POW: oxarop_op_pow_ ## typ ## _sv(n, out, x, y); break; \ - case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv(n, out, x, y); break; \ - default: wrong_op("binary_sv", tag); \ + case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("fbinary_sv_strided", tag); \ } \ } \ - void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ + void oxarop_fbinary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ switch (tag) { \ - case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \ - case FB_POW: oxarop_op_pow_ ## typ ## _vs(n, out, x, y); break; \ - case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs(n, out, x, y); break; \ - default: wrong_op("binary_vs", tag); \ + case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + default: wrong_op("fbinary_vs_strided", tag); \ } \ } \ - void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ + void oxarop_fbinary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ switch (tag) { \ - case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \ - case FB_POW: oxarop_op_pow_ ## typ ## _vv(n, out, x, y); break; \ - case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv(n, out, x, y); break; \ - default: wrong_op("binary_vv", tag); \ + case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("fbinary_vv_strided", tag); \ } \ } @@ -469,7 +493,7 @@ enum funop_tag_t { case FU_EXPM1: oxarop_op_expm1_ ## typ ## _strided(rank, out, shape, strides, x); break; \ case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ - default: wrong_op("unary", tag); \ + default: wrong_op("funary_strided", tag); \ } \ } @@ -508,9 +532,9 @@ enum redop_tag_t { #define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST #define X(typ) \ - COMM_OP(add, +, typ) \ - NONCOMM_OP(sub, -, typ) \ - COMM_OP(mul, *, typ) \ + COMM_OP_STRIDED(add, +, typ) \ + NONCOMM_OP_STRIDED(sub, -, typ) \ + COMM_OP_STRIDED(mul, *, typ) \ UNARY_OP_STRIDED(neg, -, typ) \ UNARY_OP_STRIDED(abs, GEN_ABS, typ) \ UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \ @@ -518,7 +542,7 @@ enum redop_tag_t { REDUCE1_OP(product1, *, typ) \ REDUCEFULL_OP(sumfull, +, typ) \ REDUCEFULL_OP(productfull, *, typ) \ - ENTRY_BINARY_OPS(typ) \ + ENTRY_BINARY_STRIDED_OPS(typ) \ ENTRY_UNARY_STRIDED_OPS(typ) \ ENTRY_REDUCE1_OPS(typ) \ ENTRY_REDUCEFULL_OPS(typ) \ @@ -530,9 +554,9 @@ NUM_TYPES_XLIST #undef X #define X(typ) \ - NONCOMM_OP(fdiv, /, typ) \ - PREFIX_BINOP(pow, GEN_POW, typ) \ - PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \ + NONCOMM_OP_STRIDED(fdiv, /, typ) \ + PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \ + PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \ UNARY_OP_STRIDED(recip, 1.0/, typ) \ UNARY_OP_STRIDED(exp, GEN_EXP, typ) \ UNARY_OP_STRIDED(log, GEN_LOG, typ) \ @@ -553,7 +577,7 @@ NUM_TYPES_XLIST UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \ UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \ UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \ - ENTRY_FBINARY_OPS(typ) \ + ENTRY_FBINARY_STRIDED_OPS(typ) \ ENTRY_FUNARY_STRIDED_OPS(typ) FLOAT_TYPES_XLIST #undef X diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 11ee3fe..6253ae0 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -67,7 +67,7 @@ liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> VS.unsafeWith (VS.singleton 1) $ \pstrides -> VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> - cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) + cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv | otherwise = unsafePerformIO $ do outv <- VSM.unsafeNew (product sh) @@ -79,33 +79,52 @@ liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides RS.fromVector sh <$> VS.unsafeFreeze outv -- TODO: test all the cases of this thing with various input strides -liftVEltwise2 :: (Storable a, Storable b, Storable c) +liftVEltwise2 :: Storable a => SNat n - -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) - -> RS.Array n a -> RS.Array n b -> RS.Array n c -liftVEltwise2 SNat f + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv + -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 | product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty)) | otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of (Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars - let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) + let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2)) in RS.A (RG.A sh1 (OI.T strides1 0 vec')) + (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense - RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) - (f (Left (vec1 VS.! offset1)) (Right (VS.slice blockOff blockSz vec2))))) + let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2) + RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2' + in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec)) + + (Just (_, 1), Nothing) -> -- scalar * array + wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2 + (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar - RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) - (f (Right (VS.slice blockOff blockSz vec1)) (Left (vec2 VS.! offset2))))) + let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1) + RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS sn valconv ptrconv f_vs arr1' (vec2 VS.! offset2) + in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec)) + + (Nothing, Just (_, 1)) -> -- array * scalar + wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2) + (Just (blockOff1, blockSz1), Just (blockOff2, blockSz2)) - | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the below + | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the strides check , strides1 == strides2 -> -- dense * dense but the strides match - RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) - (f (Right (VS.slice blockOff1 blockSz1 vec1)) (Right (VS.slice blockOff2 blockSz2 vec2))))) + let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) + arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) + RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV sn ptrconv f_vv arr1' arr2' + in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec)) + (_, _) -> -- fallback case - RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + wrapBinaryVV sn ptrconv f_vv arr1 arr2 -- | Given shape vector, offset and stride vector, check whether this virtual -- vector uses a dense subarray of its backing array. If so, the first index @@ -141,6 +160,57 @@ stridesDense sh offsetNeg stridesNeg = in second ((-s) :) (flipReverseds sh' off' str') flipReverseds _ _ _ = error "flipReverseds: invalid arguments" +{-# NOINLINE wrapBinarySV #-} +wrapBinarySV :: Storable a + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) + -> a -> RS.Array n a + -> RS.Array n a +wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) = + unsafePerformIO $ do + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> + cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv) + RS.fromVector sh <$> VS.unsafeFreeze outv + +wrapBinaryVS :: Storable a + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) + -> RS.Array n a -> a + -> RS.Array n a +wrapBinaryVS sn valconv ptrconv cf_strided arr y = + wrapBinarySV sn valconv ptrconv + (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr + +-- | This function assumes that the two shapes are equal. +{-# NOINLINE wrapBinaryVV #-} +wrapBinaryVV :: Storable a + => SNat n + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) + -> RS.Array n a -> RS.Array n a + -> RS.Array n a +wrapBinaryVV sn@SNat ptrconv cf_strided + (RS.A (RG.A sh (OI.T strides1 offset1 vec1))) + (RS.A (RG.A _ (OI.T strides2 offset2 vec2))) = + unsafePerformIO $ do + outv <- VSM.unsafeNew (product sh) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 -> + VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 -> + VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 -> + VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 -> + cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2) + RS.fromVector sh <$> VS.unsafeFreeze outv + {-# NOINLINE vectorOp1 #-} vectorOp1 :: forall a b. Storable a => (Ptr a -> Ptr b) @@ -423,13 +493,13 @@ $(fmap concat . forM typesList $ \arithtype -> do fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) cnamebase = "c_binary_" ++ atCName arithtype - c_ss = varE (aboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_ss_str = varE (aboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + ,do body <- [| \sn -> liftVEltwise2 sn $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] return $ FunD name [Clause [] (NormalB body) []]]) $(fmap concat . forM floatTypesList $ \arithtype -> do @@ -437,13 +507,13 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do fmap concat . forM [minBound..maxBound] $ \arithop -> do let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) cnamebase = "c_fbinary_" ++ atCName arithtype - c_ss = varE (afboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_ss_str = varE (afboNumOp arithop) + c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + ,do body <- [| \sn -> liftVEltwise2 sn $c_ss_str $c_sv_str $c_vs_str $c_vv_str |] return $ FunD name [Clause [] (NormalB body) []]]) $(fmap concat . forM typesList $ \arithtype -> do @@ -526,17 +596,17 @@ intWidBranch1 f32 f64 sn intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) => (i -> i -> i) -- ss -- int32 - -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv + -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- vv -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32 + | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64 | otherwise = error "Unsupported Int width" intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -667,14 +737,14 @@ instance NumElt Double where instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) + (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) numEltSub = intWidBranch2 @Int (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) + (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) numEltMul = intWidBranch2 @Int (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) @@ -693,14 +763,14 @@ instance NumElt Int where instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD)) + (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD)) numEltSub = intWidBranch2 @CInt (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB)) + (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB)) numEltMul = intWidBranch2 @CInt (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL)) + (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL)) numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG)) numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS)) numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM)) diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index a60b717..fa89766 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -12,24 +12,24 @@ import Data.Array.Mixed.Internal.Arith.Lists $(do let importsScal ttyp tyn = - [("binary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) - ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) - ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) - ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) + ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) + ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) + ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ] let importsFloat ttyp tyn = - [("fbinary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,("fbinary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,("fbinary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |]) + ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ] let generate types imports = |