summaryrefslogtreecommitdiff
path: root/solve_bt.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'solve_bt.cpp')
-rw-r--r--solve_bt.cpp248
1 files changed, 248 insertions, 0 deletions
diff --git a/solve_bt.cpp b/solve_bt.cpp
new file mode 100644
index 0000000..418e286
--- /dev/null
+++ b/solve_bt.cpp
@@ -0,0 +1,248 @@
+#include <iostream>
+#include <iomanip>
+#include <string>
+#include <vector>
+#include <cstring>
+#include <cstdint>
+#include <cassert>
+
+using namespace std;
+
+
+struct Sudoku {
+ int8_t bd[81];
+ uint16_t poss[81];
+ int nfull = 0;
+
+ Sudoku() {
+ for (int i = 0; i < 81; i++) {
+ poss[i] = 0x1ff;
+ }
+ }
+
+ inline int8_t& operator[](int8_t i) {return bd[i];}
+ inline const int8_t& operator[](int8_t i) const {return bd[i];}
+};
+
+static istream& operator>>(istream &in, Sudoku &su) {
+ su.nfull = 0;
+
+ for (int i = 0; i < 81; i++) {
+ char c;
+ in >> c;
+ if (c == '.') {
+ su[i] = -1;
+ } else {
+ assert('1' <= c && c <= '9');
+ su[i] = c - '1';
+ su.nfull++;
+ }
+ }
+
+ return in;
+}
+
+static ostream& operator<<(ostream &os, const Sudoku &su) {
+ for (int y = 0; y < 9; y++) {
+ if (y != 0) {
+ os << endl;
+ if (y % 3 == 0) os << endl;
+ }
+
+ for (int x = 0; x < 9; x++) {
+ if (x != 0) {
+ os << ' ';
+ if (x % 3 == 0) os << ' ';
+ }
+ int8_t v = su[9 * y + x];
+ if (v == -1) os << '.';
+ else os << (int)v + 1;
+ // string s;
+ // for (int i = 0; i < 9; i++) {
+ // if (su.poss[9 * y + x] & (1<<i)) s += '1' + i;
+ // }
+ // os << '(' << setw(9) << s << ')';
+ }
+ }
+
+ return os;
+}
+
+static bool isValid(const Sudoku &su) {
+ for (int y = 0; y < 9; y++) {
+ int i = 9 * y;
+ uint8_t seen[9];
+ memset(seen, 0, 9);
+ for (int x = 0; x < 9; x++, i++) {
+ int v = su[i];
+ if (v != -1) {
+ if (seen[v]) return false;
+ else seen[v] = 1;
+ }
+ }
+ }
+
+ for (int x = 0; x < 9; x++) {
+ int i = x;
+ uint8_t seen[9];
+ memset(seen, 0, 9);
+ for (int y = 0; y < 9; y++, i += 9) {
+ int v = su[i];
+ if (v != -1) {
+ if (seen[v]) return false;
+ else seen[v] = 1;
+ }
+ }
+ }
+
+ int i = 0;
+ for (int by = 0; by < 9; by += 3) {
+ for (int bx = 0; bx < 9; bx += 3) {
+ uint8_t seen[9];
+ memset(seen, 0, 9);
+
+ int j = i;
+ for (int y = 0; y < 3; y++) {
+ for (int x = 0; x < 3; x++) {
+ int v = su[j];
+ if (v != -1) {
+ if (seen[v]) return false;
+ else seen[v] = 1;
+ }
+ j++;
+ }
+ j += 6;
+ }
+
+ i += 3;
+ }
+ i += 18;
+ }
+
+ return true;
+}
+
+vector<Sudoku> solutions;
+
+static void scratchAround(Sudoku &su, int idx, int x, int y) {
+ uint16_t mask = 0x1ff & ~(1 << su[idx]);
+
+ for (int i = 0; i < 9; i++) {
+ su.poss[9 * y + i] &= mask;
+ su.poss[9 * i + x] &= mask;
+ }
+
+ int bx = x / 3 * 3, by = y / 3 * 3;
+ for (int yi = 0; yi < 3; yi++) {
+ for (int xi = 0; xi < 3; xi++) {
+ su.poss[9 * (by + yi) + bx + xi] &= mask;
+ }
+ }
+}
+
+static void scratchAround(Sudoku &su, int idx) {
+ scratchAround(su, idx, idx % 9, idx / 9);
+}
+
+static void scratchAround(Sudoku &su, int x, int y) {
+ scratchAround(su, 9 * y + x, x, y);
+}
+
+static void solveDestructive(Sudoku &su) {
+ // cerr << "solveDestructive(" << su.nfull << ")" << endl;
+
+ for (int y = 0, i = 0; y < 9; y++) {
+ for (int x = 0; x < 9; x++, i++) {
+ if (su[i] != -1) continue;
+
+ if (__builtin_popcount(su.poss[i]) == 1) {
+ su[i] = __builtin_ctz(su.poss[i]);
+ su.nfull++;
+ scratchAround(su, i);
+ }
+ }
+ }
+
+ if (su.nfull == 81) {
+ if (isValid(su)) {
+ solutions.push_back(su);
+ }
+ return;
+ }
+
+ int mincount = 99, minat = -1;
+
+ for (int y = 0, i = 0; y < 9; y++) {
+ for (int x = 0; x < 9; x++, i++) {
+ if (su[i] != -1) continue;
+
+ int count = __builtin_popcount((uint32_t)su.poss[i]);
+ if (count < mincount) {
+ mincount = count;
+ minat = i;
+ }
+ }
+ }
+
+ if (mincount == 1) {
+ // do a tail call
+ solveDestructive(su);
+ return;
+ }
+
+ if (mincount == 0) {
+ // invalid
+ return;
+ }
+
+ // cerr << "minat = " << minat << " ";
+ // cerr << "mincount = " << mincount << endl;
+
+ int i = minat;
+ for (int v = 0; v < 9; v++) {
+ if ((su.poss[i] & (1 << v)) == 0) continue;
+
+ Sudoku su2 = su;
+ su2[i] = v;
+ su2.nfull++;
+ // cerr << "try " << v << endl;
+ if (!isValid(su2)) {
+ // cerr << " invalid" << endl;
+ continue;
+ }
+
+ scratchAround(su2, i);
+ solveDestructive(su2);
+ }
+}
+
+static void solve(const Sudoku &su) {
+ Sudoku su2 = su;
+ solveDestructive(su2);
+}
+
+
+static void cleanPoss(Sudoku &su) {
+ for (int y = 0; y < 9; y++) {
+ for (int x = 0; x < 9; x++) {
+ scratchAround(su, x, y);
+ }
+ }
+}
+
+int main() {
+ Sudoku su;
+ cin >> su;
+
+ solutions.clear();
+ cleanPoss(su);
+ solve(su);
+
+ if (solutions.size() != 1) {
+ cout << solutions.size() << " solutions found:" << endl;
+ }
+
+ for (const Sudoku &su2 : solutions) {
+ cout << su2 << endl << endl;
+ }
+}