diff options
| -rw-r--r-- | cbits/arith.c | 103 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 23 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 1 | ||||
| -rw-r--r-- | test/Tests/C.hs | 14 | 
4 files changed, 113 insertions, 28 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index b9c86ab..f08e456 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -5,6 +5,7 @@  #include <stdio.h>  #include <stdint.h> +#include <inttypes.h>  #include <stdlib.h>  #include <stdbool.h>  #include <string.h> @@ -89,38 +90,23 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }  /***************************************************************************** - *                             Kernel functions                              * + *                             Helper functions                              *   *****************************************************************************/ -#define COMM_OP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ -    for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ -  } \ -  static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ -    for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ -  } - -#define NONCOMM_OP(name, op, typ) \ -  COMM_OP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ -    for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ +__attribute__((used)) +static void print_shape(FILE *stream, i64 rank, const i64 *shape) { +  fputc('[', stream); +  for (i64 i = 0; i < rank; i++) { +    if (i != 0) fputc(',', stream); +    fprintf(stream, "%" PRIi64, shape[i]);    } +  fputc(']', stream); +} -#define PREFIX_BINOP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ -    for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ -  } \ -  static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ -    for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ -  } \ -  static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ -    for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ -  } -#define UNARY_OP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ -    for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ -  } +/***************************************************************************** + *                                Skeletons                                  * + *****************************************************************************/  // Walk a orthotope-style strided array, except for the inner dimension. The  // body is run for every "inner vector". @@ -184,6 +170,55 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }      } \    } while (false) + +/***************************************************************************** + *                             Kernel functions                              * + *****************************************************************************/ + +#define COMM_OP(name, op, typ) \ +  static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ +    for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ +  } \ +  static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ +    for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ +  } + +#define NONCOMM_OP(name, op, typ) \ +  COMM_OP(name, op, typ) \ +  static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ +    for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ +  } + +#define PREFIX_BINOP(name, op, typ) \ +  static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ +    for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ +  } \ +  static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ +    for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ +  } \ +  static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ +    for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ +  } + +#define UNARY_OP(name, op, typ) \ +  static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ +    for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ +  } + +#define UNARY_OP_STRIDED(name, op, typ) \ +  static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ +    /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \ +    print_shape(stderr, rank, shape); \ +    fprintf(stderr, " strides="); \ +    print_shape(stderr, rank, strides); \ +    fprintf(stderr, "\n"); */ \ +    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +      for (i64 i = 0; i < shape[rank - 1]; i++) { \ +        out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \ +      } \ +    }); \ +  } +  // preconditions:  // - all strides are >0  // - shape is everywhere >0 @@ -408,6 +443,16 @@ enum unop_tag_t {      } \    } +#define ENTRY_UNARY_STRIDED_OPS(typ) \ +  void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \ +    switch (tag) { \ +      case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \ +      case UO_SIGNUM: oxarop_op_signum_ ## typ ## _strided(rank, out, shape, strides, x); break; \ +      default: wrong_op("unary_strided", tag); \ +    } \ +  } +  enum funop_tag_t {  #undef LIST_FUNOP  #define LIST_FUNOP(name, id, _) name = id, @@ -484,12 +529,16 @@ enum redop_tag_t {    UNARY_OP(neg, -, typ) \    UNARY_OP(abs, GEN_ABS, typ) \    UNARY_OP(signum, GEN_SIGNUM, typ) \ +  UNARY_OP_STRIDED(neg, -, typ) \ +  UNARY_OP_STRIDED(abs, GEN_ABS, typ) \ +  UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \    REDUCE1_OP(sum1, +, typ) \    REDUCE1_OP(product1, *, typ) \    REDUCEFULL_OP(sumfull, +, typ) \    REDUCEFULL_OP(productfull, *, typ) \    ENTRY_BINARY_OPS(typ) \    ENTRY_UNARY_OPS(typ) \ +  ENTRY_UNARY_STRIDED_OPS(typ) \    ENTRY_REDUCE1_OPS(typ) \    ENTRY_REDUCEFULL_OPS(typ) \    EXTREMUM_OP(min, <, typ) \ diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 734c7cd..123a4b5 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -49,6 +49,26 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))    | otherwise = RS.fromVector sh (f (RS.toVector arr))  -- TODO: test all the cases of this thing with various input strides +{-# NOINLINE liftOpEltwise1 #-} +liftOpEltwise1 :: (Storable a, Storable b) +               => SNat n +               -> (VS.Vector a -> VS.Vector b) +               -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr a -> IO ()) +               -> RS.Array n a -> RS.Array n b +liftOpEltwise1 sn@SNat f_vec cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      let vec' = f_vec (VS.slice blockOff blockSz vec) +      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) +  | otherwise = unsafePerformIO $ do +      outv <- VSM.unsafeNew (product sh) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> +          VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> +            VS.unsafeWith vec $ \pv -> +              cf_strided (fromIntegral (fromSNat sn)) poutv psh pstrides pv +      RS.fromVector sh <$> VS.unsafeFreeze outv + +-- TODO: test all the cases of this thing with various input strides  liftVEltwise2 :: (Storable a, Storable b, Storable c)                => SNat n                -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c) @@ -421,9 +441,10 @@ $(fmap concat . forM typesList $ \arithtype -> do      fmap concat . forM [minBound..maxBound] $ \arithop -> do        let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))            c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) +          c_op_strided = varE (mkName ("c_unary_" ++ atCName arithtype ++ "_strided")) `appE` litE (integerL (fromIntegral (auoEnum arithop)))        sequence [SigD name <$>                       [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +               ,do body <- [| \sn -> liftOpEltwise1 sn (vectorOp1 id $c_op) $c_op_strided |]                     return $ FunD name [Clause [] (NormalB body) []]])  $(fmap concat . forM floatTypesList $ \arithtype -> do diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ade7ce1..22c5b53 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -16,6 +16,7 @@ $(do          ,("binary_" ++ tyn ++ "_sv",       [t| CInt -> Int64 -> Ptr $ttyp ->     $ttyp -> Ptr $ttyp -> IO () |])          ,("binary_" ++ tyn ++ "_vs",       [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp ->     $ttyp -> IO () |])          ,("unary_" ++ tyn,                 [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +        ,("unary_" ++ tyn ++ "_strided",   [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("reduce1_" ++ tyn,               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])          ,("reducefull_" ++ tyn,            [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])          ,("extremum_min_" ++ tyn,          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 0530f53..97b425f 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -97,6 +97,17 @@ prop_sum_replicated doTranspose = property $      let rarr = rfromOrthotope inrank2 arrTrans      almostEq 1e-8 (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arrTrans) +prop_negate_normalised :: Property +prop_negate_normalised = property $ +  genRank $ \rank@(SNat @n) -> do +    sh <- forAll $ genShR rank +    guard (all (> 0) (toList sh)) +    arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$> +             genStorables (Range.singleton (product sh)) +                          (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) +    let rarr = rfromOrthotope rank arr +    rtoOrthotope (negate rarr) === OR.mapA negate arr +  tests :: TestTree  tests = testGroup "C"    [testGroup "sum" @@ -106,4 +117,7 @@ tests = testGroup "C"      ,testProperty "replicated" (prop_sum_replicated False)      ,testProperty "replicated_transposed" (prop_sum_replicated True)      ] +  ,testGroup "negate" +    [testProperty "normalised" prop_negate_normalised +    ]    ] | 
