diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-18 22:32:16 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-18 22:32:16 +0100 | 
| commit | cb758277b3fa2d74551c45340b8ff0539713078c (patch) | |
| tree | 4adb951118b70613b49f638c539282a8d28da2f0 | |
| parent | 27c2823387b21e8ed801e4d8eeb0b3e5588a2920 (diff) | |
Arith statistics collection from C
| -rw-r--r-- | bench/Main.hs | 11 | ||||
| -rw-r--r-- | cbits/arith.c | 209 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 | 
3 files changed, 233 insertions, 1 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 7f1cbad..6e83270 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -3,6 +3,7 @@  {-# LANGUAGE TypeApplications #-}  module Main where +import Control.Exception (bracket)  import Data.Array.RankedS qualified as RS  import Data.Vector.Storable qualified as VS  import Numeric.LinearAlgebra qualified as LA @@ -11,6 +12,7 @@ import Test.Tasty.Bench  import Data.Array.Nested  import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2)  import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2) +import qualified Data.Array.Mixed.Internal.Arith as Arith  enableMisc :: Bool @@ -22,7 +24,14 @@ bgroupIf False = \name _ -> bgroup name []  main :: IO () -main = defaultMain +main = +  bracket (Arith.statisticsEnable False) +          (\() -> do Arith.statisticsEnable False +                     Arith.statisticsPrintAll) +          (\() -> main_tests) + +main_tests :: IO () +main_tests = defaultMain    [bgroup "Num"      [bench "sum(+) Double [1e6]" $        let n = 1_000_000 diff --git a/cbits/arith.c b/cbits/arith.c index 6380776..c3e34ad 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -3,8 +3,11 @@  #include <inttypes.h>  #include <stdlib.h>  #include <stdbool.h> +#include <stdatomic.h>  #include <string.h>  #include <math.h> +#include <threads.h> +#include <sys/time.h>  // These are the wrapper macros used in arith_lists.h. Preset them to empty to  // avoid having to touch macros unrelated to the particular operation set below. @@ -20,6 +23,210 @@  typedef int32_t i32;  typedef int64_t i64; + +/***************************************************************************** + *                          Performance statistics                           * + *****************************************************************************/ + +// Each block holds a buffer with variable-length messages. Each message starts +// with a tag byte; the respective sublists below give the fields after that tag +// byte. +// - 1: unary operation performance measurement +//   - u8: some identifier +//   - i32: input rank +//   - i64[rank]: input shape +//   - i64[rank]: input strides +//   - f64: seconds taken +// - 2: binary operation performance measurement +//   - u8: a stats_binary_id +//   - i32: input rank +//   - i64[rank]: input shape +//   - i64[rank]: input 1 strides +//   - i64[rank]: input 2 strides +//   - f64: seconds taken +// The 'prev' and 'cap' fields are set only once on creation of a block, and can +// thus be read without restrictions. The 'len' field is potentially mutated +// from different threads and must be handled with care. +struct stats_block { +  struct stats_block *prev;  // backwards linked list; NULL if first block +  size_t cap;  // bytes capacity of buffer in this block +  atomic_size_t len;  // bytes filled in this buffer +  uint8_t buf[];  // trailing VLA +}; + +enum stats_binary_id : uint8_t { +  sbi_dotprod = 1, +}; + +// Atomic because blocks may be allocated from different threads. +static _Atomic(struct stats_block*) stats_current = NULL; +static atomic_bool stats_enabled = false; + +void oxarrays_stats_enable(i32 yes) { atomic_store(&stats_enabled, yes == 1); } + +static uint8_t* stats_alloc(size_t nbytes) { +try_again: ; +  struct stats_block *block = atomic_load(&stats_current); +  size_t curlen = block != NULL ? atomic_load(&block->len) : 0; +  size_t curcap = block != NULL ? block->cap : 0; + +  if (block == NULL || curlen + nbytes > curcap) { +    const size_t newcap = stats_current == NULL ? 4096 : 2 * stats_current->cap; +    struct stats_block *new = malloc(sizeof(struct stats_block) + newcap); +    new->prev = stats_current; +    curcap = new->cap = newcap; +    curlen = new->len = 0; +    if (!atomic_compare_exchange_strong(&stats_current, &block, new)) { +      // Race condition, simply free this memory block and try again +      free(new); +      goto try_again; +    } +    block = new; +  } + +  // Try to update the 'len' field of the block we captured at the start of the +  // function. Note that it doesn't matter if someone else already allocated a +  // new block in the meantime; we're still accessing the same block here, which +  // may succeed or fail independently. +  while (!atomic_compare_exchange_strong(&block->len, &curlen, curlen + nbytes)) { +    // curlen was updated to the actual value. +    // If the block got full in the meantime, try again from the start +    if (curlen + nbytes > curcap) goto try_again; +  } + +  return block->buf + curlen; +} + +__attribute__((unused)) +static void stats_record_unary(uint8_t id, i32 rank, const i64 *shape, const i64 *strides, double secs) { +  if (!atomic_load(&stats_enabled)) return; +  uint8_t *buf = stats_alloc(1 + 1 + 4 + 2*rank*8 + 8); +  *buf = 1; buf += 1; +  *buf = id; buf += 1; +  *(i32*)buf = rank; buf += 4; +  memcpy((i64*)buf, shape, rank * 8); buf += rank * 8; +  memcpy((i64*)buf, strides, rank * 8); buf += rank * 8; +  *(double*)buf = secs; +} + +__attribute__((unused)) +static void stats_record_binary(uint8_t id, i32 rank, const i64 *shape, const i64 *strides1, const i64 *strides2, double secs) { +  if (!atomic_load(&stats_enabled)) return; +  uint8_t *buf = stats_alloc(1 + 1 + 4 + 3*rank*8 + 8); +  *buf = 2; buf += 1; +  *buf = id; buf += 1; +  *(i32*)buf = rank; buf += 4; +  memcpy((i64*)buf, shape, rank * 8); buf += rank * 8; +  memcpy((i64*)buf, strides1, rank * 8); buf += rank * 8; +  memcpy((i64*)buf, strides2, rank * 8); buf += rank * 8; +  *(double*)buf = secs; +} + +#define TIME_START(varname_) \ +  struct timeval varname_ ## _start, varname_ ## _end; \ +  gettimeofday(&varname_ ## _start, NULL); +#define TIME_END(varname_) \ +  (gettimeofday(&varname_ ## _end, NULL), \ +   ((varname_ ## _end).tv_sec - (varname_ ## _start).tv_sec) + \ +      ((varname_ ## _end).tv_usec - (varname_ ## _start).tv_usec) / (double)1e6) + +static size_t stats_print_unary(uint8_t *buf) { +  uint8_t *orig_buf = buf; + +  uint8_t id = *buf; buf += 1; +  i32 rank = *(i32*)buf; buf += 4; +  i64 *shape = (i64*)buf; buf += rank * 8; +  i64 *strides = (i64*)buf; buf += rank * 8; +  double secs = *(double*)buf; buf += 8; + +  printf("unary %d sh=[", (int)id); +  for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); } +  printf("] str=["); +  for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides[i]); } +  printf("] ms=%lf\n", secs * 1000); + +  return buf - orig_buf; +} + +static size_t stats_print_binary(uint8_t *buf) { +  uint8_t *orig_buf = buf; + +  uint8_t id = *buf; buf += 1; +  i32 rank = *(i32*)buf; buf += 4; +  i64 *shape = (i64*)buf; buf += rank * 8; +  i64 *strides1 = (i64*)buf; buf += rank * 8; +  i64 *strides2 = (i64*)buf; buf += rank * 8; +  double secs = *(double*)buf; buf += 8; + +  printf("binary %d sh=[", (int)id); +  for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); } +  printf("] str1=["); +  for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides1[i]); } +  printf("] str2=["); +  for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides2[i]); } +  printf("] ms=%lf\n", secs * 1000); + +  return buf - orig_buf; +} + +// Also frees the printed log. +void oxarrays_stats_print_all(void) { +  printf("=== ox-arrays-arith-stats start ===\n"); + +  // Claim the entire chain and prevent new blocks from being added to it. +  // (This is technically slightly wrong because a value may still be in the +  // process of being recorded to some blocks in the chain while we're doing +  // this printing, but yolo) +  struct stats_block *last = atomic_exchange(&stats_current, NULL); + +  // Reverse the linked list; after this loop, the 'prev' pointers point to the +  // _next_ block, not the previous one. +  struct stats_block *block = last; +  if (last != NULL) { +    struct stats_block *next = NULL; +    //                     block    next +    //   #####  <-#####  <-#####    NULL +    while (block->prev != NULL) { +      struct stats_block *prev = block->prev; +      //          prev     block    next +      // #####  <-#####  <-#####    ##... +      block->prev = next; +      //          prev     block    next +      // #####  <-#####    #####->  ##... +      next = block; +      //          prev     bl=nx +      // #####  <-#####    #####->  ##... +      block = prev; +      //          block    next +      // #####  <-#####    #####->  ##... +    } +    //            block    next +    //    NULL  <-#####    #####->  ##... +    block->prev = next; +    //            block    next +    //    NULL    #####->  #####->  ##... +  } + +  while (block != NULL) { +    for (size_t i = 0; i < block->len; ) { +      switch (block->buf[i]) { +        case 1: i += 1 + stats_print_unary(block->buf + i+1); break; +        case 2: i += 1 + stats_print_binary(block->buf + i+1); break; +        default: +          printf("# UNKNOWN ENTRY WITH ID %d, SKIPPING BLOCK\n", (int)block->buf[i]); +          i = block->len; +          break; +      } +    } +    struct stats_block *next = block->prev;  // remember, reversed! +    free(block); +    block = next; +  } + +  printf("=== ox-arrays-arith-stats end ===\n"); +} + +  /*****************************************************************************   *                         Additional math functions                         *   *****************************************************************************/ @@ -325,6 +532,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  // 'out' will be filled densely in linearisation order.  #define DOTPROD_INNER_OP(typ) \    void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ +    TIME_START(tm); \      if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \        TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \          out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \ @@ -339,6 +547,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {          out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], strides1[rank - 1], arr1 + arrlinidx1, strides2[rank - 1], arr2 + arrlinidx2); \        }); \      } \ +    stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \    } diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9c560d6..27ebb64 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -27,6 +27,7 @@ import Foreign.Storable (Storable(sizeOf), peek, poke)  import GHC.TypeLits  import GHC.TypeNats qualified as TypeNats  import Language.Haskell.TH +import System.IO (hFlush, stdout)  import System.IO.Unsafe  import Data.Array.Mixed.Internal.Arith.Foreign @@ -603,6 +604,19 @@ $(fmap concat . forM typesList $ \arithtype -> do               ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |]                   return $ FunD name [Clause [] (NormalB body) []]]) +foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO () +foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO () + +statisticsEnable :: Bool -> IO () +statisticsEnable b = c_stats_enable (if b then 1 else 0) + +-- | Consumes the log: one particular event will only ever be printed once, +-- even if statisticsPrintAll is called multiple times. +statisticsPrintAll :: IO () +statisticsPrintAll = do +  hFlush stdout  -- lower the chance of overlapping output +  c_stats_print_all +  -- This branch is ostensibly a runtime branch, but will (hopefully) be  -- constant-folded away by GHC.  intWidBranch1 :: forall i n. (FiniteBits i, Storable i)  | 
