diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | solve.cpp | 266 | ||||
| -rwxr-xr-x | test.sh | 20 | 
3 files changed, 288 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bd13184 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +solve +tests/*.tmp diff --git a/solve.cpp b/solve.cpp new file mode 100644 index 0000000..0107521 --- /dev/null +++ b/solve.cpp @@ -0,0 +1,266 @@ +#include <iostream> +#include <string> +#include <cstdlib> +#include <cstdint> +#include <cassert> + +using namespace std; + + +class Board{ +	int8_t *bd; +	 +	inline int8_t& operator[](int i){return bd[i];} +	inline int8_t& at(int x,int y,bool tr=false){ +		if(tr)return bd[S*x+y]; +		else return bd[S*y+x]; +	} + +	string row(int i,bool tr){return (tr?"column ":"row ")+to_string(i+1);} +	string column(int i,bool tr){return (tr?"row ":"column ")+to_string(i+1);} +	string pos(int x,int y,bool tr){return "("+to_string(tr?y+1:x+1)+","+to_string(tr?x+1:y+1)+")";} + +	bool solvefullrows(bool tr){ +		bool changed=false; +		for(int y=0;y<S;y++){ +			int n0=0,n1=0,ex=-1; +			for(int x=0;x<S;x++){ +				switch(at(x,y,tr)){ +					case -1: ex=x; break; +					case 0: n0++; break; +					case 1: n1++; break; +				} +			} +			if(ex!=-1&&n0+n1==S-1){ +				bool v=n0<S/2?0:1; +				at(ex,y,tr)=v; +				cerr<<"Fill "<<row(y,tr)<<": "<<pos(ex,y,tr)<<" -> "<<v<<endl<<*this<<endl<<endl; +				changed=true; +			} +		} +		return changed; +	} +	bool solvefullrows(){return solvefullrows(false)|solvefullrows(true);} + +	bool completerows(bool tr){ +		bool changed=false; +		for(int y=0;y<S;y++){ +			int n0=0,n1=0; +			for(int x=0;x<S;x++){ +				switch(at(x,y,tr)){ +					case 0: n0++; break; +					case 1: n1++; break; +				} +			} +			if(n0==S/2&&n1<S/2){ +				cerr<<"Complete "<<row(y,tr)<<" with ones:"; +				for(int x=0;x<S;x++)if(at(x,y,tr)==-1){ +					at(x,y,tr)=1; +					cerr<<' '<<pos(x,y,tr)<<" -> 1"; +				} +				cerr<<endl<<*this<<endl<<endl; +				changed=true; +			} else if(n1==S/2&&n0<S/2){ +				cerr<<"Complete "<<row(y,tr)<<" with zeros:"; +				for(int x=0;x<S;x++)if(at(x,y,tr)==-1){ +					at(x,y,tr)=0; +					cerr<<' '<<pos(x,y,tr)<<" -> 0"; +				} +				cerr<<endl<<*this<<endl<<endl; +				changed=true; +			} +		} +		return changed; +	} +	bool completerows(){return completerows(false)|completerows(true);} + +	bool solveequalrows(bool tr){ +		bool changed=false; +		for(int y1=0;y1<S;y1++){ +			for(int y2=y1+1;y2<S;y2++){ +				int x,n1=0,n2=0,exa=-1,exb=-1; +				for(x=0;x<S;x++){ +					int v1=at(x,y1,tr),v2=at(x,y2,tr); +					n1+=v1!=-1; n2+=v2!=-1; +					if(v1==-1||v2==-1){ +						exa=exb; +						exb=x; +					} +					if(v1!=-1&&v2!=-1&&v1!=v2)break; +				} +				if(x==S){ +					if(n1==S&&n2==S-2){ +						at(exa,y2,tr)=!at(exa,y1,tr); +						at(exb,y2,tr)=!at(exb,y1,tr); +						cerr<<"Prevent "<<row(y2,tr)<<" from matching "<<row(y1,tr) +						    <<": "<<pos(exa,y2,tr)<<" -> "<<!at(exa,y1,tr) +						    <<", "<<pos(exb,y2,tr)<<" -> "<<!at(exb,y1,tr)<<endl<<*this<<endl<<endl; +						changed=true; +					} else if(n1==S-2&&n2==S){ +						at(exa,y1,tr)=!at(exa,y2,tr); +						at(exb,y1,tr)=!at(exb,y2,tr); +						cerr<<"Prevent "<<row(y1,tr)<<" from matching "<<row(y2,tr) +						    <<": "<<pos(exa,y1,tr)<<" -> "<<!at(exa,y2,tr) +						    <<", "<<pos(exb,y1,tr)<<" -> "<<!at(exb,y2,tr)<<endl<<*this<<endl<<endl; +						changed=true; +					} +				} +			} +		} +		return changed; +	} +	bool solveequalrows(){return solveequalrows(false)|solveequalrows(true);} + +	bool solvetrits(bool tr){ +		bool changed=false; +		for(int y=0;y<S;y++){ +			for(int x=0;x<S;x++){ +				if(at(x,y,tr)!=-1)continue; +				bool opp; +				if(x>0&&x<S-1&&at(x-1,y,tr)!=-1&&at(x-1,y,tr)==at(x+1,y,tr))opp=at(x-1,y,tr); +				else if(x>1&&at(x-2,y,tr)!=-1&&at(x-2,y,tr)==at(x-1,y,tr))opp=at(x-1,y,tr); +				else if(x<S-2&&at(x+2,y,tr)!=-1&&at(x+2,y,tr)==at(x+1,y,tr))opp=at(x+1,y,tr); +				else continue; +				at(x,y,tr)=!opp; +				cerr<<"Prevent trit: "<<pos(x,y,tr)<<" -> "<<!opp<<endl<<*this<<endl<<endl; +				changed=true; +			} +		} +		return changed; +	} +	bool solvetrits(){return solvetrits(false)|solvetrits(true);} + + +	bool verifycounts(bool tr){ +		for(int y=0;y<S;y++){ +			int n0=0,n1=0; +			for(int x=0;x<S;x++){ +				switch(at(x,y,tr)){ +					case 0: n0++; break; +					case 1: n1++; break; +				} +			} +			if(n0!=S/2||n1!=S/2){ +				cerr<<"Count verification failure at "<<row(y,tr)<<endl; +				return false; +			} +		} +		return true; +	} +	bool verifycounts(){return verifycounts(false)&&verifycounts(true);} + +	bool verifytrits(bool tr){ +		for(int y=0;y<S;y++){ +			int pr=-1,n=0; +			for(int x=0;x<S;x++){ +				int v=at(x,y,tr); +				if(pr==v){ +					n++; +					if(pr!=-1&&n>=3){ +						cerr<<"Trit verification failure at "<<pos(x,y,tr)<<endl; +						return false; +					} +				} else { +					pr=v; +					n=0; +				} +			} +		} +		return true; +	} +	bool verifytrits(){return verifytrits(false)&&verifytrits(true);} + +	bool verifyequalrows(bool tr){ +		for(int y1=0;y1<S;y1++){ +			for(int y2=y1+1;y2<S;y2++){ +				int x; +				for(x=0;x<S;x++){ +					if(at(x,y1,tr)!=at(x,y2,tr))break; +				} +				if(x==S){ +					cerr<<"Equal row verification failure on "<<row(y1,tr)<<" and "<<row(y2,tr)<<endl; +					return false; +				} +			} +		} +		return true; +	} +	bool verifyequalrows(){return verifyequalrows(false)&&verifyequalrows(true);} + +public: +	const int S; + +	Board(int S):S(S){ +		if(S<2||S%2==1){ +			cerr<<"Invalid board size "<<S<<endl; +			exit(1); +		} +		bd=new int8_t[S*S]; +	} +	~Board(){ +		delete[] bd; +	} + +	inline int8_t operator[](int i) const {return bd[i];} +	inline int8_t at(int x,int y) const {return bd[S*y+x];} + +	void solve(){ +		while(solvefullrows() +		     |completerows() +		     |solveequalrows() +		     |solvetrits()); +	} + +	bool verify(){ +		return verifycounts()&&verifytrits()&&verifyequalrows(); +	} + +	friend istream& operator>>(istream &is,Board &bd); +	friend ostream& operator<<(ostream &os,const Board &bd); +}; + +istream& operator>>(istream &is,Board &bd){ +	char c; +	for(int i=0;i<bd.S*bd.S;i++){ +		is>>c; +		if(c=='0')bd[i]=0; +		else if(c=='1')bd[i]=1; +		else if(c=='.'||c=='_')bd[i]=-1; +		else { +			cerr<<"Invalid char inputted: '"<<c<<'\''<<endl; +			exit(1); +		} +	} +	return is; +} + +ostream& operator<<(ostream &os,const Board &bd){ +	for(int y=0;y<bd.S;y++){ +		if(y!=0)os<<'\n'; +		for(int x=0;x<bd.S;x++){ +			if(x!=0)os<<' '; +			switch(bd.at(x,y)){ +				case -1: os<<'.'; break; +				case 0: os<<'0'; break; +				case 1: os<<'1'; break; +				default: os<<'?'; break; +			} +		} +	} +	return os; +} + +int main(){ +	int S; +	cin>>S; +	Board bd(S); +	cin>>bd; +	cerr<<bd<<endl<<endl; +	bd.solve(); +	if(!bd.verify()){ +		cerr<<"VERIFICATION FAILURE!"<<endl; +		return 1; +	} +	cout<<bd<<endl; +	return 0; +} @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +total=0 +correct=0 +for inf in tests/*.in; do +	base="${inf//.in}" +	outf="$base.out" +	tmpf="$base.tmp" +	if ! ./solve <"$inf" >"$tmpf" 2>/dev/null; then +		echo "\x1B[33mERROR\x1B[0m solve returned $?" +	fi +	total=$((total+1)) +	if ! diff "$tmpf" "$outf" >/dev/null; then +		echo "failure on $inf" +	else +		rm "$tmpf" +		correct=$((correct+1)) +	fi +done +if test $correct -eq $total; then printf "\x1B[33mOK"; else printf "\x1B[31mFAILURE"; fi +printf "\x1B[0m: %d/%d\n" "$correct" "$total"  | 
