diff options
Diffstat (limited to 'eqsystem_solve.cpp')
-rw-r--r-- | eqsystem_solve.cpp | 330 |
1 files changed, 330 insertions, 0 deletions
diff --git a/eqsystem_solve.cpp b/eqsystem_solve.cpp new file mode 100644 index 0000000..878ddb6 --- /dev/null +++ b/eqsystem_solve.cpp @@ -0,0 +1,330 @@ +#include <iostream> +#include <sstream> +#include <unordered_set> +#include <stdexcept> +#include <algorithm> +#include <cassert> +#include "eqsystem_solve.h" + + +const Product Product::times(double x) const { + return Product{cnst * x, vars}; +} + +void Product::times_inplace(double x) { + cnst *= x; +} + +const Product Product::times(const Product &p) const { + std::vector<std::string> vars2{vars}; + vars2.insert(vars2.end(), p.vars.begin(), p.vars.end()); + return Product{cnst * p.cnst, vars2}; +} + +bool Product::contains(const std::string &var) const { + for (const std::string &v : vars) { + if (v == var) return true; + } + return false; +} + +const Product Product::without(const std::string &var) const { + Product result{cnst, {}}; + result.vars.reserve(vars.size()); + for (const std::string &v : vars) { + if (v != var) result.vars.push_back(v); + } + return result; +} + +bool Product::operator==(const Product &other) const { + if (cnst != other.cnst) return false; + if (vars.size() != other.vars.size()) return false; + std::unordered_set<std::string> seen(vars.begin(), vars.end()); + for (const std::string &v : other.vars) { + auto it = seen.find(v); + if (it == seen.end()) return false; + seen.erase(it); + } + return true; +} + +const Sum Sum::times(double x) const { + Sum result; + result.terms.reserve(terms.size()); + for (const Product &p : terms) result.terms.push_back(p.times(x)); + return result; +} + +const Sum Sum::substitute(const std::string &var, Sum sum) const { + Sum result; + + for (const Product &p : terms) { + if (p.contains(var)) { + Product rest = p.without(var); + for (const Product &term : sum.terms) { + result.terms.push_back(rest.times(term)); + } + } else { + result.terms.push_back(p); + } + } + + return result; +} + +void Sum::simplify() { + double tot = 0.0; + std::vector<Product> terms2; + + for (Product &p : terms) { + if (p.vars.empty()) { + tot += p.cnst; + } else { + sort(p.vars.begin(), p.vars.end()); + bool added = false; + for (Product &q : terms2) { + if (q.vars == p.vars) { + q.cnst += p.cnst; + added = true; + break; + } + } + if (!added) terms2.push_back(std::move(p)); + } + } + + if (tot != 0.0) terms2.push_back(Product{tot, {}}); + terms = terms2; +} + +std::optional<double> Sum::evaluate() const { + double tot = 0.0; + for (const Product &p : terms) { + if (!p.vars.empty()) return std::nullopt; + tot += p.cnst; + } + return tot; +} + +std::optional<std::string> Equation::validate() const { + std::unordered_set<std::string> seen; + + for (int i = 0; i < 2; i++) { + const Sum &sum = i == 0 ? lhs : rhs; + + for (const Product &p : sum.terms) { + for (const std::string &var : p.vars) { + auto it = seen.find(var); + if (it != seen.end()) { + std::stringstream ss; + ss << "In equation " << *this << ": variable " << var << " occurs multiple times"; + return ss.str(); + } + seen.insert(var); + } + } + } + + return std::nullopt; +} + +const Equation Equation::substitute(const std::string &var, Sum sum) const { + return Equation{lhs.substitute(var, sum), rhs.substitute(var, sum)}; +} + +void Equation::substitute_inplace(const std::string &var, Sum sum) { + *this = substitute(var, sum); +} + +int Equation::num_vars() const { + int count = 0; + for (const Product &p : lhs.terms) count += p.vars.size(); + for (const Product &p : rhs.terms) count += p.vars.size(); + return count; +} + +std::vector<std::string> Equation::all_vars() const { + std::vector<std::string> result; + for (const Product &p : lhs.terms) result.insert(result.end(), p.vars.begin(), p.vars.end()); + for (const Product &p : rhs.terms) result.insert(result.end(), p.vars.begin(), p.vars.end()); + return result; +} + +bool Equation::isolatable(const std::string &target) const { + int num_terms = 0; + + for (int i = 0; i < 2; i++) { + const Sum &side = i == 0 ? lhs : rhs; + for (const Product &p : side.terms) { + if (p.contains(target)) { + num_terms++; + if (num_terms > 1) return false; + if (p.vars.size() > 1) return false; + } + } + } + + return true; +} + +void Equation::isolate_left(const std::string &target) { + Sum newlhs, newrhs; + for (int i = 0; i < 2; i++) { + const Sum &side = i == 0 ? lhs : rhs; + for (const Product &p : side.terms) { + if (p.contains(target)) newlhs.terms.push_back(p.times(i == 0 ? 1 : -1)); + else newrhs.terms.push_back(p.times(i == 1 ? 1 : -1)); + } + } + + lhs = newlhs; + rhs = newrhs; + + if (lhs.terms.size() == 1) { + const double factor = lhs.terms.at(0).cnst; + if (factor != 0) { + for (Product &p : rhs.terms) p.times_inplace(1.0 / factor); + } + lhs.terms.at(0).cnst = 1.0; + } +} + +void Equation::simplify() { + lhs.simplify(); + rhs.simplify(); +} + +// xmax = xmin + cplxwidth +// ymax = ymin + cplxheight +// xmin + xmax = 2 * cx +// ymin + ymax = 2 * cy +// imgwidth * cplxheight = cplxwidth * imgheight + +std::variant< + std::string, // description of the error + std::vector<std::pair<std::string, double>> // result assignment +> System::solve_inplace(bool debug) { + for (const Equation &eq : eqs) { + if (auto err = eq.validate()) { + throw std::invalid_argument("Invalid equation passed to System::solve(): " + *err); + } + } + + std::vector<std::pair<std::string, double>> assignment; + + std::vector<bool> assigned(eqs.size()); + std::vector<bool> substituted(eqs.size()); + +success_try_again: + if (debug) { + std::cerr << std::endl << "\x1B[1mSolving:\x1B[0m" << std::endl; + std::cerr << *this << std::endl << std::endl; + } + + for (int stage = 1; stage <= 2; stage++) { + for (size_t i = 0; i < eqs.size(); i++) { + if (assigned[i] || (stage == 2 && substituted[i])) continue; + + Equation &eq = eqs[i]; + + if (stage >= 2 || eq.num_vars() == 1) { + const std::string var = eq.all_vars().at(0); + if (eq.isolatable(var)) { + eq.isolate_left(var); + eq.rhs.simplify(); + if (eq.lhs.terms.size() == 1 && + eq.lhs.terms.at(0) == Product{1.0, std::vector<std::string>{var}}) { + if (debug) std::cerr << "\x1B[1m[" << stage << "] Substituting " << var; + if (auto value = eq.rhs.evaluate()) { + if (debug) std::cerr << " = " << *value; + assignment.emplace_back( + eq.lhs.terms.at(0).vars.at(0), + *value + ); + assigned[i] = true; + } + if (debug) std::cerr << "\x1B[0m" << std::endl; + + for (size_t j = 0; j < eqs.size(); j++) { + if (j == i || assigned[j]) continue; + eqs[j].substitute_inplace(var, eq.rhs); + eqs[j].simplify(); + } + substituted[i] = true; + goto success_try_again; + } + } + } + } + } + + std::stringstream errmsg; + errmsg << "Equations could not be solved:\n"; + bool have_error = false; + + // If we arrive here, no more equations could be solved + for (size_t i = 0; i < eqs.size(); i++) { + if (assigned[i]) continue; + have_error = true; + errmsg << " " << eqs[i] << '\n'; + } + + if (have_error) return errmsg.str(); + else return assignment; +} + +void System::substitute_inplace(const std::string &var, Sum sum) { + for (Equation &eq : eqs) { + eq.substitute_inplace(var, sum); + } +} + + +std::ostream& operator<<(std::ostream &os, const Product &p) { + bool first = true; + if (p.cnst != 1 || p.vars.empty()) { + os << p.cnst; + first = false; + } + + for (const std::string &v : p.vars) { + if (first) first = false; + else os << "*"; + os << v; + } + + return os; +} + +std::ostream& operator<<(std::ostream &os, const Sum &s) { + if (s.terms.empty()) { + os << "0"; + return os; + } + + bool first = true; + for (const Product &p : s.terms) { + if (first) first = false; + else os << " + "; + os << p; + } + return os; +} + +std::ostream& operator<<(std::ostream &os, const Equation &eq) { + os << eq.lhs << " = " << eq.rhs; + return os; +} + +std::ostream& operator<<(std::ostream &os, const System &sys) { + bool first = true; + for (const Equation &eq : sys.eqs) { + if (first) first = false; + else os << '\n'; + os << eq; + } + return os; +} + +// vim: set sw=4 ts=4 noet: |