#include <stdexcept>
#include <string>
#include "environment.h"
#include "error.h"

using namespace std;

#undef DEBUG


void Environment::load(const Environment &other){
	for(const auto &p : other.defs){
		defs[p.first]=p.second;
	}
	for(const auto &p : other.hooks){
		hooks[p.first]=p.second;
	}
}

void Environment::define(const Name &name,const AST &ast){
	defs[name]=ast;
}

void Environment::define(const Name &name,const Hook &hook){
	hooks[name]=hook;
}

void Environment::define2(const Name &name,const Hook2 &hook2){
	hooks[name]=[hook2](Environment &env,const AST &arg1) -> AST {
		return AST::makeNative([&env,arg1,hook2](const AST &arg2) -> AST {
			return hook2(env,arg1,arg2);
		});
	};
}

AST Environment::get(const Name &name){
	auto it=defs.find(name);
	if(it==defs.end()){
		auto it=hooks.find(name);
		if(it==hooks.end()){
			throw NameError(name);
		}
		Hook hook=it->second;
		return AST::makeNative([this,hook](const AST &ast){
			return hook(*this,ast);
		});
	}
	return it->second;
}


// ------------------------------------------------------------


void indexReplace(AST &ast,const Name &name,Index index){
	if(ast.quoted){
		return;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::index:
		case AST::Type::native:
			break;

		case AST::Type::name:
			if(ast.nameval==name){
				ast.type=AST::Type::index;
				ast.indexval=index;
			}
			break;

		case AST::Type::tuple:
			for(AST &term : ast.terms){
				indexReplace(term,name,index);
			}
			break;

		case AST::Type::lambda:
			if(ast.lambdaval.arg!=name){
				indexReplace(*ast.lambdaval.body,name,index+1);
			}
			break;
	}
}


void indexify(AST &ast){
	if(ast.quoted){
		return;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::name:
		case AST::Type::index:
		case AST::Type::native:
			break;

		case AST::Type::tuple:
			for(AST &term : ast.terms){
				indexify(term);
			}
			break;

		case AST::Type::lambda:
			indexReplace(*ast.lambdaval.body,ast.lambdaval.arg,1);
			ast.lambdaval.arg="";
			indexify(*ast.lambdaval.body);
			break;
	}
}

void singlify(AST &ast){
	if(ast.quoted){
		return;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::name:
		case AST::Type::index:
		case AST::Type::native:
			break;

		case AST::Type::lambda:
			singlify(*ast.lambdaval.body);
			break;

		case AST::Type::tuple:
			if(ast.terms.size()==0){
				break;
			}
			if(ast.terms.size()==1){
				AST newast=ast.terms[0];
				ast=newast;
				singlify(ast);
				break;
			}
			while(ast.terms.size()>2){
				AST two=AST::makeTuple({ast.terms[0],ast.terms[1]});
				ast.terms.erase(ast.terms.begin()+1);
				ast.terms[0]=move(two);
			}
			if(ast.terms.size()==2){
				singlify(ast.terms[0]);
				singlify(ast.terms[1]);
			}
			break;
	}
}


AST Environment::run(const AST &astinput){
	AST ast(astinput);
	indexify(ast);
#ifdef DEBUG
	cerr<<"indexify gave "<<ast<<endl;
#endif
	singlify(ast);
#ifdef DEBUG
	cerr<<"singlify gave "<<ast<<endl;
#endif
	reduce(ast);
	return ast;
}


void recursiveFindLevel(AST &ast,Index index,vector<pair<AST*,Index>> &nodes){
	if(ast.quoted){
		return;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::name:
		case AST::Type::native:
			break;

		case AST::Type::index:
			if(ast.indexval==index){
				nodes.emplace_back(&ast,index);
			}
			break;

		case AST::Type::tuple:
			for(AST &term : ast.terms){
				recursiveFindLevel(term,index,nodes);
			}
			break;

		case AST::Type::lambda:
			recursiveFindLevel(*ast.lambdaval.body,index+1,nodes);
			break;
	}
}


void increaseFree(AST &ast,Index amount,Index fromIndex){
	if(ast.quoted){
		return;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::name:
		case AST::Type::native:
			break;

		case AST::Type::index:
			if(ast.indexval>=fromIndex){
				ast.indexval+=amount;
			}
			break;

		case AST::Type::tuple:
			for(AST &term : ast.terms){
				increaseFree(term,amount,fromIndex);
			}
			break;

		case AST::Type::lambda:
			increaseFree(*ast.lambdaval.body,amount,fromIndex+1);
			break;
	}
}


bool hasFree(AST &ast,Index fromIndex){
	if(ast.quoted){
		return false;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::native:
			return false;

		case AST::Type::name:
			return true;

		case AST::Type::index:
			return ast.indexval>=fromIndex;

		case AST::Type::tuple:
			for(AST &term : ast.terms){
				if(hasFree(term,fromIndex)){
					return true;
				}
			}
			return false;

		case AST::Type::lambda:
			return hasFree(*ast.lambdaval.body,fromIndex+1);
	}
}


#ifdef DEBUG
static string indent(i64 depth){
	string s(2*depth,' ');
	for(i64 i=0;i<depth;i++){
		s[2*i]='|';
	}
	return s;
}
#endif

bool Environment::betareduce(AST &ast,i64 depth){
	if(ast.quoted){
		return false;
	}

	if(ast.type!=AST::Type::tuple){
		return false;
	}
	if(ast.terms.size()!=2){
		return false;
	}

#ifdef DEBUG
	cerr<<indent(depth)<<"Betareducing "<<ast<<endl;
#endif

	bool success;

	if(ast.terms[0].type==AST::Type::lambda){
		// cerr<<"=== β-REDUCE LAMBDA ==="<<endl;
		// cerr<<"ast = "<<ast<<endl;
		AST newterm=*ast.terms[0].lambdaval.body;
		ast.terms[0]=newterm;
		// cerr<<"ast = "<<ast<<endl;

		vector<pair<AST*,Index>> repl;
		recursiveFindLevel(ast.terms[0],1,repl);
		// cerr<<"Level 2:"; for(i64 i=0;i<(i64)repl.size();i++)cerr<<" {"<<*repl[i].first<<','<<repl[i].second<<'}'; cerr<<endl;

		increaseFree(ast.terms[0],-1,2);
		
		for(const pair<AST*,Index> &p : repl){
			*p.first=ast.terms[1];
			increaseFree(*p.first,p.second-1,1);
		}
		newterm=ast.terms[0];
		ast=newterm;
		success=true;
	} else if(ast.terms[0].type==AST::Type::native){
		reduce(ast.terms[1],depth+1);
		ast=ast.terms[0].nativeval(ast.terms[1]);
		success=true;
	} else if(ast.terms[0].type==AST::Type::name&&ast.terms[0].nameval=="do"){
		reduce(ast.terms[1],depth+1);
		ast=ast.terms[0];
		success=true;
	} else {
		success=false;
	}

#ifdef DEBUG
	cerr<<indent(depth)<<"'=> "<<ast<<endl;
#endif
	return success;
}


void etareduce(AST &ast){
	if(ast.quoted){
		return;
	}

	// do nothing yet
}


void Environment::reduce(AST &ast,i64 depth){
#ifdef DEBUG
	cerr<<indent(depth)<<"Reducing "<<ast<<endl;
#endif
	if(ast.quoted){
		return;
	}

	switch(ast.type){
		case AST::Type::number:
		case AST::Type::string:
		case AST::Type::index:
		case AST::Type::native:
			break;

		case AST::Type::name:
			if(ast.nameval!="do"){
				ast=get(ast.nameval);
				indexify(ast);
			}
			break;

		case AST::Type::lambda:
			reduce(*ast.lambdaval.body,depth+1);
			etareduce(ast);
			break;

		case AST::Type::tuple:
			// reduce(ast.terms[0],depth+1);
			// reduce(ast.terms[1],depth+1); // TODO: this is dodgy
			// while(betareduce(ast,depth+1)){}
			do {
				reduce(ast.terms[0],depth+1);
				if(!betareduce(ast,depth+1)){
					break;
				}
			} while(ast.type==AST::Type::tuple&&ast.terms.size()>0);
			if(ast.type==AST::Type::name){
				// TODO: this is really dodgy; why is this necessary?
				reduce(ast,depth+1);
			}
			break;
	}
#ifdef DEBUG
	cerr<<indent(depth)<<"'-> "<<ast<<endl;
#endif
}