diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-18 22:32:16 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-18 22:32:16 +0100 |
commit | cb758277b3fa2d74551c45340b8ff0539713078c (patch) | |
tree | 4adb951118b70613b49f638c539282a8d28da2f0 | |
parent | 27c2823387b21e8ed801e4d8eeb0b3e5588a2920 (diff) |
Arith statistics collection from C
-rw-r--r-- | bench/Main.hs | 11 | ||||
-rw-r--r-- | cbits/arith.c | 209 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 14 |
3 files changed, 233 insertions, 1 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 7f1cbad..6e83270 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -3,6 +3,7 @@ {-# LANGUAGE TypeApplications #-} module Main where +import Control.Exception (bracket) import Data.Array.RankedS qualified as RS import Data.Vector.Storable qualified as VS import Numeric.LinearAlgebra qualified as LA @@ -11,6 +12,7 @@ import Test.Tasty.Bench import Data.Array.Nested import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2) import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2) +import qualified Data.Array.Mixed.Internal.Arith as Arith enableMisc :: Bool @@ -22,7 +24,14 @@ bgroupIf False = \name _ -> bgroup name [] main :: IO () -main = defaultMain +main = + bracket (Arith.statisticsEnable False) + (\() -> do Arith.statisticsEnable False + Arith.statisticsPrintAll) + (\() -> main_tests) + +main_tests :: IO () +main_tests = defaultMain [bgroup "Num" [bench "sum(+) Double [1e6]" $ let n = 1_000_000 diff --git a/cbits/arith.c b/cbits/arith.c index 6380776..c3e34ad 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -3,8 +3,11 @@ #include <inttypes.h> #include <stdlib.h> #include <stdbool.h> +#include <stdatomic.h> #include <string.h> #include <math.h> +#include <threads.h> +#include <sys/time.h> // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. @@ -20,6 +23,210 @@ typedef int32_t i32; typedef int64_t i64; + +/***************************************************************************** + * Performance statistics * + *****************************************************************************/ + +// Each block holds a buffer with variable-length messages. Each message starts +// with a tag byte; the respective sublists below give the fields after that tag +// byte. +// - 1: unary operation performance measurement +// - u8: some identifier +// - i32: input rank +// - i64[rank]: input shape +// - i64[rank]: input strides +// - f64: seconds taken +// - 2: binary operation performance measurement +// - u8: a stats_binary_id +// - i32: input rank +// - i64[rank]: input shape +// - i64[rank]: input 1 strides +// - i64[rank]: input 2 strides +// - f64: seconds taken +// The 'prev' and 'cap' fields are set only once on creation of a block, and can +// thus be read without restrictions. The 'len' field is potentially mutated +// from different threads and must be handled with care. +struct stats_block { + struct stats_block *prev; // backwards linked list; NULL if first block + size_t cap; // bytes capacity of buffer in this block + atomic_size_t len; // bytes filled in this buffer + uint8_t buf[]; // trailing VLA +}; + +enum stats_binary_id : uint8_t { + sbi_dotprod = 1, +}; + +// Atomic because blocks may be allocated from different threads. +static _Atomic(struct stats_block*) stats_current = NULL; +static atomic_bool stats_enabled = false; + +void oxarrays_stats_enable(i32 yes) { atomic_store(&stats_enabled, yes == 1); } + +static uint8_t* stats_alloc(size_t nbytes) { +try_again: ; + struct stats_block *block = atomic_load(&stats_current); + size_t curlen = block != NULL ? atomic_load(&block->len) : 0; + size_t curcap = block != NULL ? block->cap : 0; + + if (block == NULL || curlen + nbytes > curcap) { + const size_t newcap = stats_current == NULL ? 4096 : 2 * stats_current->cap; + struct stats_block *new = malloc(sizeof(struct stats_block) + newcap); + new->prev = stats_current; + curcap = new->cap = newcap; + curlen = new->len = 0; + if (!atomic_compare_exchange_strong(&stats_current, &block, new)) { + // Race condition, simply free this memory block and try again + free(new); + goto try_again; + } + block = new; + } + + // Try to update the 'len' field of the block we captured at the start of the + // function. Note that it doesn't matter if someone else already allocated a + // new block in the meantime; we're still accessing the same block here, which + // may succeed or fail independently. + while (!atomic_compare_exchange_strong(&block->len, &curlen, curlen + nbytes)) { + // curlen was updated to the actual value. + // If the block got full in the meantime, try again from the start + if (curlen + nbytes > curcap) goto try_again; + } + + return block->buf + curlen; +} + +__attribute__((unused)) +static void stats_record_unary(uint8_t id, i32 rank, const i64 *shape, const i64 *strides, double secs) { + if (!atomic_load(&stats_enabled)) return; + uint8_t *buf = stats_alloc(1 + 1 + 4 + 2*rank*8 + 8); + *buf = 1; buf += 1; + *buf = id; buf += 1; + *(i32*)buf = rank; buf += 4; + memcpy((i64*)buf, shape, rank * 8); buf += rank * 8; + memcpy((i64*)buf, strides, rank * 8); buf += rank * 8; + *(double*)buf = secs; +} + +__attribute__((unused)) +static void stats_record_binary(uint8_t id, i32 rank, const i64 *shape, const i64 *strides1, const i64 *strides2, double secs) { + if (!atomic_load(&stats_enabled)) return; + uint8_t *buf = stats_alloc(1 + 1 + 4 + 3*rank*8 + 8); + *buf = 2; buf += 1; + *buf = id; buf += 1; + *(i32*)buf = rank; buf += 4; + memcpy((i64*)buf, shape, rank * 8); buf += rank * 8; + memcpy((i64*)buf, strides1, rank * 8); buf += rank * 8; + memcpy((i64*)buf, strides2, rank * 8); buf += rank * 8; + *(double*)buf = secs; +} + +#define TIME_START(varname_) \ + struct timeval varname_ ## _start, varname_ ## _end; \ + gettimeofday(&varname_ ## _start, NULL); +#define TIME_END(varname_) \ + (gettimeofday(&varname_ ## _end, NULL), \ + ((varname_ ## _end).tv_sec - (varname_ ## _start).tv_sec) + \ + ((varname_ ## _end).tv_usec - (varname_ ## _start).tv_usec) / (double)1e6) + +static size_t stats_print_unary(uint8_t *buf) { + uint8_t *orig_buf = buf; + + uint8_t id = *buf; buf += 1; + i32 rank = *(i32*)buf; buf += 4; + i64 *shape = (i64*)buf; buf += rank * 8; + i64 *strides = (i64*)buf; buf += rank * 8; + double secs = *(double*)buf; buf += 8; + + printf("unary %d sh=[", (int)id); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); } + printf("] str=["); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides[i]); } + printf("] ms=%lf\n", secs * 1000); + + return buf - orig_buf; +} + +static size_t stats_print_binary(uint8_t *buf) { + uint8_t *orig_buf = buf; + + uint8_t id = *buf; buf += 1; + i32 rank = *(i32*)buf; buf += 4; + i64 *shape = (i64*)buf; buf += rank * 8; + i64 *strides1 = (i64*)buf; buf += rank * 8; + i64 *strides2 = (i64*)buf; buf += rank * 8; + double secs = *(double*)buf; buf += 8; + + printf("binary %d sh=[", (int)id); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, shape[i]); } + printf("] str1=["); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides1[i]); } + printf("] str2=["); + for (i32 i = 0; i < rank; i++) { if (i > 0) putchar(','); printf("%" PRIi64, strides2[i]); } + printf("] ms=%lf\n", secs * 1000); + + return buf - orig_buf; +} + +// Also frees the printed log. +void oxarrays_stats_print_all(void) { + printf("=== ox-arrays-arith-stats start ===\n"); + + // Claim the entire chain and prevent new blocks from being added to it. + // (This is technically slightly wrong because a value may still be in the + // process of being recorded to some blocks in the chain while we're doing + // this printing, but yolo) + struct stats_block *last = atomic_exchange(&stats_current, NULL); + + // Reverse the linked list; after this loop, the 'prev' pointers point to the + // _next_ block, not the previous one. + struct stats_block *block = last; + if (last != NULL) { + struct stats_block *next = NULL; + // block next + // ##### <-##### <-##### NULL + while (block->prev != NULL) { + struct stats_block *prev = block->prev; + // prev block next + // ##### <-##### <-##### ##... + block->prev = next; + // prev block next + // ##### <-##### #####-> ##... + next = block; + // prev bl=nx + // ##### <-##### #####-> ##... + block = prev; + // block next + // ##### <-##### #####-> ##... + } + // block next + // NULL <-##### #####-> ##... + block->prev = next; + // block next + // NULL #####-> #####-> ##... + } + + while (block != NULL) { + for (size_t i = 0; i < block->len; ) { + switch (block->buf[i]) { + case 1: i += 1 + stats_print_unary(block->buf + i+1); break; + case 2: i += 1 + stats_print_binary(block->buf + i+1); break; + default: + printf("# UNKNOWN ENTRY WITH ID %d, SKIPPING BLOCK\n", (int)block->buf[i]); + i = block->len; + break; + } + } + struct stats_block *next = block->prev; // remember, reversed! + free(block); + block = next; + } + + printf("=== ox-arrays-arith-stats end ===\n"); +} + + /***************************************************************************** * Additional math functions * *****************************************************************************/ @@ -325,6 +532,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { // 'out' will be filled densely in linearisation order. #define DOTPROD_INNER_OP(typ) \ void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ + TIME_START(tm); \ if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \ @@ -339,6 +547,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], strides1[rank - 1], arr1 + arrlinidx1, strides2[rank - 1], arr2 + arrlinidx2); \ }); \ } \ + stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \ } diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9c560d6..27ebb64 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -27,6 +27,7 @@ import Foreign.Storable (Storable(sizeOf), peek, poke) import GHC.TypeLits import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH +import System.IO (hFlush, stdout) import System.IO.Unsafe import Data.Array.Mixed.Internal.Arith.Foreign @@ -603,6 +604,19 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op (scaleFromSVStrided $c_scale_op) $c_red_op $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) +foreign import ccall unsafe "oxarrays_stats_enable" c_stats_enable :: Int32 -> IO () +foreign import ccall unsafe "oxarrays_stats_print_all" c_stats_print_all :: IO () + +statisticsEnable :: Bool -> IO () +statisticsEnable b = c_stats_enable (if b then 1 else 0) + +-- | Consumes the log: one particular event will only ever be printed once, +-- even if statisticsPrintAll is called multiple times. +statisticsPrintAll :: IO () +statisticsPrintAll = do + hFlush stdout -- lower the chance of overlapping output + c_stats_print_all + -- This branch is ostensibly a runtime branch, but will (hopefully) be -- constant-folded away by GHC. intWidBranch1 :: forall i n. (FiniteBits i, Storable i) |