#include #include #include #include #include #include #include "eqsystem_solve.h" const Product Product::times(double x) const { return Product{cnst * x, vars}; } void Product::times_inplace(double x) { cnst *= x; } const Product Product::times(const Product &p) const { std::vector vars2{vars}; vars2.insert(vars2.end(), p.vars.begin(), p.vars.end()); return Product{cnst * p.cnst, vars2}; } bool Product::contains(const std::string &var) const { for (const std::string &v : vars) { if (v == var) return true; } return false; } const Product Product::without(const std::string &var) const { Product result{cnst, {}}; result.vars.reserve(vars.size()); for (const std::string &v : vars) { if (v != var) result.vars.push_back(v); } return result; } bool Product::operator==(const Product &other) const { if (cnst != other.cnst) return false; if (vars.size() != other.vars.size()) return false; std::unordered_set seen(vars.begin(), vars.end()); for (const std::string &v : other.vars) { auto it = seen.find(v); if (it == seen.end()) return false; seen.erase(it); } return true; } const Sum Sum::times(double x) const { Sum result; result.terms.reserve(terms.size()); for (const Product &p : terms) result.terms.push_back(p.times(x)); return result; } const Sum Sum::substitute(const std::string &var, Sum sum) const { Sum result; for (const Product &p : terms) { if (p.contains(var)) { Product rest = p.without(var); for (const Product &term : sum.terms) { result.terms.push_back(rest.times(term)); } } else { result.terms.push_back(p); } } return result; } void Sum::simplify() { double tot = 0.0; std::vector terms2; for (Product &p : terms) { if (p.vars.empty()) { tot += p.cnst; } else { sort(p.vars.begin(), p.vars.end()); bool added = false; for (Product &q : terms2) { if (q.vars == p.vars) { q.cnst += p.cnst; added = true; break; } } if (!added) terms2.push_back(std::move(p)); } } if (tot != 0.0) terms2.push_back(Product{tot, {}}); terms = terms2; } std::optional Sum::evaluate() const { double tot = 0.0; for (const Product &p : terms) { if (!p.vars.empty()) return std::nullopt; tot += p.cnst; } return tot; } std::optional Equation::validate() const { std::unordered_set seen; for (int i = 0; i < 2; i++) { const Sum &sum = i == 0 ? lhs : rhs; for (const Product &p : sum.terms) { for (const std::string &var : p.vars) { auto it = seen.find(var); if (it != seen.end()) { std::stringstream ss; ss << "In equation " << *this << ": variable " << var << " occurs multiple times"; return ss.str(); } seen.insert(var); } } } return std::nullopt; } const Equation Equation::substitute(const std::string &var, Sum sum) const { return Equation{lhs.substitute(var, sum), rhs.substitute(var, sum)}; } void Equation::substitute_inplace(const std::string &var, Sum sum) { *this = substitute(var, sum); } int Equation::num_vars() const { int count = 0; for (const Product &p : lhs.terms) count += p.vars.size(); for (const Product &p : rhs.terms) count += p.vars.size(); return count; } std::vector Equation::all_vars() const { std::vector result; for (const Product &p : lhs.terms) result.insert(result.end(), p.vars.begin(), p.vars.end()); for (const Product &p : rhs.terms) result.insert(result.end(), p.vars.begin(), p.vars.end()); return result; } bool Equation::isolatable(const std::string &target) const { int num_terms = 0; for (int i = 0; i < 2; i++) { const Sum &side = i == 0 ? lhs : rhs; for (const Product &p : side.terms) { if (p.contains(target)) { num_terms++; if (num_terms > 1) return false; if (p.vars.size() > 1) return false; } } } return true; } void Equation::isolate_left(const std::string &target) { Sum newlhs, newrhs; for (int i = 0; i < 2; i++) { const Sum &side = i == 0 ? lhs : rhs; for (const Product &p : side.terms) { if (p.contains(target)) newlhs.terms.push_back(p.times(i == 0 ? 1 : -1)); else newrhs.terms.push_back(p.times(i == 1 ? 1 : -1)); } } lhs = newlhs; rhs = newrhs; if (lhs.terms.size() == 1) { const double factor = lhs.terms.at(0).cnst; if (factor != 0) { for (Product &p : rhs.terms) p.times_inplace(1.0 / factor); } lhs.terms.at(0).cnst = 1.0; } } void Equation::simplify() { lhs.simplify(); rhs.simplify(); } // xmax = xmin + cplxwidth // ymax = ymin + cplxheight // xmin + xmax = 2 * cx // ymin + ymax = 2 * cy // imgwidth * cplxheight = cplxwidth * imgheight std::variant< std::string, // description of the error std::vector> // result assignment > System::solve_inplace(bool debug) { for (const Equation &eq : eqs) { if (auto err = eq.validate()) { throw std::invalid_argument("Invalid equation passed to System::solve(): " + *err); } } std::vector> assignment; std::vector assigned(eqs.size()); std::vector substituted(eqs.size()); success_try_again: if (debug) { std::cerr << std::endl << "\x1B[1mSolving:\x1B[0m" << std::endl; std::cerr << *this << std::endl << std::endl; } for (int stage = 1; stage <= 2; stage++) { for (size_t i = 0; i < eqs.size(); i++) { if (assigned[i] || (stage == 2 && substituted[i])) continue; Equation &eq = eqs[i]; if (stage >= 2 || eq.num_vars() == 1) { const std::string var = eq.all_vars().at(0); if (eq.isolatable(var)) { eq.isolate_left(var); eq.rhs.simplify(); if (eq.lhs.terms.size() == 1 && eq.lhs.terms.at(0) == Product{1.0, std::vector{var}}) { if (debug) std::cerr << "\x1B[1m[" << stage << "] Substituting " << var; if (auto value = eq.rhs.evaluate()) { if (debug) std::cerr << " = " << *value; assignment.emplace_back( eq.lhs.terms.at(0).vars.at(0), *value ); assigned[i] = true; } if (debug) std::cerr << "\x1B[0m" << std::endl; for (size_t j = 0; j < eqs.size(); j++) { if (j == i || assigned[j]) continue; eqs[j].substitute_inplace(var, eq.rhs); eqs[j].simplify(); } substituted[i] = true; goto success_try_again; } } } } } std::stringstream errmsg; errmsg << "Equations could not be solved:\n"; bool have_error = false; // If we arrive here, no more equations could be solved for (size_t i = 0; i < eqs.size(); i++) { if (assigned[i]) continue; have_error = true; errmsg << " " << eqs[i] << '\n'; } if (have_error) return errmsg.str(); else return assignment; } void System::substitute_inplace(const std::string &var, Sum sum) { for (Equation &eq : eqs) { eq.substitute_inplace(var, sum); } } std::ostream& operator<<(std::ostream &os, const Product &p) { bool first = true; if (p.cnst != 1 || p.vars.empty()) { os << p.cnst; first = false; } for (const std::string &v : p.vars) { if (first) first = false; else os << "*"; os << v; } return os; } std::ostream& operator<<(std::ostream &os, const Sum &s) { if (s.terms.empty()) { os << "0"; return os; } bool first = true; for (const Product &p : s.terms) { if (first) first = false; else os << " + "; os << p; } return os; } std::ostream& operator<<(std::ostream &os, const Equation &eq) { os << eq.lhs << " = " << eq.rhs; return os; } std::ostream& operator<<(std::ostream &os, const System &sys) { bool first = true; for (const Equation &eq : sys.eqs) { if (first) first = false; else os << '\n'; os << eq; } return os; } // vim: set sw=4 ts=4 noet: