diff options
Diffstat (limited to 'pnsearch.cpp')
-rw-r--r-- | pnsearch.cpp | 286 |
1 files changed, 286 insertions, 0 deletions
diff --git a/pnsearch.cpp b/pnsearch.cpp new file mode 100644 index 0000000..a09d362 --- /dev/null +++ b/pnsearch.cpp @@ -0,0 +1,286 @@ +#include <array> +#include <memory> +#include <algorithm> +#include <cstdio> +#include <cstring> +#include <cassert> +#include "pnsearch.h" +#include "minimax.h" +#include "listalloc.h" + +using namespace std; + + +static uint64_t gNodesAllocated = 0; +static uint64_t gEvalWins = 0; +static uint64_t gNodesPruned = 0; + +struct AOTree { + enum type_t { AND, OR }; + enum eval_t { TRUE, FALSE, UNKNOWN }; + + AOTree *const parent; + const Board B; + const type_t type; + + bool expanded = false; + int pnum = 1, dnum = 1; + array<ListAllocRef<AOTree, uint32_t>, 16> children; + + AOTree(AOTree *parent, Board B, type_t type) noexcept; + ~AOTree() noexcept; + + void writeDOT(FILE *f, bool root = true) const; + uint64_t treeSize() const; +}; + +static ListAlloc<AOTree, uint32_t> gAllocator(6ULL * 1024 * 1024 * 1024 / sizeof(AOTree)); + +AOTree::AOTree(AOTree *parent, Board B, type_t type) noexcept + : parent(parent), B(B), type(type) { + + gNodesAllocated++; +} + +AOTree::~AOTree() noexcept { + gNodesAllocated--; + + for (int i = 0; i < 16; i++) { + children[i].deallocate(gAllocator); + } +} + +void AOTree::writeDOT(FILE *f, bool root) const { + if (root) fprintf(f, "digraph G {\n"); + + fprintf(f, "\t\"%p\" [shape=%s, label=\"%d %d\"];\n", + this, type == AND ? "ellipse" : "rectangle", + pnum, dnum); + + for (auto &child : children) { + if (child) { + fprintf(f, "\t\"%p\" -> \"%p\"\n", this, child.get(gAllocator)); + child.get(gAllocator)->writeDOT(f, false); + } + } + + if (root) fprintf(f, "}\n"); +} + +uint64_t AOTree::treeSize() const { + uint64_t s = 1; + for (auto &child : children) { + if (child) s += child.get(gAllocator)->treeSize(); + } + return s; +} + + +template <typename F> +static int pnum_sum(const array<ListAllocRef<AOTree, uint32_t>, 16> &arr, F func) { + int sum = 0; + + for (int i = 0; i < 16; i++) { + if (arr[i]) { + int value = func(*arr[i].get(gAllocator)); + if (value == -1) return -1; + sum += value; + } + } + + return sum; +} + +template <typename F> +static int pnum_min(const array<ListAllocRef<AOTree, uint32_t>, 16> &arr, F func) { + int lowest = -1; + + for (int i = 0; i < 16; i++) { + if (arr[i]) { + int value = func(*arr[i].get(gAllocator)); + if (value == -1) continue; + if (lowest == -1) lowest = value; + else lowest = min(lowest, value); + } + } + + return lowest; +} + + +static AOTree::eval_t evaluate(const Board &B) { +#if 0 + // Basic win check + switch (B.checkWin()) { + case WIN_NONE: return AOTree::UNKNOWN; + case WIN_P0: return AOTree::TRUE; + case WIN_P1: return AOTree::FALSE; + case WIN_DRAW: return AOTree::FALSE; + } + assert(false); +#else + // Shallow minimax + auto func = B.playerToMove() == 0 ? minimax<0> : minimax<1>; + switch (func(B, -MINIMAX_LARGE, MINIMAX_LARGE, 4)) { + case 1: return AOTree::TRUE; + case -1: return AOTree::FALSE; + case 0: return AOTree::UNKNOWN; + } + assert(false); +#endif +} + +void generateAllChildren(AOTree &node) { + int onturn = node.B.playerToMove(); + int onturn_child = 1 - onturn; + AOTree::type_t child_type = onturn_child == 0 ? AOTree::OR : AOTree::AND; + + for (int i = 0; i < 16; i++) { + if (node.B.stkFull(i)) continue; + + Board C(node.B); + C.drop(i, onturn); + node.children[i] = gAllocator.allocate(&node, C, child_type); + } +} + +static void developNode(AOTree &node) { + generateAllChildren(node); + + node.expanded = true; + + for (int i = 0; i < 16; i++) { + if (!node.children[i]) continue; + AOTree &child = *node.children[i].get(gAllocator); + + AOTree::eval_t eval = evaluate(child.B); + switch (eval) { + case AOTree::FALSE: child.pnum = -1; child.dnum = 0; break; + case AOTree::TRUE: child.pnum = 0; child.dnum = -1; break; + case AOTree::UNKNOWN: { + // ENHANCEMENT: initial proof/disproof numbers according to number of children + int numSubChildren = child.B.numValidMoves(); + switch (child.type) { + case AOTree::OR: child.pnum = 1; child.dnum = numSubChildren; break; + case AOTree::AND: child.pnum = numSubChildren; child.dnum = 1; break; + } + break; + } + } + gEvalWins += eval == AOTree::TRUE || eval == AOTree::FALSE; + } +} + +static void updateAncestors(AOTree &node_) { + AOTree *node = &node_; + + while (node) { + assert(node->expanded); + + switch (node->type) { + case AOTree::AND: + node->pnum = pnum_sum(node->children, [](const AOTree &n) { return n.pnum; }); + node->dnum = pnum_min(node->children, [](const AOTree &n) { return n.dnum; }); + break; + + case AOTree::OR: + node->pnum = pnum_min(node->children, [](const AOTree &n) { return n.pnum; }); + node->dnum = pnum_sum(node->children, [](const AOTree &n) { return n.dnum; }); + break; + } + + for (int i = 0; i < 16; i++) { + if (!node->children[i]) continue; + + AOTree &child = *node->children[i].get(gAllocator); + + // ENHANCEMENT: Delete solved subtrees + if (child.pnum == 0 || child.dnum == 0) { + gNodesPruned += child.treeSize(); + node->children[i].deallocate(gAllocator); + } + } + + node = node->parent; + } +} + +static AOTree& selectMostProving(AOTree &node_) { + AOTree *node = &node_; + + while (node->expanded) { + switch (node->type) { + case AOTree::OR: { + auto &laref = *find_if(node->children.begin(), node->children.end(), + [&node](ListAllocRef<AOTree, uint32_t> &child) { + if (!child) return false; + AOTree &obj = *child.get(gAllocator); + return obj.pnum == node->pnum; + }); + node = laref.get(gAllocator); + break; + } + + case AOTree::AND: { + auto &laref = *find_if(node->children.begin(), node->children.end(), + [&node](ListAllocRef<AOTree, uint32_t> &child) { + if (!child) return false; + AOTree &obj = *child.get(gAllocator); + return obj.dnum == node->dnum; + }); + node = laref.get(gAllocator); + break; + } + } + } + + return *node; +} + +static AOTree::eval_t proofNumberSearch(AOTree &root) { + switch (evaluate(root.B)) { + case AOTree::TRUE: root.pnum = 0; root.dnum = -1; break; + case AOTree::FALSE: root.pnum = -1; root.dnum = 0; break; + case AOTree::UNKNOWN: break; + } + + for (int iter = 0; root.pnum != 0 && root.dnum != 0; iter++) { + // char filename[128]; + // sprintf(filename, "pntree-%d.dot", iter); + // FILE *f = fopen(filename, "w"); + // root.writeDOT(f); + // fclose(f); + + // if (iter == 20) break; + + if (iter % 50000 == 0) { + printf("%d iterations done, %lu nodes allocated, %lu terminals, %lu pruned, root %d,%d\n", iter, gNodesAllocated, gEvalWins, gNodesPruned, root.pnum, root.dnum); + } + + AOTree &mostProvingNode = selectMostProving(root); + developNode(mostProvingNode); + updateAncestors(mostProvingNode); + } + + if (root.pnum == 0) return AOTree::TRUE; + else if (root.dnum == 0) return AOTree::FALSE; + else return AOTree::UNKNOWN; +} + +int pnsearch(const Board &B) { + // printf("\x1B[1mWARNING: THIS IS GOING TO CONSUME A LOT OF MEMORY VERY QUICKLY. PLEASE WATCH IN htop AND BE READY TO ^C.\x1B[0m\n"); + + AOTree root(nullptr, B, B.numFilled() % 2 == 0 ? AOTree::OR : AOTree::AND); + AOTree::eval_t eval = proofNumberSearch(root); + + switch (eval) { + case AOTree::TRUE: return 0; + case AOTree::FALSE: return 1; + case AOTree::UNKNOWN: + printf("UNKNOWN from proofNumberSearch!\n"); + exit(0); + assert(false); + } + + assert(false); +} |