summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom.smeding@gmail.com>2019-02-13 21:08:37 +0100
committerTom Smeding <tom.smeding@gmail.com>2019-02-13 21:08:37 +0100
commit6175cb1c53772cc92d91a39a254c38bdf8f64905 (patch)
tree3a7dd5f8810d0c859faaf5ab985afc51841899a8
Initial: working monte carlo AI
-rw-r--r--.gitignore2
-rw-r--r--Makefile35
-rw-r--r--board.cpp216
-rw-r--r--board.h85
-rw-r--r--main.cpp39
-rw-r--r--mc.cpp79
-rw-r--r--mc.h11
-rw-r--r--params.h17
-rw-r--r--util.h20
9 files changed, 504 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..4f457ba
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+ai
+.objs/
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..1a1e21d
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,35 @@
+CXX := g++
+CXXFLAGS := -Wall -Wextra -std=c++11 -g -O2
+LDFLAGS :=
+
+TARGET := ai
+OBJDIR := .objs
+
+CXX_SOURCES := $(wildcard *.cpp)
+OBJ_FILES := $(patsubst %.cpp,$(OBJDIR)/%.o,$(CXX_SOURCES))
+DEP_FILES := $(patsubst %.cpp,$(OBJDIR)/%.d,$(CXX_SOURCES))
+
+
+.PHONY: all clean
+
+all: $(TARGET)
+
+clean:
+ rm -f $(TARGET)
+ rm -rf $(OBJDIR)
+
+$(TARGET): $(OBJ_FILES)
+ @echo LD $@
+ @$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
+
+$(OBJDIR)/%.o: %.cpp
+ @mkdir -p $(dir $@)
+ @echo CXX $<
+ @$(CXX) $(CXXFLAGS) -c -o $@ $<
+
+$(OBJDIR)/%.d: %.cpp
+ @mkdir -p $(dir $@)
+ @echo DEP $<
+ @$(CXX) -MT $(OBJDIR)/$*.o -MM $(CXXFLAGS) $< >$@
+
+-include $(DEP_FILES)
diff --git a/board.cpp b/board.cpp
new file mode 100644
index 0000000..c84e72f
--- /dev/null
+++ b/board.cpp
@@ -0,0 +1,216 @@
+#include <cstdlib>
+#include <cstring>
+#include <cassert>
+#include "board.h"
+#include "util.h"
+
+using namespace std;
+
+
+Bag Bag::uninitialised() {
+ return Bag();
+}
+
+Bag Bag::makeFull() {
+ Bag bag;
+ for (int i = 0; i < NC; i++) bag.num[i] = N;
+ bag.sum = N * NC;
+ return bag;
+}
+
+int Bag::numLeft(uint8_t clr) const {
+ return num[clr - 1];
+}
+
+int Bag::totalLeft() const {
+ return sum;
+}
+
+uint8_t Bag::peekRandom() const {
+ assert(sum != 0);
+
+ int i = random() % sum;
+ int accum = 0;
+ for (int clr = 1; clr <= NC; clr++) {
+ accum += num[clr - 1];
+ if (i < accum) {
+ return clr;
+ }
+ }
+
+ assert(false);
+}
+
+uint8_t Bag::drawRandom() {
+ uint8_t clr = peekRandom();
+ num[clr - 1]--;
+ sum--;
+ return clr;
+}
+
+void Bag::drawColour(uint8_t clr) {
+ assert(num[clr - 1] != 0);
+ num[clr - 1]--;
+ sum--;
+}
+
+void Bag::replace(uint8_t clr) {
+ num[clr - 1]++;
+ sum++;
+}
+
+
+Board Board::makeEmpty() {
+ static_assert(NC + 1 <= 256, "Too many colours");
+
+ Board bd;
+ bd.bag = Bag::makeFull();
+ memset(bd.bd, 0, BSZ * BSZ);
+ return bd;
+}
+
+// Do not call with clr == 0
+int Board::countStones(uint8_t clr, int idx, int delta) const {
+ // Since clr != 0 and the stones will never reach the edge of the board
+ // (it's too large for that), this loop will always terminate before
+ // accessing invalid memory.
+ int num = 0;
+ while (bd[idx] == clr) {
+ num++;
+ idx += delta;
+ }
+ return num;
+}
+
+void Board::newEdgeCand(int idx) {
+ // cerr << "(newEdgeCand(" << Idx(idx) << "))" << endl;
+ if (!inEdgeCands.test(idx)) {
+ edgeCands.push_back(idx);
+ inEdgeCands.set(idx);
+ }
+}
+
+void Board::put(int idx, uint8_t clr) {
+ bd[idx] = clr;
+
+ newEdgeCand(idx - 1);
+ newEdgeCand(idx + 1);
+ newEdgeCand(idx - BSZ);
+ newEdgeCand(idx + BSZ);
+}
+
+void Board::undo(int idx) {
+ bd[idx] = 0;
+
+ newEdgeCand(idx);
+}
+
+uint8_t Board::putCW(int idx, uint8_t clr) {
+ put(idx, clr);
+
+ int count[8];
+#define DO_COUNT_STONES(_i, _dx, _dy) do { int _delta = BSZ * (_dy) + (_dx); count[_i] = countStones(clr, idx + _delta, _delta); } while (0)
+ DO_COUNT_STONES(0, 0, -1);
+ DO_COUNT_STONES(1, 1, -1);
+ DO_COUNT_STONES(2, 1, 0);
+ DO_COUNT_STONES(3, 1, 1);
+ DO_COUNT_STONES(4, 0, 1);
+ DO_COUNT_STONES(5, -1, 1);
+ DO_COUNT_STONES(6, -1, 0);
+ DO_COUNT_STONES(7, -1, -1);
+#undef DO_COUNT_STONES
+
+ for (int i = 0; i < 4; i++) {
+ if (1 + count[i] + count[4 + i] >= RLEN) return clr;
+ }
+
+ return 0;
+}
+
+bool Board::checkEdge(int idx) const {
+ // Because there are always two spaces free at the ends of the board, we
+ // can safely inspect the neighbours of this cell. This is because we only
+ // call this function on cells that have either had or neighboured a stone.
+ return bd[idx - 1] || bd[idx + 1] || bd[idx - BSZ] || bd[idx + BSZ];
+}
+
+void Board::forEachMove(const function<void(int idx)> &f) {
+ vector<int> eC = edgeCands;
+ bitset<BSZ * BSZ> iEC = inEdgeCands;
+
+ size_t j = 0;
+ for (size_t i = 0; i < eC.size(); i++) {
+ int idx = eC[i];
+ // cerr << "(fEM: candidate " << Idx(idx) << "; ";
+ if (bd[idx] == 0 && checkEdge(idx)) {
+ // cerr << "edge)" << endl;
+ f(idx);
+ eC[j++] = idx;
+ } else {
+ // cerr << "-)" << endl;
+ iEC.reset(idx);
+ }
+ }
+
+ eC.resize(j);
+
+ swap(edgeCands, eC);
+ swap(inEdgeCands, iEC);
+}
+
+Bounds Board::computeBounds() const {
+ Bounds bounds;
+
+ for (int y = 1; y < BSZ - 1; y++) {
+ for (int x = 1; x < BSZ - 1; x++) {
+ int idx = BSZ * y + x;
+ if (bd[idx] != 0 || checkEdge(idx)) {
+ bounds.left = min(bounds.left, x);
+ bounds.right = max(bounds.right, x);
+ bounds.top = min(bounds.top, y);
+ bounds.bottom = max(bounds.bottom, y);
+ }
+ }
+ }
+
+ return bounds;
+}
+
+void Board::write(ostream &os) const {
+ static const char *edgeStr = "\x1B[36m+\x1B[0m";
+ static const char *openStr = "ยท";
+
+ Bounds bounds = computeBounds();
+
+ for (int y = bounds.top; y <= bounds.bottom; y++) {
+ for (int x = bounds.left; x <= bounds.right; x++) {
+ if (x != bounds.left) os << ' ';
+
+ int idx = BSZ * y + x;
+ if (bd[idx] != 0) os << Stone(bd[idx]);
+ else if (checkEdge(idx)) os << edgeStr;
+ else os << openStr;
+ }
+ os << endl;
+ }
+}
+
+ostream& operator<<(ostream &os, Stone stone) {
+ static const char *alphabet[] = {
+ "\x1B[1;31mR\x1B[0m",
+ "\x1B[1;34mB\x1B[0m",
+ "\x1B[1;33mY\x1B[0m",
+ };
+ static_assert(
+ NC <= sizeof(alphabet) / sizeof(alphabet[0]),
+ "Increase alphabet in Board::write");
+
+ uint8_t clr = stone.clr;
+ assert(1 <= clr && clr <= NC);
+ return os << alphabet[clr - 1];
+}
+
+ostream& operator<<(ostream &os, const Board &bd) {
+ bd.write(os);
+ return os;
+}
diff --git a/board.h b/board.h
new file mode 100644
index 0000000..a144ed7
--- /dev/null
+++ b/board.h
@@ -0,0 +1,85 @@
+#pragma once
+
+#include <iostream>
+#include <vector>
+#include <bitset>
+#include <functional>
+#include <cstdint>
+#include "params.h"
+
+using namespace std;
+
+
+class Bag {
+public:
+ static Bag uninitialised();
+ static Bag makeFull();
+
+ int numLeft(uint8_t clr) const;
+ int totalLeft() const;
+ uint8_t drawRandom();
+ uint8_t peekRandom() const; // picks random colour but doesn't draw
+ void drawColour(uint8_t clr);
+ void replace(uint8_t clr);
+
+private:
+ // Number of stones left of the given colours. Index with clr-1.
+ int num[NC];
+
+ // Sum of num[0 .. NC-1].
+ int sum;
+
+ Bag() = default;
+};
+
+struct Bounds {
+ int left = BSZ, right = -1, top = BSZ, bottom = -1;
+};
+
+class Board {
+public:
+ // Up to the user to keep this in sync with the board. (It is initialised
+ // in makeEmpty().)
+ Bag bag = Bag::uninitialised();
+
+ static Board makeEmpty();
+
+ inline const uint8_t& operator[](int idx) const { return bd[idx]; }
+
+ void put(int idx, uint8_t clr);
+ void undo(int idx);
+
+ uint8_t putCW(int idx, uint8_t clr);
+
+ // The callback may modify the board, but must leave it as it was after returning.
+ void forEachMove(const function<void(int idx)> &f);
+
+ void write(ostream &os) const;
+
+ Bounds computeBounds() const;
+
+private:
+ // 0 = empty, 1...NC = coloured stones
+ uint8_t bd[BSZ * BSZ];
+
+ // Candidates for edge cells; all cells that can take a stone must be
+ // elements of this list, but there may be more.
+ vector<int> edgeCands;
+
+ // Whether a particular cell is in edgeCands.
+ bitset<BSZ * BSZ> inEdgeCands;
+
+ Board() = default;
+
+ int countStones(uint8_t clr, int idx, int delta) const;
+ void newEdgeCand(int idx);
+ bool checkEdge(int idx) const;
+};
+
+struct Stone {
+ uint8_t clr;
+ inline Stone(uint8_t clr) : clr(clr) {}
+};
+
+ostream& operator<<(ostream &os, Stone stone);
+ostream& operator<<(ostream &os, const Board &bd);
diff --git a/main.cpp b/main.cpp
new file mode 100644
index 0000000..80d0fb8
--- /dev/null
+++ b/main.cpp
@@ -0,0 +1,39 @@
+#include <iostream>
+#include <cstdlib>
+#include <sys/time.h>
+#include "board.h"
+#include "mc.h"
+#include "util.h"
+
+using namespace std;
+
+
+int main() {
+ struct timeval tv;
+ gettimeofday(&tv, nullptr);
+ srandom(tv.tv_sec * 1000000U + tv.tv_usec);
+
+ Board bd = Board::makeEmpty();
+ cerr << "Initial stone at " << Idx(BSZ * BMID + BMID) << endl;
+ bd.put(BSZ * BMID + BMID, 1);
+
+ cout << bd << endl;
+
+ uint8_t onturn = 2;
+ while (bd.bag.totalLeft() > 0) {
+ cout << "--- NEXT TURN: " << Stone(onturn) << " ---" << endl;
+
+ int idx = MC::calcMove(bd, onturn);
+ uint8_t clr = bd.bag.drawRandom();
+ uint8_t win = bd.putCW(idx, clr);
+
+ cout << bd << endl;
+
+ if (win != 0) {
+ cout << "Winner: " << Stone(clr) << endl;
+ break;
+ }
+
+ onturn = NEXTTURN(onturn);
+ }
+}
diff --git a/mc.cpp b/mc.cpp
new file mode 100644
index 0000000..5e64102
--- /dev/null
+++ b/mc.cpp
@@ -0,0 +1,79 @@
+#include <iostream>
+#include <cassert>
+#include <climits>
+#include "mc.h"
+#include "util.h"
+
+using namespace std;
+
+#define NPLAYOUTS 1000
+#define SCORE_WIN 1
+#define SCORE_LOSE (-1)
+#define SCORE_TIE 0
+
+
+// Takes copy of board, since it probably isn't worth it to undo the whole rest
+// of the game.
+// TODO: Multiply the playout score with its cumulative probability (which is
+// pretty small!) to get a probabilistically correct estimate of the expected
+// score.
+static int playout(Board bd, uint8_t myclr) {
+ // cerr << " PLAYOUT" << endl;
+ while (bd.bag.totalLeft() > 0) {
+ uint8_t clr = bd.bag.drawRandom();
+
+ int moves[BSZ * BSZ], nmoves = 0;
+ bd.forEachMove([&moves, &nmoves](int idx) { moves[nmoves++] = idx; });
+
+ assert(nmoves > 0);
+ int idx = moves[random() % nmoves];
+
+ // cerr << " idx = " << Idx(idx) << " clr=" << (unsigned)clr << endl;
+
+ uint8_t win = bd.putCW(idx, clr);
+ if (win != 0) {
+ return win == myclr ? SCORE_WIN : SCORE_LOSE;
+ }
+ }
+
+ return SCORE_TIE;
+}
+
+
+int MC::calcMove(Board &bd, uint8_t myclr) {
+ assert(bd.bag.totalLeft() > 0);
+
+ float maxscore = INT_MIN;
+ int maxat = -1;
+
+ bd.forEachMove([&bd, myclr, &maxscore, &maxat](int idx) {
+ // cerr << "MC::calcMove: trying idx=" << Idx(idx) << endl;
+ float score = 0;
+
+ for (int i = 0; i < NPLAYOUTS; i++) {
+ // cerr << "playout " << i << endl;
+
+ uint8_t clr = bd.bag.peekRandom();
+ float probability = (float)bd.bag.numLeft(clr) / bd.bag.totalLeft();
+ bd.bag.drawColour(clr);
+
+ // cerr << " random clr=" << (unsigned)clr << endl;
+ uint8_t win = bd.putCW(idx, clr);
+ if (win != 0) {
+ score += probability * (win == myclr ? SCORE_WIN : SCORE_LOSE);
+ } else {
+ score += probability * playout(bd, myclr);
+ }
+
+ bd.bag.replace(clr);
+ bd.undo(idx);
+ }
+
+ if (score > maxscore) {
+ maxscore = score;
+ maxat = idx;
+ }
+ });
+
+ return maxat;
+}
diff --git a/mc.h b/mc.h
new file mode 100644
index 0000000..d4221d9
--- /dev/null
+++ b/mc.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "board.h"
+
+using namespace std;
+
+
+namespace MC {
+ // bd will be unchanged upon return
+ int calcMove(Board &bd, uint8_t myclr);
+}
diff --git a/params.h b/params.h
new file mode 100644
index 0000000..c9413b6
--- /dev/null
+++ b/params.h
@@ -0,0 +1,17 @@
+#pragma once
+
+
+// number of stones of 1 colour
+#define N 26
+
+// number of colours
+#define NC 2
+
+// winning row length
+#define RLEN 5
+
+// board size; this leaves 4 / 2 = 2 spaces around the board at all times
+#define BSZ (2 * N * NC - 1 + 4)
+
+// board middle, coords of first stone
+#define BMID (BSZ / 2)
diff --git a/util.h b/util.h
new file mode 100644
index 0000000..e890555
--- /dev/null
+++ b/util.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include <iostream>
+#include "params.h"
+
+using namespace std;
+
+
+#define NEXTTURN(_clr) ((_clr) % NC + 1)
+
+
+struct Idx {
+ inline Idx(int idx) : x(idx % BSZ), y(idx / BSZ) {}
+
+ int x, y;
+};
+
+inline ostream& operator<<(ostream &os, const Idx &obj) {
+ return os << '(' << obj.x << ',' << obj.y << ')';
+}