#include #include #include #include #include #include #include #include "ai_mcts.h" static int mcts_niterations = 5000; // static int mcts_niterations = 5000; static int mcts_newnode_playouts = 3; static int playout(Board &bd, int player) { Move poss[N * N * N]; while (true) { int nposs = 0; int winidx = -1; bd.forEachMove(player, [&bd, &poss, &nposs, &winidx, player](Move mv) { Board bd2 = bd; int win = bd2.applyCW(mv); if (win * player >= 0) { poss[nposs++] = mv; if (win != 0) { winidx = nposs - 1; return true; } } return false; }); if (nposs == 0) return -player; int index = winidx == -1 ? rand() % nposs : winidx; int win = bd.applyCW(poss[index]); if (win != 0) return win; player = -player; } } struct Node { int nwin = 0, ntotal = 0; // nwin: number of wins, as regarded from the player on turn after the parent node bool terminal = false; // someone has won Move inedge; Node *parent; vector children; bool allExpanded = false; }; static float scoreFormula(const Node &node) { return (float)node.nwin / node.ntotal + sqrtf(2.0f * logf(node.parent->ntotal) / node.ntotal); } static Node& nodeSelect(Node &from, Board *bd, int *onturn) { if (!from.allExpanded) return from; assert(from.children.size() != 0); float maxscore = -1; Node *choice = nullptr; for (Node &ch : from.children) { float score = scoreFormula(ch); if (score > maxscore) { maxscore = score; choice = &ch; } } bd->apply(choice->inedge); if (choice->terminal) { return *choice; } *onturn = -*onturn; return nodeSelect(*choice, bd, onturn); } static Node& expand(Node &from, const Board &bd, int onturn) { if (from.terminal) { from.ntotal += mcts_newnode_playouts; if (from.nwin > 0) { from.nwin += from.ntotal; } return from; } assert(!from.allExpanded); if (from.children.size() == 0) { // cerr << " expand: initialising children" << endl; bd.forEachMove(onturn, [&from, &bd](Move mv) { from.children.emplace_back(); Node &ch = from.children.back(); ch.inedge = mv; ch.parent = &from; return false; }); } vector poss; poss.reserve(from.children.size()); for (int i = 0; i < (int)from.children.size(); i++) { if (from.children[i].ntotal == 0) { poss.push_back(i); } } assert(poss.size() != 0); int index = poss[rand() % poss.size()]; // cerr << " expand: poss.size()=" << poss.size() << " index=" << index << " f.c[i].nt=" << from.children[index].ntotal << endl; if (poss.size() == 1) from.allExpanded = true; Node &node = from.children[index]; Board bd2 = bd; int win = bd2.applyCW(node.inedge); if (win != 0) { node.terminal = true; node.nwin = mcts_newnode_playouts * (win == onturn); node.ntotal = mcts_newnode_playouts; return node; } for (int i = 0; i < mcts_newnode_playouts; i++) { Board bd3 = bd2; win = playout(bd3, onturn); node.nwin = win == onturn; } node.ntotal = mcts_newnode_playouts; return node; } static void backPropagate(Node &node, int propWins, int propTotal) { node.nwin += propWins; node.ntotal += propTotal; if (node.parent == nullptr) return; backPropagate(*node.parent, propTotal - propWins, propTotal); } static string incrementalFilename() { static int i = 1; return "tree_" + to_string(i++) + ".dot"; } static void writeTreeNode(const Node &node, ostream &stream, int maxdepth) { stream << "\"" << &node << "\" [label=\"" << node.nwin << "/" << node.ntotal << "\\n" << node.inedge << "\"];\n"; if (maxdepth <= 0) return; vector nexts; nexts.reserve(node.children.size()); for (const Node &ch : node.children) { nexts.push_back(&ch); } sort(nexts.begin(), nexts.end(), [](const Node *a, const Node *b) { return a->ntotal < b->ntotal; }); for (const Node *ch : nexts) { stream << "\"" << &node << "\" -> \"" << ch << "\";\n"; writeTreeNode(*ch, stream, maxdepth - 1); } } static void writeTree(const Node &root, const string &filename, int maxdepth = 1) { ofstream f(filename); assert(f); f << "digraph G {\n"; writeTreeNode(root, f, maxdepth); f << "}\n"; f.close(); cerr << "Wrote tree to \"" << filename << "\"" << endl; } Move AI::MCTS::findMove(const Board &bd, int player) { Node root; root.inedge = Move(-1, -1); root.parent = nullptr; for (int iter = 0; iter < mcts_niterations; iter++) { // cerr << "ITERATION " << iter << " root.ntotal = " << root.ntotal << endl; Board bd2 = bd; int onturn = player; Node &node = nodeSelect(root, &bd2, &onturn); // cerr << "Selected " << &node << endl; Node &newnode = expand(node, bd2, onturn); // cerr << "Expanded " << &newnode << endl; backPropagate(node, newnode.ntotal - newnode.nwin, newnode.ntotal); } int maxtotal = -1; Move maxat; for (const Node &node : root.children) { if (node.ntotal > maxtotal) { maxtotal = node.ntotal; maxat = node.inedge; } } // writeTree(root, incrementalFilename()); return maxat; }