aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-16 00:30:25 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-16 00:30:25 +0100
commitc14017f4bc28951be7e298d01769b5b49384a7c3 (patch)
treedd7ea8e90b28e37ac46251d11be2eb6c0ffc699b
parentb0fae0894f4440c6cd9cd74b5a3515baa8bd8c35 (diff)
arith: Unary int ops on strided arrays without normalisation
-rw-r--r--cbits/arith.c103
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs23
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs1
-rw-r--r--test/Tests/C.hs14
4 files changed, 113 insertions, 28 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index b9c86ab..f08e456 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -5,6 +5,7 @@
#include <stdio.h>
#include <stdint.h>
+#include <inttypes.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
@@ -89,38 +90,23 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
/*****************************************************************************
- * Kernel functions *
+ * Helper functions *
*****************************************************************************/
-#define COMM_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \
- } \
- static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \
- }
-
-#define NONCOMM_OP(name, op, typ) \
- COMM_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
- for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \
+__attribute__((used))
+static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
+ fputc('[', stream);
+ for (i64 i = 0; i < rank; i++) {
+ if (i != 0) fputc(',', stream);
+ fprintf(stream, "%" PRIi64, shape[i]);
}
+ fputc(']', stream);
+}
-#define PREFIX_BINOP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \
- } \
- static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \
- } \
- static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \
- }
-#define UNARY_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \
- }
+/*****************************************************************************
+ * Skeletons *
+ *****************************************************************************/
// Walk a orthotope-style strided array, except for the inner dimension. The
// body is run for every "inner vector".
@@ -184,6 +170,55 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
} \
} while (false)
+
+/*****************************************************************************
+ * Kernel functions *
+ *****************************************************************************/
+
+#define COMM_OP(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
+ for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
+ for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \
+ }
+
+#define NONCOMM_OP(name, op, typ) \
+ COMM_OP(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
+ for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \
+ }
+
+#define PREFIX_BINOP(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \
+ }
+
+#define UNARY_OP(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \
+ }
+
+#define UNARY_OP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \
+ print_shape(stderr, rank, shape); \
+ fprintf(stderr, " strides="); \
+ print_shape(stderr, rank, strides); \
+ fprintf(stderr, "\n"); */ \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \
+ } \
+ }); \
+ }
+
// preconditions:
// - all strides are >0
// - shape is everywhere >0
@@ -408,6 +443,16 @@ enum unop_tag_t {
} \
}
+#define ENTRY_UNARY_STRIDED_OPS(typ) \
+ void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \
+ switch (tag) { \
+ case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case UO_SIGNUM: oxarop_op_signum_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ default: wrong_op("unary_strided", tag); \
+ } \
+ }
+
enum funop_tag_t {
#undef LIST_FUNOP
#define LIST_FUNOP(name, id, _) name = id,
@@ -484,12 +529,16 @@ enum redop_tag_t {
UNARY_OP(neg, -, typ) \
UNARY_OP(abs, GEN_ABS, typ) \
UNARY_OP(signum, GEN_SIGNUM, typ) \
+ UNARY_OP_STRIDED(neg, -, typ) \
+ UNARY_OP_STRIDED(abs, GEN_ABS, typ) \
+ UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \
REDUCE1_OP(sum1, +, typ) \
REDUCE1_OP(product1, *, typ) \
REDUCEFULL_OP(sumfull, +, typ) \
REDUCEFULL_OP(productfull, *, typ) \
ENTRY_BINARY_OPS(typ) \
ENTRY_UNARY_OPS(typ) \
+ ENTRY_UNARY_STRIDED_OPS(typ) \
ENTRY_REDUCE1_OPS(typ) \
ENTRY_REDUCEFULL_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 734c7cd..123a4b5 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -49,6 +49,26 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-- TODO: test all the cases of this thing with various input strides
+{-# NOINLINE liftOpEltwise1 #-}
+liftOpEltwise1 :: (Storable a, Storable b)
+ => SNat n
+ -> (VS.Vector a -> VS.Vector b)
+ -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ())
+ -> RS.Array n a -> RS.Array n b
+liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ let vec' = f_vec (VS.slice blockOff blockSz vec)
+ in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
+ | otherwise = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith vec $ \pv ->
+ cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv
+ RS.fromVector sh <$> VS.unsafeFreeze outv
+
+-- TODO: test all the cases of this thing with various input strides
liftVEltwise2 :: (Storable a, Storable b, Storable c)
=> SNat n
-> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c)
@@ -421,9 +441,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM floatTypesList $ \arithtype -> do
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ade7ce1..22c5b53 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -16,6 +16,7 @@ $(do
,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
,("unary_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 0530f53..97b425f 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -97,6 +97,17 @@ prop_sum_replicated doTranspose = property $
let rarr = rfromOrthotope inrank2 arrTrans
almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans)
+prop_negate_normalised :: Property
+prop_negate_normalised = property $
+ genRank $ \rank@(SNat @n) -> do
+ sh <- forAll $ genShR rank
+ guard (all (> 0) (toList sh))
+ arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$>
+ genStorables (Range.singleton (product sh))
+ (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
+ let rarr = rfromOrthotope rank arr
+ rtoOrthotope (negate rarr) === OR.mapA negate arr
+
tests :: TestTree
tests = testGroup "C"
[testGroup "sum"
@@ -106,4 +117,7 @@ tests = testGroup "C"
,testProperty "replicated" (prop_sum_replicated False)
,testProperty "replicated_transposed" (prop_sum_replicated True)
]
+ ,testGroup "negate"
+ [testProperty "normalised" prop_negate_normalised
+ ]
]