diff options
author | tomsmeding <tom.smeding@gmail.com> | 2016-10-03 13:00:14 +0200 |
---|---|---|
committer | tomsmeding <tom.smeding@gmail.com> | 2016-10-03 13:00:14 +0200 |
commit | 2bf5effe95641667a1ed51c04eff7760f6a42ef4 (patch) | |
tree | ee02ed7d87ebcfae07ac1e69017d6edced3ed386 /bigint.cpp |
Initial
Diffstat (limited to 'bigint.cpp')
-rw-r--r-- | bigint.cpp | 498 |
1 files changed, 498 insertions, 0 deletions
diff --git a/bigint.cpp b/bigint.cpp new file mode 100644 index 0000000..5182d07 --- /dev/null +++ b/bigint.cpp @@ -0,0 +1,498 @@ +#include <iomanip> +#include <stdexcept> +#include <cctype> +#include <cassert> +#include "bigint.h" +#include "numalgo.h" + +using namespace std; + +Bigint::Bigint() + :sign(1){} + +Bigint::Bigint(slongdigit_t v) + :digits(1,abs(v)),sign(v>=0?1:-1){ + static_assert(sizeof(longdigit_t)==2*sizeof(digit_t), + "longdigit_t should be twice as large as digit_t"); + v=abs(v); + if(v>digits[0])digits.push_back(v>>digit_bits); + else if(v==0)digits.clear(); + checkconsistent(); +} + +void Bigint::add(Bigint &a,const Bigint &b){ + if(a.digits.size()<b.digits.size())a.digits.resize(b.digits.size()); + int sz=a.digits.size(); + int carry=0; + for(int i=0;i<sz;i++){ + longdigit_t bdig=i<(int)b.digits.size()?b.digits[i]:0; + longdigit_t sum=a.digits[i]+bdig+carry; + a.digits[i]=sum; + // carry=sum>=((longdigit_t)1<<digit_bits); + carry=sum>>digit_bits; + } + if(carry)a.digits.push_back(1); + a.normalise(); + a.checkconsistent(); +} + +void Bigint::subtract(Bigint &a,const Bigint &b){ + if(a.digits.size()<b.digits.size()){ + a.digits.resize(b.digits.size()); //adds zeros + } + assert(a.digits.size()>=b.digits.size()); + if(a.digits.size()==0){ + a.checkconsistent(); + return; + } + assert(a.digits.size()>0); + int sz=a.digits.size(); + int carry=0; + for(int i=0;i<sz;i++){ + if(i>=(int)b.digits.size()&&!carry)break; + digit_t adig=a.digits[i]; + digit_t bdig=i<(int)b.digits.size()?b.digits[i]:0; + digit_t res=adig-(bdig+carry); + // cerr<<"carry="<<carry<<" res="<<res<<" adig="<<adig<<" bdig="<<bdig<<endl; + carry=(bdig||carry)&&res>=adig; + a.digits[i]=res; + } + if(carry){ + // cerr<<"2s complement"<<endl; + //we do a fake 2s complement, sort of + carry=0; + for(int i=0;i<sz;i++){ + a.digits[i]=~a.digits[i]; + a.digits[i]+=(i==0)+carry; + carry=a.digits[i]<=(digit_t)carry; + } + a.sign=-a.sign; + } + a.shrink(); + a.normalise(); + a.checkconsistent(); +} + +Bigint Bigint::product(const Bigint &a,const Bigint &b){ + int asz=a.digits.size(),bsz=b.digits.size(); + if(asz==0||bsz==0)return Bigint(); + Bigint res; + res.digits.resize(asz+bsz); + for(int i=0;i<asz;i++){ + digit_t carry=0; + for(int j=0;j<bsz;j++){ + longdigit_t pr=(longdigit_t)a.digits[i]*b.digits[j]+carry; + longdigit_t newd=pr+res.digits[i+j]; //this always fits, I checked + res.digits[i+j]=(digit_t)newd; + carry=newd>>digit_bits; + // cerr<<"carry="<<carry<<endl; + } + for(int j=bsz;carry;j++){ + assert(i+j<(int)res.digits.size()); + longdigit_t newd=res.digits[i+j]+carry; + res.digits[i+j]=newd; + carry=newd>>digit_bits; + // cerr<<"(2) carry="<<carry<<endl; + } + } + res.sign=a.sign*b.sign; + res.shrink(); + res.normalise(); + res.checkconsistent(); + return res; +} + +void Bigint::shrink(){ + while(digits.size()&&digits.back()==0)digits.pop_back(); +} + +void Bigint::normalise(){ + if(digits.size()==0&&sign==-1)sign=1; +} + +void Bigint::checkconsistent(){ + assert(digits.size()==0||digits.back()!=0); + assert(digits.size()!=0||sign==1); +} + +Bigint& Bigint::operator=(slongdigit_t v){ + digits.resize(1); + sign=v>=0?1:-1; + v*=sign; + digits[0]=v; + if(v>digits[0])digits.push_back(v>>digit_bits); + shrink(); + normalise(); + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator+=(const Bigint &o){ + if(&o==this){ + return *this=Bigint(*this)+=o; + } + if(sign==1){ + if(o.sign==1)add(*this,o); + else subtract(*this,o); + } else { + if(o.sign==1)subtract(*this,o); + else add(*this,o); + } + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator-=(const Bigint &o){ + if(&o==this){ + return *this=Bigint(*this)-=o; + } + if(sign==1){ + if(o.sign==1)subtract(*this,o); + else add(*this,o); + } else { + if(o.sign==1)add(*this,o); + else subtract(*this,o); + } + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator*=(const Bigint &o){ + *this=product(*this,o); + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator<<=(int sh){ + if(sh==0)return *this; + if(digits.size()==0)return *this; + if(sh<0)return *this>>=-sh; + if(sh/digit_bits>0){ + digits.insert(digits.begin(),sh/digit_bits,0); + sh%=digit_bits; + if(sh==0){ + checkconsistent(); + return *this; + } + } + digits.push_back(0); + for(int i=digits.size()-2;i>=0;i--){ + digits[i+1]|=digits[i]>>(digit_bits-sh); + digits[i]<<=sh; + } + shrink(); + normalise(); + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator>>=(int sh){ + if(sh==0)return *this; + if(digits.size()==0)return *this; + if(sh<0)return *this<<=-sh; + if(sh/digit_bits>0){ + if(sh/digit_bits>=(int)digits.size()){ + digits.clear(); + sign=1; + checkconsistent(); + return *this; + } + digits.erase(digits.begin(),digits.begin()+sh/digit_bits); + sh%=digit_bits; + if(sh==0){ + checkconsistent(); + return *this; + } + } + digits[0]>>=sh; + int sz=digits.size(); + for(int i=1;i<sz;i++){ + digits[i-1]|=digits[i]<<(digit_bits-sh); + digits[i]>>=sh; + } + shrink(); + normalise(); + checkconsistent(); + return *this; +} + +Bigint& Bigint::negate(){ + sign=-sign; + return *this; +} + +Bigint Bigint::operator+(const Bigint &o) const { + return Bigint(*this)+=o; +} + +Bigint Bigint::operator-(const Bigint &o) const { + return Bigint(*this)-=o; +} + +Bigint Bigint::operator*(const Bigint &o) const { + return product(*this,o); +} + +int depthrecord; + +pair<Bigint,Bigint> Bigint::divmod(const Bigint &div) const { + pair<Bigint,Bigint> res=divmod(div,1); + //cerr<<hex<<*this<<' '<<div<<' '; + //cerr<<dec<<depthrecord<<endl; + return res; +} + +inline void record(int depth){ + depthrecord=depth; +} + +pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const { + //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<')'<<endl; + if(div.digits.size()==0)throw domain_error("Bigint divide by zero"); + if(digits.size()==0){record(depth); return make_pair(Bigint(0),Bigint(0));} + + int cmp=compareAbs(div); + if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint(0));} + if(cmp<0){record(depth); return make_pair(Bigint(0),*this);} + //now *this is greater in magnitude than the divisor + +#if 0 + int twoexp=bitcount()-div.bitcount(); + if(twoexp<0)twoexp=0; + Bigint quotient(sign*div.sign); + quotient<<=twoexp; + Bigint guess(div); //guess == quotient * div + guess<<=twoexp; + //cerr<<"guess="<<hex<<guess<<endl; + //twoexp now becomes irrelevant +#else + Bigint guess,quotient; + if(digits.size()==div.digits.size()){ + quotient=digits.back()/div.digits.back(); + guess=quotient; + guess*=div; + } else { + assert(digits.size()>div.digits.size()); + assert(digits.size()>=2&&div.digits.size()>=1); + longdigit_t factor=((longdigit_t)1<<digit_bits)*digits.back()+digits[digits.size()-2]; + factor/=div.digits.back()+1; + quotient=factor; + quotient<<=(digits.size()-div.digits.size()-1)*digit_bits; + guess=quotient; + guess*=div; + } + // cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<endl; +#endif + cmp=guess.compareAbs(*this); + if(cmp<0){ + Bigint guess2(guess); + while(true){ + guess2<<=1; + int cmp=guess2.compareAbs(*this); + if(cmp>0)break; + guess<<=1; + quotient<<=1; + if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + } + Bigint rest(*this); + rest-=guess; //also correct for *this and guess negative + pair<Bigint,Bigint> dm=rest.divmod(div,depth+1); + dm.first+=quotient; + return dm; + } + if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + //then cmp>0, so our guess is too large + Bigint one(1); + do { + if(quotient.digits[0]&1){ + guess-=div; + // quotient-=one; // not necessary, since we shift the bit out anyway + } + guess>>=1; + quotient>>=1; + cmp=guess.compareAbs(*this); + } while(cmp>0); + if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + Bigint rest(*this); + rest-=guess; + pair<Bigint,Bigint> dm=rest.divmod(div,depth+1); + dm.first+=quotient; + return dm; +} + +bool Bigint::operator==(const Bigint &o) const {return compare(o)==0;} +bool Bigint::operator!=(const Bigint &o) const {return compare(o)!=0;} +bool Bigint::operator<(const Bigint &o) const {return compare(o)<0;} +bool Bigint::operator>(const Bigint &o) const {return compare(o)>0;} +bool Bigint::operator<=(const Bigint &o) const {return compare(o)<=0;} +bool Bigint::operator>=(const Bigint &o) const {return compare(o)>=0;} + +bool Bigint::operator==(slongdigit_t v) const {return compare(v)==0;} +bool Bigint::operator!=(slongdigit_t v) const {return compare(v)!=0;} +bool Bigint::operator<(slongdigit_t v) const {return compare(v)<0;} +bool Bigint::operator>(slongdigit_t v) const {return compare(v)>0;} +bool Bigint::operator<=(slongdigit_t v) const {return compare(v)<=0;} +bool Bigint::operator>=(slongdigit_t v) const {return compare(v)>=0;} + +int Bigint::compare(const Bigint &o) const { + if(sign>o.sign)return 1; + if(sign<o.sign)return -1; + return sign*compareAbs(o); +} + +int Bigint::compare(slongdigit_t v) const { + if(sign==-1&&v>=0)return -1; + if(sign==1&&v<0)return 1; + return sign*compareAbs(v); +} + +int Bigint::compareAbs(const Bigint &o) const { + int sz=digits.size(),osz=o.digits.size(); + if(sz>osz)return 1; + if(sz<osz)return -1; + for(int i=sz-1;i>=0;i--){ + if(digits[i]>o.digits[i])return 1; + if(digits[i]<o.digits[i])return -1; + } + return 0; +} + +int Bigint::compareAbs(slongdigit_t v) const { + v=abs(v); + if(digits.size()>2)return 1; + if(digits.size()==0)return v==0?0:-1; + if(digits.size()==2){ + if(digits[1]>(digit_t)(v>>digit_bits))return 1; + if(digits[1]<(digit_t)(v>>digit_bits))return -1; + } + if(digits[0]<(digit_t)v)return -1; + if(digits[0]>(digit_t)v)return 1; + return 0; +} + +int Bigint::bitcount() const { + if(digits.size()==0)return 0; + return (digits.size()-1)*8+ilog2(digits.back())+1; +} + +Bigint::slongdigit_t Bigint::lowdigits() const { + if(digits.size()==0)return 0; + if(digits.size()==1)return digits[0]; + longdigit_t mask=~((longdigit_t)1<<(digit_bits-1)); + return ((slongdigit_t)1<<digit_bits)*(digits[1]&mask)+digits[0]; +} + +vector<char> Bigint::serialise() const { + vector<char> v(1+digits.size()*sizeof(digit_t)); + v[0]=sign; + int sz=digits.size(); + for(int i=0;i<sz;i++){ + for(int j=0;j<(int)sizeof(digit_t);j++){ + v[1+i*sizeof(digit_t)+j]=(digits[i]>>(8*j))&255; + } + } + return v; +} + +void Bigint::deserialise(const vector<char> &v){ + assert(v.size()%4==1); + sign=(int)v[0]; + assert(sign==1||sign==-1); + int sz=v.size()/4; + digits.resize(sz); + for(int i=0;i<sz;i++){ + digits[i]=0; + for(int j=0;j<(int)sizeof(digit_t);j++){ + digits[i]|=v[1+i*sizeof(digit_t)+j]<<(8*j); + } + } + shrink(); + normalise(); + checkconsistent(); +} + +vector<bool> Bigint::bits() const { + if(digits.size()==0)return {}; + vector<bool> v(digit_bits*(digits.size()-1)+ilog2(digits.back())+1); + int sz=digits.size(); + for(int i=0;i<sz;i++){ + digit_t dig=digits[i]; + for(int j=0;dig;j++){ + if(dig&1)v[digit_bits*i+j]=true; + dig>>=1; + } + } + return v; +} + +Bigint::digit_t Bigint::_digit(int idx) const { + return digits.at(idx); +} + +istream& operator>>(istream &is,Bigint &b){ + while(isspace(is.peek()))is.get(); + if(!is)return is; + b.digits.resize(0); + b.sign=1; + Bigint ten(10); + bool acted=false; + while(true){ + char c=is.peek(); + if(!isdigit(c))break; + acted=true; + is.get(); + if(!is){ + b.checkconsistent(); + return is; + } + b*=ten; + b+=Bigint(c-'0'); + // cerr<<"b="<<b<<endl; + } + if(!acted)is.setstate(ios_base::failbit); + b.checkconsistent(); + return is; +} + +std::ostream& operator<<(std::ostream &os,Bigint b){ + if(b<0){ + os<<'-'; + b.negate(); + } + if(os.flags()&ios_base::hex){ + os<<"0x"; + if(b.digits.size()==0)return os<<'0'; + os<<b.digits.back(); + for(int i=b.digits.size()-2;i>=0;i--){ + os<<setw(Bigint::digit_bits/4)<<setfill('0')<<b.digits[i]; + } + return os; + } +#if 0 + assert(b.digits.size()<=2); + if(b.sign==-1)os<<'-'; + if(b.digits.size()==2)os<<((uint64_t)b.digits[1]<<32)+b.digits[0]; + else if(b.digits.size()==1)os<<b.digits[0]; + else os<<'0'; + return os; +#else + if(b==0)return os<<'0'; + Bigint div(1000000000000000000LL); + vector<Bigint::longdigit_t> outbuf; + while(b!=0){ + pair<Bigint,Bigint> dm=b.divmod(div); + b=dm.first; + Bigint::longdigit_t val=0; + assert(dm.second.digits.size()<=2); + if(dm.second.digits.size()>=2) + val+=((Bigint::longdigit_t)1<<Bigint::digit_bits)*dm.second.digits[1]; + if(dm.second.digits.size()>=1) + val+=dm.second.digits[0]; + outbuf.push_back(val); + } + for(int i=outbuf.size()-1;i>=0;i--){ + (i==(int)outbuf.size()-1?os:os<<setfill('0')<<setw(18))<<outbuf[i]; + } + return os; +#endif +} |