summaryrefslogtreecommitdiff
path: root/problem.cpp
blob: 3eb7ecc601b980a0e2df030f31437a426152f0e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#include <iomanip>
#include <stdexcept>
#include <cassert>
#include "problem.h"

#define DEBUG

using namespace std;

Problem::Problem(istream &in){
	in>>*this;
}

template <typename T,typename U>
void rowopadd(int n,T &dst,const U &src,double times){
	for(int i=0;i<n;i++)dst[i]+=src[i]*times;
}

template <typename T>
void rowopmult(int n,T &dst,double times){
	for(int i=0;i<n;i++)dst[i]*=times;
}

template <typename T>
int min_index(const vector<T> &v){
	if(v.size()==0)throw logic_error("No minimum element in empty vector");
	int mi=0;
	int size=v.size();
	for(int i=1;i<size;i++){
		if(v[i]<v[mi]){
			mi=i;
		}
	}
	return mi;
}

void Problem::removefrombasis(int varidx){
	if(zfunc[varidx]==0){
#ifdef DEBUG
		cerr<<"Not needing to remove variable "<<varidx<<" from basis"<<endl;
#endif
		return;
	}

	const int nrestr=restr.height(),nvars=restr.width();

	int restridx=-1;
	for(int i=0;i<nrestr;i++){
		if(basis[i]==varidx){
			restridx=i;
			break;
		}
	}
	throw logic_error("Variable to remove from basis is not in basis");

	if(restr[restridx][varidx]==0)throw logic_error("Basis variable column not in identity form");

	zvalue-=zfunc[varidx]*bvec[restridx];
	rowopadd(nvars,zfunc,restr[restridx],-zfunc[varidx]);

#ifdef DEBUG
	cerr<<"Removed variable "<<varidx<<" from basis:"<<endl;
	cerr<<*this<<endl<<endl;
#endif
}

void Problem::solve(){
	const int nvars=restr.width();

	bool haveart=false;
	for(VarType type : vartype){
		if(type==VT_ART){
			haveart=true;
			break;
		}
	}
	if(haveart){
#ifdef DEBUG
		cerr<<" === SOLVING ARTIFICIAL PROBLEM ==="<<endl;
#endif
		vector<double> origzfunc=zfunc;
		for(int j=0;j<nvars;j++){
			zfunc[j]=vartype[j]==VT_ART;
		}
#ifdef DEBUG
		cerr<<*this<<endl<<endl;
#endif
		solve_noart();
#ifdef DEBUG
		cerr<<*this<<endl<<endl;
#endif
		if(zvalue!=0)throw invalid_argument("No feasible solution (no zero value of artificial variables)");
		zfunc=origzfunc;

		int nart=0;
		for(int j=nvars-1;j>=0&&vartype[j]==VT_ART;j--){
			removefrombasis(j);
			nart++;
		}
		vartype.resize(nvars-nart);
		zfunc.resize(nvars-nart);
		restr.resize(nvars-nart,restr.height());

#ifdef DEBUG
		cerr<<" === SOLVING ORIGINAL PROBLEM ==="<<endl;
#endif
		solve_noart();
	}
}

void Problem::solve_noart(){
	const int nrestr=restr.height(),nvars=restr.width();

	basis.clear();
	basis.resize(nrestr,-1);
	for(int j=0;j<nvars;j++){ //find basis variables
		int i;
		for(i=0;i<nrestr;i++)if(restr[i][j]!=0)break;
		if(i==nrestr)throw invalid_argument("Variable with zero matrix column");
		if(restr[i][j]<=0)continue;
		int onerestr=i;
		if(basis[onerestr]!=-1)continue;
		i++;
		for(;i<nrestr;i++)if(restr[i][j]!=0)break;
		if(i!=nrestr)continue;
		basis[onerestr]=j;
		double value=restr[onerestr][j];
		assert(value!=0);
		for(int j2=0;j2<nvars;j2++)restr[onerestr][j2]/=value;
		bvec[onerestr]/=value;
	}

	for(int j : basis){
		if(j==-1)throw invalid_argument("No feasible solution (overspecified system)");
	}

#ifdef DEBUG
	cerr<<"basis =";
	for(int i=0;i<nrestr;i++)cerr<<' '<<basis[i];
	cerr<<endl;
	cerr<<*this<<endl<<endl;
#endif

	for(int i=0;i<nrestr;i++){ //express z in terms of non-basis variables
		if(zfunc[basis[i]]!=0){
			zvalue-=zfunc[basis[i]]*bvec[i];
			rowopadd(nvars,zfunc,restr[i],-zfunc[basis[i]]);
		}
	}

#ifdef DEBUG
	cerr<<"Expressed z in terms of non-basis variables:"<<endl;
	cerr<<*this<<endl<<endl;
#endif

	while(true){
		int pivotvar=min_index(zfunc); //find pivot column
		if(zfunc[pivotvar]>=0)break; //optimal solution found
		int pivotrestr=-1;
		for(int i=0;i<nrestr;i++){ //find pivot row
			if(restr[i][pivotvar]<=0)continue;
			if(pivotrestr==-1||bvec[i]/restr[i][pivotvar]<bvec[pivotrestr]/restr[pivotrestr][pivotvar]){
				pivotrestr=i;
			}
		}
		if(pivotrestr==-1){
			throw invalid_argument("Unbounded problem (no positive entry in pivot column)");
		}
		if(restr[pivotrestr][pivotvar]!=1){ //normalise row
			bvec[pivotrestr]/=restr[pivotrestr][pivotvar];
			auto row=restr[pivotrestr];
			rowopmult(nvars,row,1/restr[pivotrestr][pivotvar]);
		}
		for(int i=0;i<nrestr;i++){ //zero this column in other rows
			if(i==pivotrestr)continue;
			if(restr[i][pivotvar]==0)continue;
			bvec[i]-=restr[i][pivotvar]*bvec[pivotrestr];
			auto row=restr[i];
			rowopadd(nvars,row,restr[pivotrestr],-restr[i][pivotvar]);
		}
		if(zfunc[pivotvar]!=0){ //and the z row
			zvalue-=zfunc[pivotvar]*bvec[pivotrestr];
			rowopadd(nvars,zfunc,restr[pivotrestr],-zfunc[pivotvar]);
		}
		basis[pivotrestr]=pivotvar; //replace variable in basis

#ifdef DEBUG
		cerr<<"basis =";
		for(int i=0;i<nrestr;i++)cerr<<' '<<basis[i];
		cerr<<endl;
#endif
	}
}

double Problem::solutionValue() const {
	return zvalue==0?0:-zvalue; //prevent -0
}

vector<double> Problem::solutionVars() const {
	vector<double> normal;
	int nrestr=restr.height(),nvars=restr.width();
	int nnormal;
	for(nnormal=0;nnormal<nvars;nnormal++)if(vartype[nnormal]!=VT_NORMAL)break;
	normal.resize(nnormal);
	for(int i=0;i<nrestr;i++){
		if(vartype[basis[i]]==VT_NORMAL){
			normal[basis[i]]=bvec[i];
			if(normal[basis[i]]==0)normal[basis[i]]=0; //prevent -0
		}
	}
	return normal;
}

template <typename T>
void readthrow(istream &in,T &dst){
	in>>dst;
	if(in.fail())throw invalid_argument("Failure reading input");
}

istream& operator>>(istream &in,Problem &prob){
	string type;
	in>>type;
	bool negate=false;
	if(type=="max")negate=true;
	else if(type!="min")throw invalid_argument("Invalid LP problem type");
	int nvars,nrestr;
	in>>nvars>>nrestr;
	if(nvars<=0)throw invalid_argument("Invalid number of variables");
	if(nrestr<=0)throw invalid_argument("Invalid number of restrictions");
	prob.vartype.clear();  prob.vartype.resize(nvars);
	prob.zfunc.clear();  prob.zfunc.resize(nvars);
	prob.restr.clear();  prob.restr.resize(nvars,nrestr);
	prob.bvec.clear();  prob.bvec.resize(nrestr);
	prob.basis.clear();  prob.basis.resize(nrestr,-1);
	for(int i=0;i<nvars;i++){
		readthrow(in,prob.zfunc[i]);
		if(negate)prob.zfunc[i]=-prob.zfunc[i];
	}
	vector<double> addslack(nrestr,0),addart(nrestr,0);
	for(int i=0;i<nrestr;i++){
		for(int j=0;j<nvars;j++){
			readthrow(in,prob.restr[i][j]);
		}
		readthrow(in,type);
		readthrow(in,prob.bvec[i]);
		if(prob.bvec[i]<0){
			prob.bvec[i]=-prob.bvec[i];
			for(int j=0;j<nvars;j++)prob.restr[i][j]=-prob.restr[i][j];
			if(type=="<=")type=">=";
			else if(type==">=")type="<=";
		}
		if(type=="<="){
			addslack[i]=1;
		} else if(type==">="){
			addslack[i]=-1;
			addart[i]=1;
		} else if(type=="="){
			addart[i]=1;
		} else {
			throw invalid_argument("Invalid restriction type");
		}
	}
	for(int i=0;i<nrestr;i++){
		if(addslack[i]==0)continue;
		prob.vartype.push_back(Problem::VT_SLACK);
		prob.zfunc.emplace_back();
		prob.restr.addcolumn();
		prob.restr[i].back()=addslack[i];
	}
	for(int i=0;i<nrestr;i++){
		if(addart[i]==0)continue;
		prob.vartype.push_back(Problem::VT_ART);
		prob.zfunc.emplace_back();
		prob.restr.addcolumn();
		prob.restr[i].back()=addart[i];
	}
	return in;
}

ostream& operator<<(ostream &os,const Problem &prob){
	const int cellwid=4;
	os<<"min.";
	for(Problem::VarType type : prob.vartype){
		switch(type){
			case Problem::VT_NORMAL: os<<' '<<setw(cellwid)<<'N'; break;
			case Problem::VT_SLACK:  os<<' '<<setw(cellwid)<<'S'; break;
			case Problem::VT_ART:    os<<' '<<setw(cellwid)<<'A'; break;
			default: assert(false);
		}
	}
	os<<endl;

	os<<"z = ";
	for(double c : prob.zfunc){
		os<<' '<<setw(cellwid)<<c;
	}
	os<<" = "<<setw(cellwid)<<prob.zvalue<<endl;

	for(int i=0;i<prob.restr.height();i++){
		if(i==0)os<<"s.t.";
		else os<<"    ";
		for(int j=0;j<prob.restr.width();j++){
			os<<' '<<setw(cellwid)<<prob.restr[i][j];
		}
		os<<" = "<<setw(cellwid)<<prob.bvec[i];
		os<<" (var "<<prob.basis[i]<<')';
		if(i!=prob.restr.height()-1)os<<endl;
	}
	
	return os;
}