diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | ai_mcts.cpp | 215 | ||||
-rw-r--r-- | ai_mcts.h | 8 | ||||
-rw-r--r-- | main.cpp | 5 |
4 files changed, 227 insertions, 3 deletions
@@ -1,5 +1,5 @@ CXX = g++ -CXXFLAGS = -Wall -Wextra -O3 -std=c++17 -fwrapv -flto +CXXFLAGS = -Wall -Wextra -O3 -g -std=c++17 -fwrapv -flto TARGET = main diff --git a/ai_mcts.cpp b/ai_mcts.cpp new file mode 100644 index 0000000..b8a50ae --- /dev/null +++ b/ai_mcts.cpp @@ -0,0 +1,215 @@ +#include <fstream> +#include <string> +#include <vector> +#include <algorithm> +#include <climits> +#include <cmath> +#include <cassert> +#include "ai_mcts.h" + + +static int mcts_niterations = 5000; +// static int mcts_niterations = 5000; +static int mcts_newnode_playouts = 3; + + +static int playout(Board &bd, int player) { + Move poss[N * N * N]; + + while (true) { + int nposs = 0; + int winidx = -1; + bd.forEachMove(player, [&bd, &poss, &nposs, &winidx, player](Move mv) { + Board bd2 = bd; + int win = bd2.applyCW(mv); + if (win * player >= 0) { + poss[nposs++] = mv; + if (win != 0) { + winidx = nposs - 1; + return true; + } + } + return false; + }); + + if (nposs == 0) return -player; + + int index = winidx == -1 ? rand() % nposs : winidx; + + int win = bd.applyCW(poss[index]); + if (win != 0) return win; + + player = -player; + } +} + +struct Node { + int nwin = 0, ntotal = 0; // nwin: number of wins, as regarded from the player on turn after the parent node + bool terminal = false; // someone has won + Move inedge; + Node *parent; + + vector<Node> children; + bool allExpanded = false; +}; + +static float scoreFormula(const Node &node) { + return (float)node.nwin / node.ntotal + sqrtf(2.0f * logf(node.parent->ntotal) / node.ntotal); +} + +static Node& nodeSelect(Node &from, Board *bd, int *onturn) { + if (!from.allExpanded) return from; + assert(from.children.size() != 0); + + float maxscore = -1; + Node *choice = nullptr; + for (Node &ch : from.children) { + float score = scoreFormula(ch); + if (score > maxscore) { + maxscore = score; + choice = &ch; + } + } + + bd->apply(choice->inedge); + if (choice->terminal) { + return *choice; + } + *onturn = -*onturn; + + return nodeSelect(*choice, bd, onturn); +} + +static Node& expand(Node &from, const Board &bd, int onturn) { + if (from.terminal) { + from.ntotal += mcts_newnode_playouts; + if (from.nwin > 0) { + from.nwin += from.ntotal; + } + return from; + } + + assert(!from.allExpanded); + + if (from.children.size() == 0) { + // cerr << " expand: initialising children" << endl; + bd.forEachMove(onturn, [&from, &bd](Move mv) { + from.children.emplace_back(); + Node &ch = from.children.back(); + ch.inedge = mv; + ch.parent = &from; + return false; + }); + } + + vector<int> poss; + poss.reserve(from.children.size()); + + for (int i = 0; i < (int)from.children.size(); i++) { + if (from.children[i].ntotal == 0) { + poss.push_back(i); + } + } + + assert(poss.size() != 0); + int index = poss[rand() % poss.size()]; + + // cerr << " expand: poss.size()=" << poss.size() << " index=" << index << " f.c[i].nt=" << from.children[index].ntotal << endl; + + if (poss.size() == 1) from.allExpanded = true; + + Node &node = from.children[index]; + + Board bd2 = bd; + int win = bd2.applyCW(node.inedge); + if (win != 0) { + node.terminal = true; + node.nwin = mcts_newnode_playouts * (win == onturn); + node.ntotal = mcts_newnode_playouts; + return node; + } + + for (int i = 0; i < mcts_newnode_playouts; i++) { + Board bd3 = bd2; + win = playout(bd3, onturn); + node.nwin = win == onturn; + } + + node.ntotal = mcts_newnode_playouts; + + return node; +} + +static void backPropagate(Node &node, int propWins, int propTotal) { + node.nwin += propWins; + node.ntotal += propTotal; + + if (node.parent == nullptr) return; + backPropagate(*node.parent, propTotal - propWins, propTotal); +} + +static string incrementalFilename() { + static int i = 1; + return "tree_" + to_string(i++) + ".dot"; +} + +static void writeTreeNode(const Node &node, ostream &stream, int maxdepth) { + stream << "\"" << &node << "\" [label=\"" << node.nwin << "/" << node.ntotal << "\\n" << node.inedge << "\"];\n"; + if (maxdepth <= 0) return; + + vector<const Node*> nexts; + nexts.reserve(node.children.size()); + for (const Node &ch : node.children) { + nexts.push_back(&ch); + } + sort(nexts.begin(), nexts.end(), [](const Node *a, const Node *b) { + return a->ntotal < b->ntotal; + }); + + for (const Node *ch : nexts) { + stream << "\"" << &node << "\" -> \"" << ch << "\";\n"; + writeTreeNode(*ch, stream, maxdepth - 1); + } +} + +static void writeTree(const Node &root, const string &filename, int maxdepth = 1) { + ofstream f(filename); + assert(f); + f << "digraph G {\n"; + writeTreeNode(root, f, maxdepth); + f << "}\n"; + f.close(); + + cerr << "Wrote tree to \"" << filename << "\"" << endl; +} + +Move AI::MCTS::findMove(const Board &bd, int player) { + Node root; + root.inedge = Move(-1, -1); + root.parent = nullptr; + + for (int iter = 0; iter < mcts_niterations; iter++) { + // cerr << "ITERATION " << iter << " root.ntotal = " << root.ntotal << endl; + Board bd2 = bd; + int onturn = player; + + Node &node = nodeSelect(root, &bd2, &onturn); + // cerr << "Selected " << &node << endl; + Node &newnode = expand(node, bd2, onturn); + // cerr << "Expanded " << &newnode << endl; + backPropagate(node, newnode.ntotal - newnode.nwin, newnode.ntotal); + } + + int maxtotal = -1; + Move maxat; + for (const Node &node : root.children) { + if (node.ntotal > maxtotal) { + maxtotal = node.ntotal; + maxat = node.inedge; + } + } + + // writeTree(root, incrementalFilename()); + + return maxat; +} diff --git a/ai_mcts.h b/ai_mcts.h new file mode 100644 index 0000000..1ef5d7f --- /dev/null +++ b/ai_mcts.h @@ -0,0 +1,8 @@ +#pragma once + +#include "board.h" + + +namespace AI::MCTS { + Move findMove(const Board &bd, int player); +} @@ -4,12 +4,13 @@ #include "board.h" #include "ai_mc.h" #include "ai_mm.h" +#include "ai_mcts.h" #include "ai_rand.h" using namespace std; #ifndef AI_CHOICE -#define AI_CHOICE MC +#define AI_CHOICE MCTS #endif #define STR_(x) #x @@ -51,7 +52,7 @@ int main() { onturn = -onturn; } - int win; + int win = 0; string line; while (true) { |