summaryrefslogtreecommitdiff
path: root/eqsystem_solve.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eqsystem_solve.cpp')
-rw-r--r--eqsystem_solve.cpp330
1 files changed, 330 insertions, 0 deletions
diff --git a/eqsystem_solve.cpp b/eqsystem_solve.cpp
new file mode 100644
index 0000000..878ddb6
--- /dev/null
+++ b/eqsystem_solve.cpp
@@ -0,0 +1,330 @@
+#include <iostream>
+#include <sstream>
+#include <unordered_set>
+#include <stdexcept>
+#include <algorithm>
+#include <cassert>
+#include "eqsystem_solve.h"
+
+
+const Product Product::times(double x) const {
+ return Product{cnst * x, vars};
+}
+
+void Product::times_inplace(double x) {
+ cnst *= x;
+}
+
+const Product Product::times(const Product &p) const {
+ std::vector<std::string> vars2{vars};
+ vars2.insert(vars2.end(), p.vars.begin(), p.vars.end());
+ return Product{cnst * p.cnst, vars2};
+}
+
+bool Product::contains(const std::string &var) const {
+ for (const std::string &v : vars) {
+ if (v == var) return true;
+ }
+ return false;
+}
+
+const Product Product::without(const std::string &var) const {
+ Product result{cnst, {}};
+ result.vars.reserve(vars.size());
+ for (const std::string &v : vars) {
+ if (v != var) result.vars.push_back(v);
+ }
+ return result;
+}
+
+bool Product::operator==(const Product &other) const {
+ if (cnst != other.cnst) return false;
+ if (vars.size() != other.vars.size()) return false;
+ std::unordered_set<std::string> seen(vars.begin(), vars.end());
+ for (const std::string &v : other.vars) {
+ auto it = seen.find(v);
+ if (it == seen.end()) return false;
+ seen.erase(it);
+ }
+ return true;
+}
+
+const Sum Sum::times(double x) const {
+ Sum result;
+ result.terms.reserve(terms.size());
+ for (const Product &p : terms) result.terms.push_back(p.times(x));
+ return result;
+}
+
+const Sum Sum::substitute(const std::string &var, Sum sum) const {
+ Sum result;
+
+ for (const Product &p : terms) {
+ if (p.contains(var)) {
+ Product rest = p.without(var);
+ for (const Product &term : sum.terms) {
+ result.terms.push_back(rest.times(term));
+ }
+ } else {
+ result.terms.push_back(p);
+ }
+ }
+
+ return result;
+}
+
+void Sum::simplify() {
+ double tot = 0.0;
+ std::vector<Product> terms2;
+
+ for (Product &p : terms) {
+ if (p.vars.empty()) {
+ tot += p.cnst;
+ } else {
+ sort(p.vars.begin(), p.vars.end());
+ bool added = false;
+ for (Product &q : terms2) {
+ if (q.vars == p.vars) {
+ q.cnst += p.cnst;
+ added = true;
+ break;
+ }
+ }
+ if (!added) terms2.push_back(std::move(p));
+ }
+ }
+
+ if (tot != 0.0) terms2.push_back(Product{tot, {}});
+ terms = terms2;
+}
+
+std::optional<double> Sum::evaluate() const {
+ double tot = 0.0;
+ for (const Product &p : terms) {
+ if (!p.vars.empty()) return std::nullopt;
+ tot += p.cnst;
+ }
+ return tot;
+}
+
+std::optional<std::string> Equation::validate() const {
+ std::unordered_set<std::string> seen;
+
+ for (int i = 0; i < 2; i++) {
+ const Sum &sum = i == 0 ? lhs : rhs;
+
+ for (const Product &p : sum.terms) {
+ for (const std::string &var : p.vars) {
+ auto it = seen.find(var);
+ if (it != seen.end()) {
+ std::stringstream ss;
+ ss << "In equation " << *this << ": variable " << var << " occurs multiple times";
+ return ss.str();
+ }
+ seen.insert(var);
+ }
+ }
+ }
+
+ return std::nullopt;
+}
+
+const Equation Equation::substitute(const std::string &var, Sum sum) const {
+ return Equation{lhs.substitute(var, sum), rhs.substitute(var, sum)};
+}
+
+void Equation::substitute_inplace(const std::string &var, Sum sum) {
+ *this = substitute(var, sum);
+}
+
+int Equation::num_vars() const {
+ int count = 0;
+ for (const Product &p : lhs.terms) count += p.vars.size();
+ for (const Product &p : rhs.terms) count += p.vars.size();
+ return count;
+}
+
+std::vector<std::string> Equation::all_vars() const {
+ std::vector<std::string> result;
+ for (const Product &p : lhs.terms) result.insert(result.end(), p.vars.begin(), p.vars.end());
+ for (const Product &p : rhs.terms) result.insert(result.end(), p.vars.begin(), p.vars.end());
+ return result;
+}
+
+bool Equation::isolatable(const std::string &target) const {
+ int num_terms = 0;
+
+ for (int i = 0; i < 2; i++) {
+ const Sum &side = i == 0 ? lhs : rhs;
+ for (const Product &p : side.terms) {
+ if (p.contains(target)) {
+ num_terms++;
+ if (num_terms > 1) return false;
+ if (p.vars.size() > 1) return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+void Equation::isolate_left(const std::string &target) {
+ Sum newlhs, newrhs;
+ for (int i = 0; i < 2; i++) {
+ const Sum &side = i == 0 ? lhs : rhs;
+ for (const Product &p : side.terms) {
+ if (p.contains(target)) newlhs.terms.push_back(p.times(i == 0 ? 1 : -1));
+ else newrhs.terms.push_back(p.times(i == 1 ? 1 : -1));
+ }
+ }
+
+ lhs = newlhs;
+ rhs = newrhs;
+
+ if (lhs.terms.size() == 1) {
+ const double factor = lhs.terms.at(0).cnst;
+ if (factor != 0) {
+ for (Product &p : rhs.terms) p.times_inplace(1.0 / factor);
+ }
+ lhs.terms.at(0).cnst = 1.0;
+ }
+}
+
+void Equation::simplify() {
+ lhs.simplify();
+ rhs.simplify();
+}
+
+// xmax = xmin + cplxwidth
+// ymax = ymin + cplxheight
+// xmin + xmax = 2 * cx
+// ymin + ymax = 2 * cy
+// imgwidth * cplxheight = cplxwidth * imgheight
+
+std::variant<
+ std::string, // description of the error
+ std::vector<std::pair<std::string, double>> // result assignment
+> System::solve_inplace(bool debug) {
+ for (const Equation &eq : eqs) {
+ if (auto err = eq.validate()) {
+ throw std::invalid_argument("Invalid equation passed to System::solve(): " + *err);
+ }
+ }
+
+ std::vector<std::pair<std::string, double>> assignment;
+
+ std::vector<bool> assigned(eqs.size());
+ std::vector<bool> substituted(eqs.size());
+
+success_try_again:
+ if (debug) {
+ std::cerr << std::endl << "\x1B[1mSolving:\x1B[0m" << std::endl;
+ std::cerr << *this << std::endl << std::endl;
+ }
+
+ for (int stage = 1; stage <= 2; stage++) {
+ for (size_t i = 0; i < eqs.size(); i++) {
+ if (assigned[i] || (stage == 2 && substituted[i])) continue;
+
+ Equation &eq = eqs[i];
+
+ if (stage >= 2 || eq.num_vars() == 1) {
+ const std::string var = eq.all_vars().at(0);
+ if (eq.isolatable(var)) {
+ eq.isolate_left(var);
+ eq.rhs.simplify();
+ if (eq.lhs.terms.size() == 1 &&
+ eq.lhs.terms.at(0) == Product{1.0, std::vector<std::string>{var}}) {
+ if (debug) std::cerr << "\x1B[1m[" << stage << "] Substituting " << var;
+ if (auto value = eq.rhs.evaluate()) {
+ if (debug) std::cerr << " = " << *value;
+ assignment.emplace_back(
+ eq.lhs.terms.at(0).vars.at(0),
+ *value
+ );
+ assigned[i] = true;
+ }
+ if (debug) std::cerr << "\x1B[0m" << std::endl;
+
+ for (size_t j = 0; j < eqs.size(); j++) {
+ if (j == i || assigned[j]) continue;
+ eqs[j].substitute_inplace(var, eq.rhs);
+ eqs[j].simplify();
+ }
+ substituted[i] = true;
+ goto success_try_again;
+ }
+ }
+ }
+ }
+ }
+
+ std::stringstream errmsg;
+ errmsg << "Equations could not be solved:\n";
+ bool have_error = false;
+
+ // If we arrive here, no more equations could be solved
+ for (size_t i = 0; i < eqs.size(); i++) {
+ if (assigned[i]) continue;
+ have_error = true;
+ errmsg << " " << eqs[i] << '\n';
+ }
+
+ if (have_error) return errmsg.str();
+ else return assignment;
+}
+
+void System::substitute_inplace(const std::string &var, Sum sum) {
+ for (Equation &eq : eqs) {
+ eq.substitute_inplace(var, sum);
+ }
+}
+
+
+std::ostream& operator<<(std::ostream &os, const Product &p) {
+ bool first = true;
+ if (p.cnst != 1 || p.vars.empty()) {
+ os << p.cnst;
+ first = false;
+ }
+
+ for (const std::string &v : p.vars) {
+ if (first) first = false;
+ else os << "*";
+ os << v;
+ }
+
+ return os;
+}
+
+std::ostream& operator<<(std::ostream &os, const Sum &s) {
+ if (s.terms.empty()) {
+ os << "0";
+ return os;
+ }
+
+ bool first = true;
+ for (const Product &p : s.terms) {
+ if (first) first = false;
+ else os << " + ";
+ os << p;
+ }
+ return os;
+}
+
+std::ostream& operator<<(std::ostream &os, const Equation &eq) {
+ os << eq.lhs << " = " << eq.rhs;
+ return os;
+}
+
+std::ostream& operator<<(std::ostream &os, const System &sys) {
+ bool first = true;
+ for (const Equation &eq : sys.eqs) {
+ if (first) first = false;
+ else os << '\n';
+ os << eq;
+ }
+ return os;
+}
+
+// vim: set sw=4 ts=4 noet: