summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--Makefile13
-rw-r--r--board.c48
-rw-r--r--board.cpp81
-rw-r--r--board.h37
-rwxr-xr-xgenwinmasks.py2
-rw-r--r--listalloc.h132
-rw-r--r--main.c77
-rw-r--r--main.cpp26
-rw-r--r--minimax.cpp91
-rw-r--r--minimax.h13
-rw-r--r--pnsearch.cpp286
-rw-r--r--pnsearch.h6
13 files changed, 668 insertions, 145 deletions
diff --git a/.gitignore b/.gitignore
index d9b6f50..2a53474 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
ttt3d
+*.o
diff --git a/Makefile b/Makefile
index 3652a5d..59e28d0 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
-CC = gcc
-CFLAGS = -Wall -Wextra -std=c11 -O3 -fwrapv -flto
+CXX = g++
+CXXFLAGS = -Wall -Wextra -std=c++17 -O3 -fwrapv -flto
LDFLAGS =
TARGET = ttt3d
@@ -8,8 +8,11 @@ TARGET = ttt3d
all: $(TARGET)
clean:
- rm -f $(TARGET)
+ rm -f $(TARGET) *.o
-$(TARGET): $(wildcard *.c *.h)
- $(CC) $(CFLAGS) -o $@ $(filter %.c,$^) $(LDFLAGS)
+$(TARGET): $(patsubst %.cpp,%.o,$(wildcard *.cpp))
+ $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
+
+%.o: %.cpp $(wildcard *.h)
+ $(CXX) $(CXXFLAGS) -c -o $@ $<
diff --git a/board.c b/board.c
deleted file mode 100644
index 91cee2a..0000000
--- a/board.c
+++ /dev/null
@@ -1,48 +0,0 @@
-#include "board.h"
-
-
-void b_set(board_t B, cboard_t C) {
- B[0] = C[0];
- B[1] = C[1];
-}
-
-void b_drop(board_t B, int xy, int v) {
- u64 mask = 0xfULL << (4 * xy), stkbot = 0x1ULL << (4 * xy);
- u64 bit = ((B[0] | B[1]) & mask) + stkbot;
- B[v] = B[v] | bit;
-}
-
-bool b_stk_full(cboard_t B, int xy) {
- u64 mask = 0xfULL << (4 * xy);
- return ((B[0] | B[1]) & mask) == mask;
-}
-
-static u64 winmasks[61] = {
- 0xfULL, 0xf0ULL, 0xf00ULL, 0x1111ULL,
- 0x2222ULL, 0x4444ULL, 0x8421ULL, 0x8888ULL,
- 0xf000ULL, 0xf0000ULL, 0xf00000ULL, 0xf000000ULL,
- 0x11110000ULL, 0x22220000ULL, 0x44440000ULL, 0x84210000ULL,
- 0x88880000ULL, 0xf0000000ULL, 0xf00000000ULL, 0xf000000000ULL,
- 0xf0000000000ULL, 0x111100000000ULL, 0x222200000000ULL, 0x444400000000ULL,
- 0x842100000000ULL, 0x888800000000ULL, 0xf00000000000ULL, 0x1000100010001ULL,
- 0x2000200020002ULL, 0x4000400040004ULL, 0x8000400020001ULL, 0x8000800080008ULL,
- 0xf000000000000ULL, 0x10001000100010ULL, 0x20002000200020ULL, 0x40004000400040ULL,
- 0x80004000200010ULL, 0x80008000800080ULL, 0xf0000000000000ULL, 0x100010001000100ULL,
- 0x200020002000200ULL, 0x400040004000400ULL, 0x800040002000100ULL, 0x800080008000800ULL,
- 0xf00000000000000ULL, 0x1000010000100001ULL, 0x1000100010001000ULL, 0x1111000000000000ULL,
- 0x2000020000200002ULL, 0x2000200020002000ULL, 0x2222000000000000ULL, 0x4000040000400004ULL,
- 0x4000400040004000ULL, 0x4444000000000000ULL, 0x8000040000200001ULL, 0x8000080000800008ULL,
- 0x8000400020001000ULL, 0x8000800080008000ULL, 0x8421000000000000ULL, 0x8888000000000000ULL,
- 0xf000000000000000ULL,
-};
-
-enum win_t b_win(cboard_t B) {
- if ((B[0] | B[1]) == ~0ULL) return WIN_DRAW;
-
- for (int i = 0; i < (int)(sizeof(winmasks) / sizeof(winmasks[0])); i++) {
- if ((B[0] & winmasks[i]) == winmasks[i]) return WIN_P0;
- if ((B[1] & winmasks[i]) == winmasks[i]) return WIN_P1;
- }
-
- return WIN_NONE;
-}
diff --git a/board.cpp b/board.cpp
new file mode 100644
index 0000000..5afcea1
--- /dev/null
+++ b/board.cpp
@@ -0,0 +1,81 @@
+#include <array>
+#include "board.h"
+
+using namespace std;
+
+
+Board Board::empty() {
+ Board B;
+ B.bd[0] = B.bd[1] = 0;
+ return B;
+}
+
+Board Board::uninitialised() {
+ return Board();
+}
+
+void Board::drop(int xy, int v) {
+ uint64_t mask = 0xfULL << (4 * xy), stkbot = 0x1ULL << (4 * xy);
+ uint64_t bit = ((bd[0] | bd[1]) & mask) + stkbot;
+ bd[v] |= bit;
+}
+
+bool Board::stkFull(int xy) const {
+ uint64_t mask = 0xfULL << (4 * xy);
+ return ((bd[0] | bd[1]) & mask) == mask;
+}
+
+static array<uint64_t, 61> winmasks = {
+ 0xfULL, 0xf0ULL, 0xf00ULL, 0x1111ULL,
+ 0x2222ULL, 0x4444ULL, 0x8421ULL, 0x8888ULL,
+ 0xf000ULL, 0xf0000ULL, 0xf00000ULL, 0xf000000ULL,
+ 0x11110000ULL, 0x22220000ULL, 0x44440000ULL, 0x84210000ULL,
+ 0x88880000ULL, 0xf0000000ULL, 0xf00000000ULL, 0xf000000000ULL,
+ 0xf0000000000ULL, 0x111100000000ULL, 0x222200000000ULL, 0x444400000000ULL,
+ 0x842100000000ULL, 0x888800000000ULL, 0xf00000000000ULL, 0x1000100010001ULL,
+ 0x2000200020002ULL, 0x4000400040004ULL, 0x8000400020001ULL, 0x8000800080008ULL,
+ 0xf000000000000ULL, 0x10001000100010ULL, 0x20002000200020ULL, 0x40004000400040ULL,
+ 0x80004000200010ULL, 0x80008000800080ULL, 0xf0000000000000ULL, 0x100010001000100ULL,
+ 0x200020002000200ULL, 0x400040004000400ULL, 0x800040002000100ULL, 0x800080008000800ULL,
+ 0xf00000000000000ULL, 0x1000010000100001ULL, 0x1000100010001000ULL, 0x1111000000000000ULL,
+ 0x2000020000200002ULL, 0x2000200020002000ULL, 0x2222000000000000ULL, 0x4000040000400004ULL,
+ 0x4000400040004000ULL, 0x4444000000000000ULL, 0x8000040000200001ULL, 0x8000080000800008ULL,
+ 0x8000400020001000ULL, 0x8000800080008000ULL, 0x8421000000000000ULL, 0x8888000000000000ULL,
+ 0xf000000000000000ULL,
+};
+
+win_t Board::checkWin() const {
+ if ((bd[0] | bd[1]) == ~0ULL) return WIN_DRAW;
+
+ for (int i = 0; i < (int)winmasks.size(); i++) {
+ if ((bd[0] & winmasks[i]) == winmasks[i]) return WIN_P0;
+ if ((bd[1] & winmasks[i]) == winmasks[i]) return WIN_P1;
+ }
+
+ return WIN_NONE;
+}
+
+uint64_t Board::hash() const {
+ return bd[0] ^ ((bd[1] << 29) | (bd[1] >> (64 - 29)));
+}
+
+int Board::numFilled() const {
+ return __builtin_popcountll(bd[0] | bd[1]);
+}
+
+bool Board::isEmpty() const {
+ return bd[0] == 0 && bd[1] == 0;
+}
+
+int Board::numValidMoves() const {
+ uint64_t topsMask = 0x8888'8888'8888'8888ULL;
+ return 16 - __builtin_popcountll((bd[0] | bd[1]) & topsMask);
+}
+
+int Board::playerToMove() const {
+ return numFilled() % 2;
+}
+
+bool Board::operator==(const Board &other) const {
+ return bd[0] == other.bd[0] && bd[1] == other.bd[1];
+}
diff --git a/board.h b/board.h
index f795d03..f21403b 100644
--- a/board.h
+++ b/board.h
@@ -1,16 +1,7 @@
#pragma once
-#include <stdbool.h>
-#include <stdint.h>
+#include <cstdint>
-typedef uint64_t u64;
-typedef uint8_t u8;
-
-
-// as bit array: [16 * y + 4 * x + z]
-// (x,y) = floor position, z = height with z=0 on the bottom
-typedef u64 board_t[2];
-typedef const u64 cboard_t[2];
enum win_t {
WIN_NONE,
@@ -19,7 +10,25 @@ enum win_t {
WIN_DRAW,
};
-void b_set(board_t B, cboard_t C);
-void b_drop(board_t B, int xy, int v);
-bool b_stk_full(cboard_t B, int xy);
-enum win_t b_win(cboard_t B);
+struct Board {
+ // as bit array: [16 * y + 4 * x + z]
+ // (x,y) = floor position, z = height with z=0 on the bottom
+ uint64_t bd[2];
+
+ static Board empty();
+ static Board uninitialised();
+
+ void drop(int xy, int v);
+ bool stkFull(int xy) const;
+ win_t checkWin() const;
+ uint64_t hash() const;
+ int numFilled() const;
+ bool isEmpty() const;
+ int numValidMoves() const;
+ int playerToMove() const;
+
+ bool operator==(const Board &other) const;
+
+protected:
+ Board() = default;
+};
diff --git a/genwinmasks.py b/genwinmasks.py
index 45a9f58..9d118be 100755
--- a/genwinmasks.py
+++ b/genwinmasks.py
@@ -25,7 +25,7 @@ for x in range(4):
masks.sort()
-print("static u64 winmasks[{}] = {{".format(len(masks)))
+print("static array<uint64_t, 61> winmasks[{}] = {{".format(len(masks)))
for i in range((len(masks) + 3) // 4):
print("\t" + " ".join(hex(n) + "ULL," for n in masks[4*i:4*i+4]))
print("};")
diff --git a/listalloc.h b/listalloc.h
new file mode 100644
index 0000000..cdb512d
--- /dev/null
+++ b/listalloc.h
@@ -0,0 +1,132 @@
+#pragma once
+
+#include <stdexcept>
+#include <vector>
+#include <stack>
+#include <memory>
+
+using namespace std;
+
+
+template <typename T, typename Index>
+class ListAlloc {
+ static_assert(is_integral_v<Index> && is_unsigned_v<Index>);
+
+public:
+ const size_t capacity;
+
+ ListAlloc(size_t capacity);
+ ~ListAlloc();
+
+ template <typename... Args>
+ Index allocate(Args... args);
+ void deallocate(Index index);
+
+ T* at(Index index);
+
+private:
+ char *buffer;
+ stack<Index> freeStack;
+ size_t cursor = 0;
+};
+
+template <typename T, typename Index>
+class ListAllocRef {
+ static_assert(is_integral_v<Index> && is_unsigned_v<Index>);
+
+public:
+ ListAllocRef() : index(-1) {}
+ ListAllocRef(Index index) : index(index) {}
+
+ ~ListAllocRef() {
+ if (index != (Index)-1) {
+ assert(false && "Non-empty ListAllocRef upon destruction");
+ }
+ }
+
+ // No copying
+ ListAllocRef(const ListAllocRef<T, Index> &other) = delete;
+ ListAllocRef<T, Index>& operator=(const ListAllocRef<T, Index> &other) = delete;
+
+ operator bool() const {
+ return index != (Index)-1;
+ }
+
+ ListAllocRef<T, Index>& operator=(Index newIndex) {
+ if (index != (Index)-1) {
+ throw logic_error("operator= on non-empty ListAllocRef");
+ }
+
+ index = newIndex;
+ return *this;
+ }
+
+ T* get(ListAlloc<T, Index> &allocator) {
+ return allocator.at(index);
+ }
+
+ const T* get(ListAlloc<T, Index> &allocator) const {
+ return allocator.at(index);
+ }
+
+ void deallocate(ListAlloc<T, Index> &allocator) {
+ if (index != (Index)-1) {
+ allocator.deallocate(index);
+ index = -1;
+ }
+ }
+
+private:
+ Index index;
+};
+
+template <typename T, typename Index>
+ListAlloc<T, Index>::ListAlloc(size_t capacity)
+ : capacity(capacity)
+ , buffer(new char[capacity * sizeof(T)]) {
+
+ size_t largestIndex = capacity - 1;
+ if (capacity != 0 && (size_t)(Index)largestIndex != largestIndex) {
+ throw logic_error("Capacity too large for index type in ListAlloc");
+ }
+
+ fprintf(stderr, "ListAlloc with capacity=%zu, size=%zu\n", capacity, capacity * sizeof(T));
+}
+
+template <typename T, typename Index>
+ListAlloc<T, Index>::~ListAlloc() {
+ if (freeStack.size() != cursor) {
+ assert(false && "Not all entries deallocated in ~ListAlloc");
+ }
+
+ delete[] buffer;
+}
+
+template <typename T, typename Index>
+template <typename... Args>
+Index ListAlloc<T, Index>::allocate(Args... args) {
+ Index index;
+
+ if (!freeStack.empty()) {
+ index = freeStack.top();
+ freeStack.pop();
+ } else if (cursor < capacity) {
+ index = cursor++;
+ } else {
+ throw runtime_error("Out of memory in ListAlloc");
+ }
+
+ new(at(index)) T(args...);
+ return index;
+}
+
+template <typename T, typename Index>
+void ListAlloc<T, Index>::deallocate(Index index) {
+ at(index)->~T();
+ freeStack.push(index);
+}
+
+template <typename T, typename Index>
+T* ListAlloc<T, Index>::at(Index index) {
+ return (T*)&buffer[index * sizeof(T)];
+}
diff --git a/main.c b/main.c
deleted file mode 100644
index 6b5cac8..0000000
--- a/main.c
+++ /dev/null
@@ -1,77 +0,0 @@
-#include <stdio.h>
-#include <stdlib.h>
-#include "board.h"
-
-#define LARGE 100000
-
-
-static inline int min(int a, int b) {
- return a < b ? a : b;
-}
-
-static inline int max(int a, int b) {
- return a < b ? b : a;
-}
-
-
-static int win_score[4] = {
- [WIN_NONE] = 0,
- [WIN_P0] = 1000,
- [WIN_P1] = -1000,
- [WIN_DRAW] = 0,
-};
-
-static int evaluate(cboard_t B) {
- (void)B;
- return 0;
-}
-
-static int minimax_p0(cboard_t B, int alpha, int beta, int depth);
-
-static int minimax_p1(cboard_t B, int alpha, int beta, int depth) {
- if (depth == 0) return evaluate(B);
- enum win_t win = b_win(B);
- if (win != WIN_NONE) return win_score[win];
-
- int best = LARGE;
- for (int i = 0; i < 16; i++) {
- if (b_stk_full(B, i)) continue;
- board_t C; b_set(C, B);
- b_drop(C, i, 1);
- int sc = minimax_p0(C, alpha, beta, depth - 1);
- best = min(sc, best);
- beta = min(best, beta);
- if (beta <= alpha) break;
- }
-
- return best;
-}
-
-static int minimax_p0(cboard_t B, int alpha, int beta, int depth) {
- if (depth == 0) return evaluate(B);
- enum win_t win = b_win(B);
- if (win != WIN_NONE) return win_score[win];
-
- if (depth >= 36) {
- printf("depth=%d alpha=%d beta=%d\n", depth, alpha, beta);
- }
-
- int best = -LARGE;
- for (int i = 0; i < 16; i++) {
- if (b_stk_full(B, i)) continue;
- board_t C; b_set(C, B);
- b_drop(C, i, 0);
- int sc = minimax_p1(C, alpha, beta, depth - 1);
- best = max(sc, best);
- alpha = max(best, alpha);
- if (beta <= alpha) break;
- }
-
- return best;
-}
-
-
-int main(void) {
- board_t B = {0, 0};
- printf("%d\n", minimax_p0(B, -LARGE, LARGE, 64));
-}
diff --git a/main.cpp b/main.cpp
new file mode 100644
index 0000000..0a9c2d2
--- /dev/null
+++ b/main.cpp
@@ -0,0 +1,26 @@
+#define USE_PN_SEARCH
+
+#include <array>
+#include <memory>
+#include <cstdio>
+#include <cstdlib>
+#include "board.h"
+
+#ifdef USE_PN_SEARCH
+#include "pnsearch.h"
+#else
+#include "minimax.h"
+#endif
+
+using namespace std;
+
+
+int main(void) {
+ Board B = Board::empty();
+
+#ifdef USE_PN_SEARCH
+ printf("%d\n", pnsearch(B));
+#else
+ printf("%d\n", minimax<0>(B, -MINIMAX_LARGE, MINIMAX_LARGE, 64));
+#endif
+}
diff --git a/minimax.cpp b/minimax.cpp
new file mode 100644
index 0000000..c4011da
--- /dev/null
+++ b/minimax.cpp
@@ -0,0 +1,91 @@
+#include <array>
+#include <memory>
+#include "minimax.h"
+
+using namespace std;
+
+#define LARGE MINIMAX_LARGE
+
+
+static int win_score[4] = {
+ [WIN_NONE] = 0,
+ [WIN_P0] = 1,
+ [WIN_P1] = -1,
+ [WIN_DRAW] = -1,
+};
+
+struct TransItem {
+ Board B = Board::empty();
+ int score = 0;
+};
+
+static unique_ptr<array<TransItem, 10000019>> transTable =
+ make_unique<decltype(transTable)::element_type>();
+
+static void transTableStore(uint64_t index, const Board &B, int score) {
+ TransItem &entry = (*transTable)[index];
+ // if (entry.B.isEmpty() || B.numFilled() < entry.B.numFilled()) {
+ entry.B = B;
+ entry.score = score;
+ // }
+}
+
+static uint64_t transHits = 0, transTotal = 0;
+static uint64_t transHitDepthSum = 0;
+
+template <int player>
+int minimax(const Board &B, int alpha, int beta, int depth) {
+ uint64_t transIndex = B.hash() % transTable->size();
+ transTotal++;
+ if (!B.isEmpty() && (*transTable)[transIndex].B == B) {
+ transHits++;
+ transHitDepthSum += depth;
+ return (*transTable)[transIndex].score;
+ }
+
+ win_t win = B.checkWin();
+ if (win != WIN_NONE) {
+ transTableStore(transIndex, B, win_score[win]);
+ return win_score[win];
+ }
+
+ if (depth == 0) return win_score[WIN_NONE];
+
+ if (depth >= 42) {
+ printf("depth=%d alpha=%d beta=%d transHitrate=%lf (avg depth %lf)\n",
+ depth, alpha, beta,
+ (double)transHits / transTotal, (double)transHitDepthSum / transTotal);
+ }
+
+ int best = player == 0 ? -LARGE : LARGE;
+
+ // Returns true if iteration should be stopped
+ auto tryMove = [&](int i) {
+ if (B.stkFull(i)) return false;
+
+ Board C(B);
+ C.drop(i, player);
+
+ int sc = minimax<1 - player>(C, alpha, beta, depth - 1);
+
+ if constexpr (player == 0) {
+ best = max(sc, best);
+ alpha = max(best, alpha);
+ } else {
+ best = min(sc, best);
+ beta = min(best, beta);
+ }
+
+ return beta <= alpha;
+ };
+
+ for (int i = 0; i < 16; i++) {
+ if (tryMove(i)) break;
+ }
+
+ transTableStore(transIndex, B, best);
+ return best;
+}
+
+template int minimax<0>(const Board &B, int alpha, int beta, int depth);
+template int minimax<1>(const Board &B, int alpha, int beta, int depth);
diff --git a/minimax.h b/minimax.h
new file mode 100644
index 0000000..8c77acd
--- /dev/null
+++ b/minimax.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "board.h"
+
+#define MINIMAX_LARGE 10000
+
+
+// Returns:
+// - 1 if player 0 wins
+// - -1 if player 1 wins or can force a draw
+// - 0 if neither player can force a win within the given depth
+template <int player>
+int minimax(const Board &B, int alpha, int beta, int depth);
diff --git a/pnsearch.cpp b/pnsearch.cpp
new file mode 100644
index 0000000..a09d362
--- /dev/null
+++ b/pnsearch.cpp
@@ -0,0 +1,286 @@
+#include <array>
+#include <memory>
+#include <algorithm>
+#include <cstdio>
+#include <cstring>
+#include <cassert>
+#include "pnsearch.h"
+#include "minimax.h"
+#include "listalloc.h"
+
+using namespace std;
+
+
+static uint64_t gNodesAllocated = 0;
+static uint64_t gEvalWins = 0;
+static uint64_t gNodesPruned = 0;
+
+struct AOTree {
+ enum type_t { AND, OR };
+ enum eval_t { TRUE, FALSE, UNKNOWN };
+
+ AOTree *const parent;
+ const Board B;
+ const type_t type;
+
+ bool expanded = false;
+ int pnum = 1, dnum = 1;
+ array<ListAllocRef<AOTree, uint32_t>, 16> children;
+
+ AOTree(AOTree *parent, Board B, type_t type) noexcept;
+ ~AOTree() noexcept;
+
+ void writeDOT(FILE *f, bool root = true) const;
+ uint64_t treeSize() const;
+};
+
+static ListAlloc<AOTree, uint32_t> gAllocator(6ULL * 1024 * 1024 * 1024 / sizeof(AOTree));
+
+AOTree::AOTree(AOTree *parent, Board B, type_t type) noexcept
+ : parent(parent), B(B), type(type) {
+
+ gNodesAllocated++;
+}
+
+AOTree::~AOTree() noexcept {
+ gNodesAllocated--;
+
+ for (int i = 0; i < 16; i++) {
+ children[i].deallocate(gAllocator);
+ }
+}
+
+void AOTree::writeDOT(FILE *f, bool root) const {
+ if (root) fprintf(f, "digraph G {\n");
+
+ fprintf(f, "\t\"%p\" [shape=%s, label=\"%d %d\"];\n",
+ this, type == AND ? "ellipse" : "rectangle",
+ pnum, dnum);
+
+ for (auto &child : children) {
+ if (child) {
+ fprintf(f, "\t\"%p\" -> \"%p\"\n", this, child.get(gAllocator));
+ child.get(gAllocator)->writeDOT(f, false);
+ }
+ }
+
+ if (root) fprintf(f, "}\n");
+}
+
+uint64_t AOTree::treeSize() const {
+ uint64_t s = 1;
+ for (auto &child : children) {
+ if (child) s += child.get(gAllocator)->treeSize();
+ }
+ return s;
+}
+
+
+template <typename F>
+static int pnum_sum(const array<ListAllocRef<AOTree, uint32_t>, 16> &arr, F func) {
+ int sum = 0;
+
+ for (int i = 0; i < 16; i++) {
+ if (arr[i]) {
+ int value = func(*arr[i].get(gAllocator));
+ if (value == -1) return -1;
+ sum += value;
+ }
+ }
+
+ return sum;
+}
+
+template <typename F>
+static int pnum_min(const array<ListAllocRef<AOTree, uint32_t>, 16> &arr, F func) {
+ int lowest = -1;
+
+ for (int i = 0; i < 16; i++) {
+ if (arr[i]) {
+ int value = func(*arr[i].get(gAllocator));
+ if (value == -1) continue;
+ if (lowest == -1) lowest = value;
+ else lowest = min(lowest, value);
+ }
+ }
+
+ return lowest;
+}
+
+
+static AOTree::eval_t evaluate(const Board &B) {
+#if 0
+ // Basic win check
+ switch (B.checkWin()) {
+ case WIN_NONE: return AOTree::UNKNOWN;
+ case WIN_P0: return AOTree::TRUE;
+ case WIN_P1: return AOTree::FALSE;
+ case WIN_DRAW: return AOTree::FALSE;
+ }
+ assert(false);
+#else
+ // Shallow minimax
+ auto func = B.playerToMove() == 0 ? minimax<0> : minimax<1>;
+ switch (func(B, -MINIMAX_LARGE, MINIMAX_LARGE, 4)) {
+ case 1: return AOTree::TRUE;
+ case -1: return AOTree::FALSE;
+ case 0: return AOTree::UNKNOWN;
+ }
+ assert(false);
+#endif
+}
+
+void generateAllChildren(AOTree &node) {
+ int onturn = node.B.playerToMove();
+ int onturn_child = 1 - onturn;
+ AOTree::type_t child_type = onturn_child == 0 ? AOTree::OR : AOTree::AND;
+
+ for (int i = 0; i < 16; i++) {
+ if (node.B.stkFull(i)) continue;
+
+ Board C(node.B);
+ C.drop(i, onturn);
+ node.children[i] = gAllocator.allocate(&node, C, child_type);
+ }
+}
+
+static void developNode(AOTree &node) {
+ generateAllChildren(node);
+
+ node.expanded = true;
+
+ for (int i = 0; i < 16; i++) {
+ if (!node.children[i]) continue;
+ AOTree &child = *node.children[i].get(gAllocator);
+
+ AOTree::eval_t eval = evaluate(child.B);
+ switch (eval) {
+ case AOTree::FALSE: child.pnum = -1; child.dnum = 0; break;
+ case AOTree::TRUE: child.pnum = 0; child.dnum = -1; break;
+ case AOTree::UNKNOWN: {
+ // ENHANCEMENT: initial proof/disproof numbers according to number of children
+ int numSubChildren = child.B.numValidMoves();
+ switch (child.type) {
+ case AOTree::OR: child.pnum = 1; child.dnum = numSubChildren; break;
+ case AOTree::AND: child.pnum = numSubChildren; child.dnum = 1; break;
+ }
+ break;
+ }
+ }
+ gEvalWins += eval == AOTree::TRUE || eval == AOTree::FALSE;
+ }
+}
+
+static void updateAncestors(AOTree &node_) {
+ AOTree *node = &node_;
+
+ while (node) {
+ assert(node->expanded);
+
+ switch (node->type) {
+ case AOTree::AND:
+ node->pnum = pnum_sum(node->children, [](const AOTree &n) { return n.pnum; });
+ node->dnum = pnum_min(node->children, [](const AOTree &n) { return n.dnum; });
+ break;
+
+ case AOTree::OR:
+ node->pnum = pnum_min(node->children, [](const AOTree &n) { return n.pnum; });
+ node->dnum = pnum_sum(node->children, [](const AOTree &n) { return n.dnum; });
+ break;
+ }
+
+ for (int i = 0; i < 16; i++) {
+ if (!node->children[i]) continue;
+
+ AOTree &child = *node->children[i].get(gAllocator);
+
+ // ENHANCEMENT: Delete solved subtrees
+ if (child.pnum == 0 || child.dnum == 0) {
+ gNodesPruned += child.treeSize();
+ node->children[i].deallocate(gAllocator);
+ }
+ }
+
+ node = node->parent;
+ }
+}
+
+static AOTree& selectMostProving(AOTree &node_) {
+ AOTree *node = &node_;
+
+ while (node->expanded) {
+ switch (node->type) {
+ case AOTree::OR: {
+ auto &laref = *find_if(node->children.begin(), node->children.end(),
+ [&node](ListAllocRef<AOTree, uint32_t> &child) {
+ if (!child) return false;
+ AOTree &obj = *child.get(gAllocator);
+ return obj.pnum == node->pnum;
+ });
+ node = laref.get(gAllocator);
+ break;
+ }
+
+ case AOTree::AND: {
+ auto &laref = *find_if(node->children.begin(), node->children.end(),
+ [&node](ListAllocRef<AOTree, uint32_t> &child) {
+ if (!child) return false;
+ AOTree &obj = *child.get(gAllocator);
+ return obj.dnum == node->dnum;
+ });
+ node = laref.get(gAllocator);
+ break;
+ }
+ }
+ }
+
+ return *node;
+}
+
+static AOTree::eval_t proofNumberSearch(AOTree &root) {
+ switch (evaluate(root.B)) {
+ case AOTree::TRUE: root.pnum = 0; root.dnum = -1; break;
+ case AOTree::FALSE: root.pnum = -1; root.dnum = 0; break;
+ case AOTree::UNKNOWN: break;
+ }
+
+ for (int iter = 0; root.pnum != 0 && root.dnum != 0; iter++) {
+ // char filename[128];
+ // sprintf(filename, "pntree-%d.dot", iter);
+ // FILE *f = fopen(filename, "w");
+ // root.writeDOT(f);
+ // fclose(f);
+
+ // if (iter == 20) break;
+
+ if (iter % 50000 == 0) {
+ printf("%d iterations done, %lu nodes allocated, %lu terminals, %lu pruned, root %d,%d\n", iter, gNodesAllocated, gEvalWins, gNodesPruned, root.pnum, root.dnum);
+ }
+
+ AOTree &mostProvingNode = selectMostProving(root);
+ developNode(mostProvingNode);
+ updateAncestors(mostProvingNode);
+ }
+
+ if (root.pnum == 0) return AOTree::TRUE;
+ else if (root.dnum == 0) return AOTree::FALSE;
+ else return AOTree::UNKNOWN;
+}
+
+int pnsearch(const Board &B) {
+ // printf("\x1B[1mWARNING: THIS IS GOING TO CONSUME A LOT OF MEMORY VERY QUICKLY. PLEASE WATCH IN htop AND BE READY TO ^C.\x1B[0m\n");
+
+ AOTree root(nullptr, B, B.numFilled() % 2 == 0 ? AOTree::OR : AOTree::AND);
+ AOTree::eval_t eval = proofNumberSearch(root);
+
+ switch (eval) {
+ case AOTree::TRUE: return 0;
+ case AOTree::FALSE: return 1;
+ case AOTree::UNKNOWN:
+ printf("UNKNOWN from proofNumberSearch!\n");
+ exit(0);
+ assert(false);
+ }
+
+ assert(false);
+}
diff --git a/pnsearch.h b/pnsearch.h
new file mode 100644
index 0000000..e0023bf
--- /dev/null
+++ b/pnsearch.h
@@ -0,0 +1,6 @@
+#pragma once
+
+#include "board.h"
+
+
+int pnsearch(const Board &B);