summaryrefslogtreecommitdiff
path: root/pnsearch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'pnsearch.cpp')
-rw-r--r--pnsearch.cpp286
1 files changed, 286 insertions, 0 deletions
diff --git a/pnsearch.cpp b/pnsearch.cpp
new file mode 100644
index 0000000..a09d362
--- /dev/null
+++ b/pnsearch.cpp
@@ -0,0 +1,286 @@
+#include <array>
+#include <memory>
+#include <algorithm>
+#include <cstdio>
+#include <cstring>
+#include <cassert>
+#include "pnsearch.h"
+#include "minimax.h"
+#include "listalloc.h"
+
+using namespace std;
+
+
+static uint64_t gNodesAllocated = 0;
+static uint64_t gEvalWins = 0;
+static uint64_t gNodesPruned = 0;
+
+struct AOTree {
+ enum type_t { AND, OR };
+ enum eval_t { TRUE, FALSE, UNKNOWN };
+
+ AOTree *const parent;
+ const Board B;
+ const type_t type;
+
+ bool expanded = false;
+ int pnum = 1, dnum = 1;
+ array<ListAllocRef<AOTree, uint32_t>, 16> children;
+
+ AOTree(AOTree *parent, Board B, type_t type) noexcept;
+ ~AOTree() noexcept;
+
+ void writeDOT(FILE *f, bool root = true) const;
+ uint64_t treeSize() const;
+};
+
+static ListAlloc<AOTree, uint32_t> gAllocator(6ULL * 1024 * 1024 * 1024 / sizeof(AOTree));
+
+AOTree::AOTree(AOTree *parent, Board B, type_t type) noexcept
+ : parent(parent), B(B), type(type) {
+
+ gNodesAllocated++;
+}
+
+AOTree::~AOTree() noexcept {
+ gNodesAllocated--;
+
+ for (int i = 0; i < 16; i++) {
+ children[i].deallocate(gAllocator);
+ }
+}
+
+void AOTree::writeDOT(FILE *f, bool root) const {
+ if (root) fprintf(f, "digraph G {\n");
+
+ fprintf(f, "\t\"%p\" [shape=%s, label=\"%d %d\"];\n",
+ this, type == AND ? "ellipse" : "rectangle",
+ pnum, dnum);
+
+ for (auto &child : children) {
+ if (child) {
+ fprintf(f, "\t\"%p\" -> \"%p\"\n", this, child.get(gAllocator));
+ child.get(gAllocator)->writeDOT(f, false);
+ }
+ }
+
+ if (root) fprintf(f, "}\n");
+}
+
+uint64_t AOTree::treeSize() const {
+ uint64_t s = 1;
+ for (auto &child : children) {
+ if (child) s += child.get(gAllocator)->treeSize();
+ }
+ return s;
+}
+
+
+template <typename F>
+static int pnum_sum(const array<ListAllocRef<AOTree, uint32_t>, 16> &arr, F func) {
+ int sum = 0;
+
+ for (int i = 0; i < 16; i++) {
+ if (arr[i]) {
+ int value = func(*arr[i].get(gAllocator));
+ if (value == -1) return -1;
+ sum += value;
+ }
+ }
+
+ return sum;
+}
+
+template <typename F>
+static int pnum_min(const array<ListAllocRef<AOTree, uint32_t>, 16> &arr, F func) {
+ int lowest = -1;
+
+ for (int i = 0; i < 16; i++) {
+ if (arr[i]) {
+ int value = func(*arr[i].get(gAllocator));
+ if (value == -1) continue;
+ if (lowest == -1) lowest = value;
+ else lowest = min(lowest, value);
+ }
+ }
+
+ return lowest;
+}
+
+
+static AOTree::eval_t evaluate(const Board &B) {
+#if 0
+ // Basic win check
+ switch (B.checkWin()) {
+ case WIN_NONE: return AOTree::UNKNOWN;
+ case WIN_P0: return AOTree::TRUE;
+ case WIN_P1: return AOTree::FALSE;
+ case WIN_DRAW: return AOTree::FALSE;
+ }
+ assert(false);
+#else
+ // Shallow minimax
+ auto func = B.playerToMove() == 0 ? minimax<0> : minimax<1>;
+ switch (func(B, -MINIMAX_LARGE, MINIMAX_LARGE, 4)) {
+ case 1: return AOTree::TRUE;
+ case -1: return AOTree::FALSE;
+ case 0: return AOTree::UNKNOWN;
+ }
+ assert(false);
+#endif
+}
+
+void generateAllChildren(AOTree &node) {
+ int onturn = node.B.playerToMove();
+ int onturn_child = 1 - onturn;
+ AOTree::type_t child_type = onturn_child == 0 ? AOTree::OR : AOTree::AND;
+
+ for (int i = 0; i < 16; i++) {
+ if (node.B.stkFull(i)) continue;
+
+ Board C(node.B);
+ C.drop(i, onturn);
+ node.children[i] = gAllocator.allocate(&node, C, child_type);
+ }
+}
+
+static void developNode(AOTree &node) {
+ generateAllChildren(node);
+
+ node.expanded = true;
+
+ for (int i = 0; i < 16; i++) {
+ if (!node.children[i]) continue;
+ AOTree &child = *node.children[i].get(gAllocator);
+
+ AOTree::eval_t eval = evaluate(child.B);
+ switch (eval) {
+ case AOTree::FALSE: child.pnum = -1; child.dnum = 0; break;
+ case AOTree::TRUE: child.pnum = 0; child.dnum = -1; break;
+ case AOTree::UNKNOWN: {
+ // ENHANCEMENT: initial proof/disproof numbers according to number of children
+ int numSubChildren = child.B.numValidMoves();
+ switch (child.type) {
+ case AOTree::OR: child.pnum = 1; child.dnum = numSubChildren; break;
+ case AOTree::AND: child.pnum = numSubChildren; child.dnum = 1; break;
+ }
+ break;
+ }
+ }
+ gEvalWins += eval == AOTree::TRUE || eval == AOTree::FALSE;
+ }
+}
+
+static void updateAncestors(AOTree &node_) {
+ AOTree *node = &node_;
+
+ while (node) {
+ assert(node->expanded);
+
+ switch (node->type) {
+ case AOTree::AND:
+ node->pnum = pnum_sum(node->children, [](const AOTree &n) { return n.pnum; });
+ node->dnum = pnum_min(node->children, [](const AOTree &n) { return n.dnum; });
+ break;
+
+ case AOTree::OR:
+ node->pnum = pnum_min(node->children, [](const AOTree &n) { return n.pnum; });
+ node->dnum = pnum_sum(node->children, [](const AOTree &n) { return n.dnum; });
+ break;
+ }
+
+ for (int i = 0; i < 16; i++) {
+ if (!node->children[i]) continue;
+
+ AOTree &child = *node->children[i].get(gAllocator);
+
+ // ENHANCEMENT: Delete solved subtrees
+ if (child.pnum == 0 || child.dnum == 0) {
+ gNodesPruned += child.treeSize();
+ node->children[i].deallocate(gAllocator);
+ }
+ }
+
+ node = node->parent;
+ }
+}
+
+static AOTree& selectMostProving(AOTree &node_) {
+ AOTree *node = &node_;
+
+ while (node->expanded) {
+ switch (node->type) {
+ case AOTree::OR: {
+ auto &laref = *find_if(node->children.begin(), node->children.end(),
+ [&node](ListAllocRef<AOTree, uint32_t> &child) {
+ if (!child) return false;
+ AOTree &obj = *child.get(gAllocator);
+ return obj.pnum == node->pnum;
+ });
+ node = laref.get(gAllocator);
+ break;
+ }
+
+ case AOTree::AND: {
+ auto &laref = *find_if(node->children.begin(), node->children.end(),
+ [&node](ListAllocRef<AOTree, uint32_t> &child) {
+ if (!child) return false;
+ AOTree &obj = *child.get(gAllocator);
+ return obj.dnum == node->dnum;
+ });
+ node = laref.get(gAllocator);
+ break;
+ }
+ }
+ }
+
+ return *node;
+}
+
+static AOTree::eval_t proofNumberSearch(AOTree &root) {
+ switch (evaluate(root.B)) {
+ case AOTree::TRUE: root.pnum = 0; root.dnum = -1; break;
+ case AOTree::FALSE: root.pnum = -1; root.dnum = 0; break;
+ case AOTree::UNKNOWN: break;
+ }
+
+ for (int iter = 0; root.pnum != 0 && root.dnum != 0; iter++) {
+ // char filename[128];
+ // sprintf(filename, "pntree-%d.dot", iter);
+ // FILE *f = fopen(filename, "w");
+ // root.writeDOT(f);
+ // fclose(f);
+
+ // if (iter == 20) break;
+
+ if (iter % 50000 == 0) {
+ printf("%d iterations done, %lu nodes allocated, %lu terminals, %lu pruned, root %d,%d\n", iter, gNodesAllocated, gEvalWins, gNodesPruned, root.pnum, root.dnum);
+ }
+
+ AOTree &mostProvingNode = selectMostProving(root);
+ developNode(mostProvingNode);
+ updateAncestors(mostProvingNode);
+ }
+
+ if (root.pnum == 0) return AOTree::TRUE;
+ else if (root.dnum == 0) return AOTree::FALSE;
+ else return AOTree::UNKNOWN;
+}
+
+int pnsearch(const Board &B) {
+ // printf("\x1B[1mWARNING: THIS IS GOING TO CONSUME A LOT OF MEMORY VERY QUICKLY. PLEASE WATCH IN htop AND BE READY TO ^C.\x1B[0m\n");
+
+ AOTree root(nullptr, B, B.numFilled() % 2 == 0 ? AOTree::OR : AOTree::AND);
+ AOTree::eval_t eval = proofNumberSearch(root);
+
+ switch (eval) {
+ case AOTree::TRUE: return 0;
+ case AOTree::FALSE: return 1;
+ case AOTree::UNKNOWN:
+ printf("UNKNOWN from proofNumberSearch!\n");
+ exit(0);
+ assert(false);
+ }
+
+ assert(false);
+}