summaryrefslogtreecommitdiff
path: root/ai_mcts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ai_mcts.cpp')
-rw-r--r--ai_mcts.cpp215
1 files changed, 215 insertions, 0 deletions
diff --git a/ai_mcts.cpp b/ai_mcts.cpp
new file mode 100644
index 0000000..b8a50ae
--- /dev/null
+++ b/ai_mcts.cpp
@@ -0,0 +1,215 @@
+#include <fstream>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <climits>
+#include <cmath>
+#include <cassert>
+#include "ai_mcts.h"
+
+
+static int mcts_niterations = 5000;
+// static int mcts_niterations = 5000;
+static int mcts_newnode_playouts = 3;
+
+
+static int playout(Board &bd, int player) {
+ Move poss[N * N * N];
+
+ while (true) {
+ int nposs = 0;
+ int winidx = -1;
+ bd.forEachMove(player, [&bd, &poss, &nposs, &winidx, player](Move mv) {
+ Board bd2 = bd;
+ int win = bd2.applyCW(mv);
+ if (win * player >= 0) {
+ poss[nposs++] = mv;
+ if (win != 0) {
+ winidx = nposs - 1;
+ return true;
+ }
+ }
+ return false;
+ });
+
+ if (nposs == 0) return -player;
+
+ int index = winidx == -1 ? rand() % nposs : winidx;
+
+ int win = bd.applyCW(poss[index]);
+ if (win != 0) return win;
+
+ player = -player;
+ }
+}
+
+struct Node {
+ int nwin = 0, ntotal = 0; // nwin: number of wins, as regarded from the player on turn after the parent node
+ bool terminal = false; // someone has won
+ Move inedge;
+ Node *parent;
+
+ vector<Node> children;
+ bool allExpanded = false;
+};
+
+static float scoreFormula(const Node &node) {
+ return (float)node.nwin / node.ntotal + sqrtf(2.0f * logf(node.parent->ntotal) / node.ntotal);
+}
+
+static Node& nodeSelect(Node &from, Board *bd, int *onturn) {
+ if (!from.allExpanded) return from;
+ assert(from.children.size() != 0);
+
+ float maxscore = -1;
+ Node *choice = nullptr;
+ for (Node &ch : from.children) {
+ float score = scoreFormula(ch);
+ if (score > maxscore) {
+ maxscore = score;
+ choice = &ch;
+ }
+ }
+
+ bd->apply(choice->inedge);
+ if (choice->terminal) {
+ return *choice;
+ }
+ *onturn = -*onturn;
+
+ return nodeSelect(*choice, bd, onturn);
+}
+
+static Node& expand(Node &from, const Board &bd, int onturn) {
+ if (from.terminal) {
+ from.ntotal += mcts_newnode_playouts;
+ if (from.nwin > 0) {
+ from.nwin += from.ntotal;
+ }
+ return from;
+ }
+
+ assert(!from.allExpanded);
+
+ if (from.children.size() == 0) {
+ // cerr << " expand: initialising children" << endl;
+ bd.forEachMove(onturn, [&from, &bd](Move mv) {
+ from.children.emplace_back();
+ Node &ch = from.children.back();
+ ch.inedge = mv;
+ ch.parent = &from;
+ return false;
+ });
+ }
+
+ vector<int> poss;
+ poss.reserve(from.children.size());
+
+ for (int i = 0; i < (int)from.children.size(); i++) {
+ if (from.children[i].ntotal == 0) {
+ poss.push_back(i);
+ }
+ }
+
+ assert(poss.size() != 0);
+ int index = poss[rand() % poss.size()];
+
+ // cerr << " expand: poss.size()=" << poss.size() << " index=" << index << " f.c[i].nt=" << from.children[index].ntotal << endl;
+
+ if (poss.size() == 1) from.allExpanded = true;
+
+ Node &node = from.children[index];
+
+ Board bd2 = bd;
+ int win = bd2.applyCW(node.inedge);
+ if (win != 0) {
+ node.terminal = true;
+ node.nwin = mcts_newnode_playouts * (win == onturn);
+ node.ntotal = mcts_newnode_playouts;
+ return node;
+ }
+
+ for (int i = 0; i < mcts_newnode_playouts; i++) {
+ Board bd3 = bd2;
+ win = playout(bd3, onturn);
+ node.nwin = win == onturn;
+ }
+
+ node.ntotal = mcts_newnode_playouts;
+
+ return node;
+}
+
+static void backPropagate(Node &node, int propWins, int propTotal) {
+ node.nwin += propWins;
+ node.ntotal += propTotal;
+
+ if (node.parent == nullptr) return;
+ backPropagate(*node.parent, propTotal - propWins, propTotal);
+}
+
+static string incrementalFilename() {
+ static int i = 1;
+ return "tree_" + to_string(i++) + ".dot";
+}
+
+static void writeTreeNode(const Node &node, ostream &stream, int maxdepth) {
+ stream << "\"" << &node << "\" [label=\"" << node.nwin << "/" << node.ntotal << "\\n" << node.inedge << "\"];\n";
+ if (maxdepth <= 0) return;
+
+ vector<const Node*> nexts;
+ nexts.reserve(node.children.size());
+ for (const Node &ch : node.children) {
+ nexts.push_back(&ch);
+ }
+ sort(nexts.begin(), nexts.end(), [](const Node *a, const Node *b) {
+ return a->ntotal < b->ntotal;
+ });
+
+ for (const Node *ch : nexts) {
+ stream << "\"" << &node << "\" -> \"" << ch << "\";\n";
+ writeTreeNode(*ch, stream, maxdepth - 1);
+ }
+}
+
+static void writeTree(const Node &root, const string &filename, int maxdepth = 1) {
+ ofstream f(filename);
+ assert(f);
+ f << "digraph G {\n";
+ writeTreeNode(root, f, maxdepth);
+ f << "}\n";
+ f.close();
+
+ cerr << "Wrote tree to \"" << filename << "\"" << endl;
+}
+
+Move AI::MCTS::findMove(const Board &bd, int player) {
+ Node root;
+ root.inedge = Move(-1, -1);
+ root.parent = nullptr;
+
+ for (int iter = 0; iter < mcts_niterations; iter++) {
+ // cerr << "ITERATION " << iter << " root.ntotal = " << root.ntotal << endl;
+ Board bd2 = bd;
+ int onturn = player;
+
+ Node &node = nodeSelect(root, &bd2, &onturn);
+ // cerr << "Selected " << &node << endl;
+ Node &newnode = expand(node, bd2, onturn);
+ // cerr << "Expanded " << &newnode << endl;
+ backPropagate(node, newnode.ntotal - newnode.nwin, newnode.ntotal);
+ }
+
+ int maxtotal = -1;
+ Move maxat;
+ for (const Node &node : root.children) {
+ if (node.ntotal > maxtotal) {
+ maxtotal = node.ntotal;
+ maxat = node.inedge;
+ }
+ }
+
+ // writeTree(root, incrementalFilename());
+
+ return maxat;
+}