From 831af1d49c9bb7d17794d259c99f92b2513496c5 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 16 May 2021 19:13:05 +0200 Subject: server: WIP utf8 validation implementation --- test/hashtable.c | 5 +- test/main.c | 32 +++++++- test/test_framework.c | 20 +++++ test/test_framework.h | 4 + test/utf8.c | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 276 insertions(+), 4 deletions(-) create mode 100644 test/utf8.c (limited to 'test') diff --git a/test/hashtable.c b/test/hashtable.c index ea6217c..a147d61 100644 --- a/test/hashtable.c +++ b/test/hashtable.c @@ -2,12 +2,11 @@ #include #include #include "test_framework.h" -#include "../global.h" #include "../hashtable.h" -static const int NUM_RUNS = 1000; -static const int NUM_OPERATIONS_PER_RUN = 100000; +static const int NUM_RUNS = 100; +static const int NUM_OPERATIONS_PER_RUN = 10000; struct pair { diff --git a/test/main.c b/test/main.c index 5c0656c..8c2fffe 100644 --- a/test/main.c +++ b/test/main.c @@ -1,4 +1,5 @@ #include +#include #include #include #include "test_framework.h" @@ -13,7 +14,7 @@ atomic_flag test_framework_assertion_failed = ATOMIC_FLAG_INIT; static void report(const char *clr, const char *label, const char *name, clock_t taken) { - printf("\x1B[%sm[%s] %s (%lfs)\n", + printf("\x1B[%sm[%s] %s (%lfs)\x1B[0m\n", clr, label, name, (double)taken / CLOCKS_PER_SEC); } @@ -41,10 +42,39 @@ static void report_failure(const char *name, clock_t taken) { static int run_tests(void) { RUN_TEST(hashtable_unit1); RUN_TEST(hashtable); + RUN_TEST(utf8_unit1); + RUN_TEST(utf8_random); + RUN_TEST(utf8_random_valid); + RUN_TEST(utf8_exhaustive_1); return 0; } +static unsigned int random_seed_from_device(void) { + FILE *f = fopen("/dev/urandom", "r"); + if (!f) { + fprintf(stderr, "Cannot open /dev/urandom\n"); + exit(1); + } + unsigned int seed; + int nread = fread(&seed, 1, sizeof seed, f); + if (nread < (int)sizeof seed) { + fprintf(stderr, "Cannot read from /dev/urandom\n"); + exit(1); + } + fclose(f); + return seed; +} int main(void) { + unsigned int seed; + const char *seed_env_var = getenv("SEED"); + if (seed_env_var && seed_env_var[0]) { + seed = strtol(seed_env_var, NULL, 10); + } else { + seed = random_seed_from_device(); + } + fprintf(stderr, "seed = %u\n", seed); + srandom(seed); + return run_tests(); } diff --git a/test/test_framework.c b/test/test_framework.c index 8cd53ac..d566944 100644 --- a/test/test_framework.c +++ b/test/test_framework.c @@ -1,5 +1,6 @@ #include #include +#include #include "test_framework.h" @@ -12,3 +13,22 @@ void test_report_error( fname, lineno, type, condition); atomic_flag_test_and_set(&test_framework_assertion_failed); } + +void print_buffer(FILE *stream, const void *buffer_, size_t length) { + const uint8_t *buffer = (const uint8_t*)buffer_; + fputc('"', stream); + + for (size_t i = 0; i < length; i++) { + const uint8_t b = buffer[i]; + if (b == '"') fprintf(stream, "\\\""); + else if (b == '\\') fprintf(stream, "\\\\"); + else if (32 <= b && b < 127) fputc(b, stream); + else { + fprintf(stream, "\\x%c%c", + "0123456789abcdef"[b >> 4], + "0123456789abcdef"[b & 0xf]); + } + } + + fprintf(stream, "\"[%zu]", length); +} diff --git a/test/test_framework.h b/test/test_framework.h index 327de00..3323190 100644 --- a/test/test_framework.h +++ b/test/test_framework.h @@ -1,5 +1,7 @@ #pragma once +#include + #define TEST(name) testfn__ ## name @@ -23,3 +25,5 @@ void test_report_error( return (ret_); \ } \ } while (0) + +void print_buffer(FILE *stream, const void *buffer, size_t length); diff --git a/test/utf8.c b/test/utf8.c new file mode 100644 index 0000000..afc7383 --- /dev/null +++ b/test/utf8.c @@ -0,0 +1,219 @@ +#include +#include +#include +#include "test_framework.h" +#include "../global.h" +#include "../utf8.h" + + +// Returns the number of bytes in the utf8 unit, or -1 on invalid input. +// If the parse is successful, puts the bits that are part of the unit being +// parsed in *unit. +static int parse_utf8_prefix_byte(uint8_t b, int64_t *unit) { + if ((b & 0b10000000) == 0b00000000) {*unit = b & 0b01111111; return 1;} + if ((b & 0b11100000) == 0b11000000) {*unit = b & 0b00011111; return 2;} + if ((b & 0b11110000) == 0b11100000) {*unit = b & 0b00001111; return 3;} + if ((b & 0b11111000) == 0b11110000) {*unit = b & 0b00000111; return 4;} + return -1; +} + +// Returns length of the parsed utf8 unit, and puts the parsed value in *unitp. +// No range checking is done. +static int parse_utf8_unit(const uint8_t *buf, size_t length, int64_t *unitp, bool debug) { + if (length == 0) { + if (debug) fprintf(stderr, "[utf8ref] unit at EOS\n"); + return -1; + } + + int64_t unit; + const int num_bytes = parse_utf8_prefix_byte(buf[0], &unit); + assert(num_bytes == -1 || (1 <= num_bytes && num_bytes <= 4)); + if (num_bytes == -1) { + if (debug) fprintf(stderr, "[utf8ref] invalid prefix byte %x\n", (unsigned)buf[0]); + return -1; + } + assert(unit >= 0); + if (length < (size_t)num_bytes) { + if (debug) fprintf(stderr, "[utf8ref] prefix byte %x specifies length %d, but EOS\n", (unsigned)buf[0], num_bytes); + return -1; + } + + for (int i = 1; i < num_bytes; i++) { + if ((buf[i] & 0b11000000) != 0b10000000) { + if (debug) fprintf(stderr, "[utf8ref] invalid continuation byte %x\n", (unsigned)buf[i]); + return -1; + } + unit = (unit << 6) | (buf[i] & 0b00111111); + } + + // check for overlong encodings + if ((num_bytes >= 2 && unit <= 0x7F) || + (num_bytes >= 3 && unit <= 0x7FF) || + (num_bytes >= 4 && unit <= 0xFFFF)) { + if (debug) fprintf(stderr, "[utf8ref] overlong encoding with prefix byte %x\n", (unsigned)buf[0]); + return -1; + } + + *unitp = unit; + return num_bytes; +} + +static bool validate_utf8_reference(const char *buf_, size_t length, bool debug) { + const uint8_t *buf = (const uint8_t*)buf_; + + size_t cursor = 0; + while (cursor < length) { + int64_t unit; + int len = parse_utf8_unit(buf + cursor, length - cursor, &unit, debug); + assert(len == -1 || (1 <= len && len <= 4)); + if (len == -1) return false; + assert(unit >= 0); + // fprintf(stderr, "unit = 0x%lx\n", unit); + + // Surrogate code point + if (0xD800 <= unit && unit <= 0xDFFF) { + if (debug) fprintf(stderr, "[utf8ref] surrogate code point %lx (prefix byte %x)\n", unit, (unsigned)buf[cursor]); + return false; + } + // Maximal unicode value + if (unit > 0x10FFFF) { + if (debug) fprintf(stderr, "[utf8ref] out of range code point %lx (prefix byte %x)\n", unit, (unsigned)buf[cursor]); + return false; + } + + cursor += len; + } + + return true; +} + +// Requires that the buffer has space for at least 4 bytes. +// Returns the number of bytes written. +static int utf8_serialise(char *buf_, int64_t unit) { + uint8_t *buf = (uint8_t*)buf_; + +#define PLACE_CONTINUATION_BYTE(idx_) \ + {buf[(idx_)] = 0x80 | (unit & 0x3F); unit >>= 6;} + + if (unit <= 0x7F) { + buf[0] = unit; + return 1; + } + if (unit <= 0x7FF) { + PLACE_CONTINUATION_BYTE(1); + buf[0] = 0xC0 | (unit & 0x1F); + return 2; + } + if (unit <= 0xFFFF) { + PLACE_CONTINUATION_BYTE(2); + PLACE_CONTINUATION_BYTE(1); + buf[0] = 0xE0 | (unit & 0x0F); + return 3; + } + if (unit <= 0x10FFFF) { + PLACE_CONTINUATION_BYTE(3); + PLACE_CONTINUATION_BYTE(2); + PLACE_CONTINUATION_BYTE(1); + buf[0] = 0xF0 | (unit & 0x07); + return 4; + } + assert(false && "Invalid unit in utf8_serialise"); + +#undef PLACE_CONTINUATION_BYTE +} + +static void fill_random_buffer(char *buf, size_t length) { + size_t i = 0; + while (i + sizeof(long) < length) { + *(long*)&buf[i] = random(); + i += sizeof(long); + } + while (i < length) buf[i++] = random(); +} + +DEFINE_TEST(utf8_unit1) { + EXPECT(validate_utf8("hello", 5)); + EXPECT(validate_utf8_reference("hello", 5, true)); + const char *str = "hello 🧀🇳🇱"; + EXPECT(validate_utf8(str, strlen(str))); + EXPECT(validate_utf8_reference(str, strlen(str), true)); + EXPECT(validate_utf8("\xe0\xad\xbc`j", 5)); + EXPECT(validate_utf8_reference("\xe0\xad\xbc`j", 5, true)); + EXPECT(validate_utf8("\xd3\xb0\\i\x00\x00\x00\x001\xc7\xaa_", 12)); + EXPECT(validate_utf8("\xc7\xaa_", 3)); + EXPECT(validate_utf8_reference("\xd3\xb0\\i\x00\x00\x00\x001\xc7\xaa_", 12, true)); + EXPECT(!validate_utf8("\xf2\x98\xbcx", 4)); + EXPECT(!validate_utf8_reference("\xf2\x98\xbcx", 4, false)); + return 0; +} + +DEFINE_TEST(utf8_random) { + const int max_length = 100; + const int num_tests = 10000000; + + char *buffer = malloc(max_length + 1, char); + for (int test = 0; test < num_tests; test++) { + // fprintf(stderr, "== test = %d\n", test); + const int length = random() % max_length; + fill_random_buffer(buffer, length); + const bool ret_ref = validate_utf8_reference(buffer, length, false); + const bool ret_impl = validate_utf8(buffer, length); + if (ret_ref != ret_impl) { + fprintf(stderr, "buffer: "); + print_buffer(stderr, buffer, length); + fprintf(stderr, "\n"); + fprintf(stderr, "reference -> %d, implementation -> %d\n", ret_ref, ret_impl); + EXPECTRET(1, false && "validate_utf8_reference == validate_utf8"); + } + } + free(buffer); + + return 0; +} + +DEFINE_TEST(utf8_random_valid) { + const int max_length = 100; + const int num_tests = 3000000; + + char *buffer = malloc(max_length + 1, char); + for (int test = 0; test < num_tests; test++) { + int length = random() % max_length; + + int cursor = 0; + while (cursor + 4 <= length) { + const int64_t unit = random() % 0x110000; + if (0xD800 <= unit && unit <= 0xDFFF) continue; // surrogate + cursor += utf8_serialise(buffer + cursor, unit); + } + length = cursor; + + const bool ret_ref = validate_utf8_reference(buffer, length, true); + if (!ret_ref) { + fprintf(stderr, "buffer: "); + print_buffer(stderr, buffer, length); + fprintf(stderr, "\n"); + EXPECTRET(1, false && "validate_utf8_reference on valid string"); + } + + const bool ret_impl = validate_utf8(buffer, length); + if (!ret_impl) { + fprintf(stderr, "buffer: "); + print_buffer(stderr, buffer, length); + fprintf(stderr, "\n"); + EXPECTRET(1, false && "validate_utf8 on valid string"); + } + } + free(buffer); + + return 0; +} + +DEFINE_TEST(utf8_exhaustive_1) { + for (int64_t number = 0; number < 0x100000000LL; number++) { + EXPECT( + validate_utf8_reference((const char*)&number, 4, false) + == validate_utf8((const char*)&number, 4) + ); + } + return 0; +} -- cgit v1.2.3-70-g09d2