diff options
| -rw-r--r-- | cbits/arith.c | 93 | 
1 files changed, 35 insertions, 58 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index 5d74c01..4d60228 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -145,6 +145,17 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }      } \    } while (false) +// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on +// strides[rank-1] == 1 and a fallback case. +#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \ +  do { \ +    if (strides[rank - 1] == 1) { \ +      TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \ +    } else { \ +      TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ +    } \ +  } while (false) +  // preconditions:  // - all strides are >0  // - shape is everywhere >0 @@ -154,23 +165,13 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }  // 'out' will be filled densely in linearisation order.  #define REDUCE1_OP(name, op, typ) \    static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ -    if (strides[rank - 1] == 1) { \ -      TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ -        typ accum = arr[arrlinidx]; \ -        for (i64 i = 1; i < shape[rank - 1]; i++) { \ -          accum = accum op arr[arrlinidx + i]; \ -        } \ -        out[outlinidx] = accum; \ -      }); \ -    } else { \ -      TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ -        typ accum = arr[arrlinidx]; \ -        for (i64 i = 1; i < shape[rank - 1]; i++) { \ -          accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ -        } \ -        out[outlinidx] = accum; \ -      }); \ -    } \ +    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +      typ accum = arr[arrlinidx]; \ +      for (i64 i = 1; i < shape[rank - 1]; i++) { \ +        accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ +      } \ +      out[outlinidx] = accum; \ +    }); \    }  // preconditions @@ -180,23 +181,13 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }  #define REDUCEFULL_OP(name, op, typ) \    typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \      typ res = 0; \ -    if (strides[rank - 1] == 1) { \ -      TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ -        typ accum = arr[arrlinidx]; \ -        for (i64 i = 1; i < shape[rank - 1]; i++) { \ -          accum = accum op arr[arrlinidx + i]; \ -        } \ -        res = res op accum; \ -      }); \ -    } else { \ -      TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ -        typ accum = arr[arrlinidx]; \ -        for (i64 i = 1; i < shape[rank - 1]; i++) { \ -          accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ -        } \ -        res = res op accum; \ -      }); \ -    } \ +    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +      typ accum = arr[arrlinidx]; \ +      for (i64 i = 1; i < shape[rank - 1]; i++) { \ +        accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ +      } \ +      res = res op accum; \ +    }); \      return res; \    } @@ -209,31 +200,17 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }    void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \      typ best = arr[0]; \      memset(outidx, 0, rank * sizeof(i64)); \ -    if (strides[rank - 1] == 1) { \ -      TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ -        bool found = false; \ -        for (i64 i = 0; i < shape[rank - 1]; i++) { \ -          if (arr[arrlinidx + i] cmp best) { \ -            best = arr[arrlinidx + i]; \ -            found = true; \ -            outidx[rank - 1] = i; \ -          } \ -        } \ -        if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ -      }); \ -    } else { \ -      TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ -        bool found = false; \ -        for (i64 i = 0; i < shape[rank - 1]; i++) { \ -          if (arr[arrlinidx + i] cmp best) { \ -            best = arr[arrlinidx + strides[rank - 1] * i]; \ -            found = true; \ -            outidx[rank - 1] = i; \ -          } \ +    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +      bool found = false; \ +      for (i64 i = 0; i < shape[rank - 1]; i++) { \ +        if (arr[arrlinidx + i] cmp best) { \ +          best = arr[arrlinidx + strides[rank - 1] * i]; \ +          found = true; \ +          outidx[rank - 1] = i; \          } \ -        if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ -      }); \ -    } \ +      } \ +      if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ +    }); \    }  #define DOTPROD_OP(typ) \ | 
