aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--Makefile24
-rw-r--r--bigint.cpp498
-rw-r--r--bigint.h87
-rwxr-xr-xbiginttest.py38
-rw-r--r--main.cpp141
-rw-r--r--numalgo.cpp50
-rw-r--r--numalgo.h11
-rw-r--r--rsa.cpp12
-rw-r--r--rsa.h15
10 files changed, 880 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..89805cd
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.o
+*.dSYM
+main
+*.txt
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..14220de
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,24 @@
+CXX = g++
+CXXFLAGS = -Wall -Wextra -std=c++11 -fwrapv
+ifneq ($(DEBUG),)
+ CXXFLAGS += -g
+else
+ CXXFLAGS += -O2
+endif
+BIN = main
+
+.PHONY: all clean remake
+
+all: $(BIN)
+
+clean:
+ rm -rf $(BIN) *.o *.dSYM
+
+remake: clean all
+
+
+$(BIN): $(patsubst %.cpp,%.o,$(wildcard *.cpp))
+ $(CXX) -o $@ $^
+
+%.o: %.cpp $(wildcard *.h)
+ $(CXX) $(CXXFLAGS) -c -o $@ $<
diff --git a/bigint.cpp b/bigint.cpp
new file mode 100644
index 0000000..5182d07
--- /dev/null
+++ b/bigint.cpp
@@ -0,0 +1,498 @@
+#include <iomanip>
+#include <stdexcept>
+#include <cctype>
+#include <cassert>
+#include "bigint.h"
+#include "numalgo.h"
+
+using namespace std;
+
+Bigint::Bigint()
+ :sign(1){}
+
+Bigint::Bigint(slongdigit_t v)
+ :digits(1,abs(v)),sign(v>=0?1:-1){
+ static_assert(sizeof(longdigit_t)==2*sizeof(digit_t),
+ "longdigit_t should be twice as large as digit_t");
+ v=abs(v);
+ if(v>digits[0])digits.push_back(v>>digit_bits);
+ else if(v==0)digits.clear();
+ checkconsistent();
+}
+
+void Bigint::add(Bigint &a,const Bigint &b){
+ if(a.digits.size()<b.digits.size())a.digits.resize(b.digits.size());
+ int sz=a.digits.size();
+ int carry=0;
+ for(int i=0;i<sz;i++){
+ longdigit_t bdig=i<(int)b.digits.size()?b.digits[i]:0;
+ longdigit_t sum=a.digits[i]+bdig+carry;
+ a.digits[i]=sum;
+ // carry=sum>=((longdigit_t)1<<digit_bits);
+ carry=sum>>digit_bits;
+ }
+ if(carry)a.digits.push_back(1);
+ a.normalise();
+ a.checkconsistent();
+}
+
+void Bigint::subtract(Bigint &a,const Bigint &b){
+ if(a.digits.size()<b.digits.size()){
+ a.digits.resize(b.digits.size()); //adds zeros
+ }
+ assert(a.digits.size()>=b.digits.size());
+ if(a.digits.size()==0){
+ a.checkconsistent();
+ return;
+ }
+ assert(a.digits.size()>0);
+ int sz=a.digits.size();
+ int carry=0;
+ for(int i=0;i<sz;i++){
+ if(i>=(int)b.digits.size()&&!carry)break;
+ digit_t adig=a.digits[i];
+ digit_t bdig=i<(int)b.digits.size()?b.digits[i]:0;
+ digit_t res=adig-(bdig+carry);
+ // cerr<<"carry="<<carry<<" res="<<res<<" adig="<<adig<<" bdig="<<bdig<<endl;
+ carry=(bdig||carry)&&res>=adig;
+ a.digits[i]=res;
+ }
+ if(carry){
+ // cerr<<"2s complement"<<endl;
+ //we do a fake 2s complement, sort of
+ carry=0;
+ for(int i=0;i<sz;i++){
+ a.digits[i]=~a.digits[i];
+ a.digits[i]+=(i==0)+carry;
+ carry=a.digits[i]<=(digit_t)carry;
+ }
+ a.sign=-a.sign;
+ }
+ a.shrink();
+ a.normalise();
+ a.checkconsistent();
+}
+
+Bigint Bigint::product(const Bigint &a,const Bigint &b){
+ int asz=a.digits.size(),bsz=b.digits.size();
+ if(asz==0||bsz==0)return Bigint();
+ Bigint res;
+ res.digits.resize(asz+bsz);
+ for(int i=0;i<asz;i++){
+ digit_t carry=0;
+ for(int j=0;j<bsz;j++){
+ longdigit_t pr=(longdigit_t)a.digits[i]*b.digits[j]+carry;
+ longdigit_t newd=pr+res.digits[i+j]; //this always fits, I checked
+ res.digits[i+j]=(digit_t)newd;
+ carry=newd>>digit_bits;
+ // cerr<<"carry="<<carry<<endl;
+ }
+ for(int j=bsz;carry;j++){
+ assert(i+j<(int)res.digits.size());
+ longdigit_t newd=res.digits[i+j]+carry;
+ res.digits[i+j]=newd;
+ carry=newd>>digit_bits;
+ // cerr<<"(2) carry="<<carry<<endl;
+ }
+ }
+ res.sign=a.sign*b.sign;
+ res.shrink();
+ res.normalise();
+ res.checkconsistent();
+ return res;
+}
+
+void Bigint::shrink(){
+ while(digits.size()&&digits.back()==0)digits.pop_back();
+}
+
+void Bigint::normalise(){
+ if(digits.size()==0&&sign==-1)sign=1;
+}
+
+void Bigint::checkconsistent(){
+ assert(digits.size()==0||digits.back()!=0);
+ assert(digits.size()!=0||sign==1);
+}
+
+Bigint& Bigint::operator=(slongdigit_t v){
+ digits.resize(1);
+ sign=v>=0?1:-1;
+ v*=sign;
+ digits[0]=v;
+ if(v>digits[0])digits.push_back(v>>digit_bits);
+ shrink();
+ normalise();
+ checkconsistent();
+ return *this;
+}
+
+Bigint& Bigint::operator+=(const Bigint &o){
+ if(&o==this){
+ return *this=Bigint(*this)+=o;
+ }
+ if(sign==1){
+ if(o.sign==1)add(*this,o);
+ else subtract(*this,o);
+ } else {
+ if(o.sign==1)subtract(*this,o);
+ else add(*this,o);
+ }
+ checkconsistent();
+ return *this;
+}
+
+Bigint& Bigint::operator-=(const Bigint &o){
+ if(&o==this){
+ return *this=Bigint(*this)-=o;
+ }
+ if(sign==1){
+ if(o.sign==1)subtract(*this,o);
+ else add(*this,o);
+ } else {
+ if(o.sign==1)add(*this,o);
+ else subtract(*this,o);
+ }
+ checkconsistent();
+ return *this;
+}
+
+Bigint& Bigint::operator*=(const Bigint &o){
+ *this=product(*this,o);
+ checkconsistent();
+ return *this;
+}
+
+Bigint& Bigint::operator<<=(int sh){
+ if(sh==0)return *this;
+ if(digits.size()==0)return *this;
+ if(sh<0)return *this>>=-sh;
+ if(sh/digit_bits>0){
+ digits.insert(digits.begin(),sh/digit_bits,0);
+ sh%=digit_bits;
+ if(sh==0){
+ checkconsistent();
+ return *this;
+ }
+ }
+ digits.push_back(0);
+ for(int i=digits.size()-2;i>=0;i--){
+ digits[i+1]|=digits[i]>>(digit_bits-sh);
+ digits[i]<<=sh;
+ }
+ shrink();
+ normalise();
+ checkconsistent();
+ return *this;
+}
+
+Bigint& Bigint::operator>>=(int sh){
+ if(sh==0)return *this;
+ if(digits.size()==0)return *this;
+ if(sh<0)return *this<<=-sh;
+ if(sh/digit_bits>0){
+ if(sh/digit_bits>=(int)digits.size()){
+ digits.clear();
+ sign=1;
+ checkconsistent();
+ return *this;
+ }
+ digits.erase(digits.begin(),digits.begin()+sh/digit_bits);
+ sh%=digit_bits;
+ if(sh==0){
+ checkconsistent();
+ return *this;
+ }
+ }
+ digits[0]>>=sh;
+ int sz=digits.size();
+ for(int i=1;i<sz;i++){
+ digits[i-1]|=digits[i]<<(digit_bits-sh);
+ digits[i]>>=sh;
+ }
+ shrink();
+ normalise();
+ checkconsistent();
+ return *this;
+}
+
+Bigint& Bigint::negate(){
+ sign=-sign;
+ return *this;
+}
+
+Bigint Bigint::operator+(const Bigint &o) const {
+ return Bigint(*this)+=o;
+}
+
+Bigint Bigint::operator-(const Bigint &o) const {
+ return Bigint(*this)-=o;
+}
+
+Bigint Bigint::operator*(const Bigint &o) const {
+ return product(*this,o);
+}
+
+int depthrecord;
+
+pair<Bigint,Bigint> Bigint::divmod(const Bigint &div) const {
+ pair<Bigint,Bigint> res=divmod(div,1);
+ //cerr<<hex<<*this<<' '<<div<<' ';
+ //cerr<<dec<<depthrecord<<endl;
+ return res;
+}
+
+inline void record(int depth){
+ depthrecord=depth;
+}
+
+pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const {
+ //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<')'<<endl;
+ if(div.digits.size()==0)throw domain_error("Bigint divide by zero");
+ if(digits.size()==0){record(depth); return make_pair(Bigint(0),Bigint(0));}
+
+ int cmp=compareAbs(div);
+ if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint(0));}
+ if(cmp<0){record(depth); return make_pair(Bigint(0),*this);}
+ //now *this is greater in magnitude than the divisor
+
+#if 0
+ int twoexp=bitcount()-div.bitcount();
+ if(twoexp<0)twoexp=0;
+ Bigint quotient(sign*div.sign);
+ quotient<<=twoexp;
+ Bigint guess(div); //guess == quotient * div
+ guess<<=twoexp;
+ //cerr<<"guess="<<hex<<guess<<endl;
+ //twoexp now becomes irrelevant
+#else
+ Bigint guess,quotient;
+ if(digits.size()==div.digits.size()){
+ quotient=digits.back()/div.digits.back();
+ guess=quotient;
+ guess*=div;
+ } else {
+ assert(digits.size()>div.digits.size());
+ assert(digits.size()>=2&&div.digits.size()>=1);
+ longdigit_t factor=((longdigit_t)1<<digit_bits)*digits.back()+digits[digits.size()-2];
+ factor/=div.digits.back()+1;
+ quotient=factor;
+ quotient<<=(digits.size()-div.digits.size()-1)*digit_bits;
+ guess=quotient;
+ guess*=div;
+ }
+ // cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<endl;
+#endif
+ cmp=guess.compareAbs(*this);
+ if(cmp<0){
+ Bigint guess2(guess);
+ while(true){
+ guess2<<=1;
+ int cmp=guess2.compareAbs(*this);
+ if(cmp>0)break;
+ guess<<=1;
+ quotient<<=1;
+ if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));}
+ }
+ Bigint rest(*this);
+ rest-=guess; //also correct for *this and guess negative
+ pair<Bigint,Bigint> dm=rest.divmod(div,depth+1);
+ dm.first+=quotient;
+ return dm;
+ }
+ if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));}
+ //then cmp>0, so our guess is too large
+ Bigint one(1);
+ do {
+ if(quotient.digits[0]&1){
+ guess-=div;
+ // quotient-=one; // not necessary, since we shift the bit out anyway
+ }
+ guess>>=1;
+ quotient>>=1;
+ cmp=guess.compareAbs(*this);
+ } while(cmp>0);
+ if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));}
+ Bigint rest(*this);
+ rest-=guess;
+ pair<Bigint,Bigint> dm=rest.divmod(div,depth+1);
+ dm.first+=quotient;
+ return dm;
+}
+
+bool Bigint::operator==(const Bigint &o) const {return compare(o)==0;}
+bool Bigint::operator!=(const Bigint &o) const {return compare(o)!=0;}
+bool Bigint::operator<(const Bigint &o) const {return compare(o)<0;}
+bool Bigint::operator>(const Bigint &o) const {return compare(o)>0;}
+bool Bigint::operator<=(const Bigint &o) const {return compare(o)<=0;}
+bool Bigint::operator>=(const Bigint &o) const {return compare(o)>=0;}
+
+bool Bigint::operator==(slongdigit_t v) const {return compare(v)==0;}
+bool Bigint::operator!=(slongdigit_t v) const {return compare(v)!=0;}
+bool Bigint::operator<(slongdigit_t v) const {return compare(v)<0;}
+bool Bigint::operator>(slongdigit_t v) const {return compare(v)>0;}
+bool Bigint::operator<=(slongdigit_t v) const {return compare(v)<=0;}
+bool Bigint::operator>=(slongdigit_t v) const {return compare(v)>=0;}
+
+int Bigint::compare(const Bigint &o) const {
+ if(sign>o.sign)return 1;
+ if(sign<o.sign)return -1;
+ return sign*compareAbs(o);
+}
+
+int Bigint::compare(slongdigit_t v) const {
+ if(sign==-1&&v>=0)return -1;
+ if(sign==1&&v<0)return 1;
+ return sign*compareAbs(v);
+}
+
+int Bigint::compareAbs(const Bigint &o) const {
+ int sz=digits.size(),osz=o.digits.size();
+ if(sz>osz)return 1;
+ if(sz<osz)return -1;
+ for(int i=sz-1;i>=0;i--){
+ if(digits[i]>o.digits[i])return 1;
+ if(digits[i]<o.digits[i])return -1;
+ }
+ return 0;
+}
+
+int Bigint::compareAbs(slongdigit_t v) const {
+ v=abs(v);
+ if(digits.size()>2)return 1;
+ if(digits.size()==0)return v==0?0:-1;
+ if(digits.size()==2){
+ if(digits[1]>(digit_t)(v>>digit_bits))return 1;
+ if(digits[1]<(digit_t)(v>>digit_bits))return -1;
+ }
+ if(digits[0]<(digit_t)v)return -1;
+ if(digits[0]>(digit_t)v)return 1;
+ return 0;
+}
+
+int Bigint::bitcount() const {
+ if(digits.size()==0)return 0;
+ return (digits.size()-1)*8+ilog2(digits.back())+1;
+}
+
+Bigint::slongdigit_t Bigint::lowdigits() const {
+ if(digits.size()==0)return 0;
+ if(digits.size()==1)return digits[0];
+ longdigit_t mask=~((longdigit_t)1<<(digit_bits-1));
+ return ((slongdigit_t)1<<digit_bits)*(digits[1]&mask)+digits[0];
+}
+
+vector<char> Bigint::serialise() const {
+ vector<char> v(1+digits.size()*sizeof(digit_t));
+ v[0]=sign;
+ int sz=digits.size();
+ for(int i=0;i<sz;i++){
+ for(int j=0;j<(int)sizeof(digit_t);j++){
+ v[1+i*sizeof(digit_t)+j]=(digits[i]>>(8*j))&255;
+ }
+ }
+ return v;
+}
+
+void Bigint::deserialise(const vector<char> &v){
+ assert(v.size()%4==1);
+ sign=(int)v[0];
+ assert(sign==1||sign==-1);
+ int sz=v.size()/4;
+ digits.resize(sz);
+ for(int i=0;i<sz;i++){
+ digits[i]=0;
+ for(int j=0;j<(int)sizeof(digit_t);j++){
+ digits[i]|=v[1+i*sizeof(digit_t)+j]<<(8*j);
+ }
+ }
+ shrink();
+ normalise();
+ checkconsistent();
+}
+
+vector<bool> Bigint::bits() const {
+ if(digits.size()==0)return {};
+ vector<bool> v(digit_bits*(digits.size()-1)+ilog2(digits.back())+1);
+ int sz=digits.size();
+ for(int i=0;i<sz;i++){
+ digit_t dig=digits[i];
+ for(int j=0;dig;j++){
+ if(dig&1)v[digit_bits*i+j]=true;
+ dig>>=1;
+ }
+ }
+ return v;
+}
+
+Bigint::digit_t Bigint::_digit(int idx) const {
+ return digits.at(idx);
+}
+
+istream& operator>>(istream &is,Bigint &b){
+ while(isspace(is.peek()))is.get();
+ if(!is)return is;
+ b.digits.resize(0);
+ b.sign=1;
+ Bigint ten(10);
+ bool acted=false;
+ while(true){
+ char c=is.peek();
+ if(!isdigit(c))break;
+ acted=true;
+ is.get();
+ if(!is){
+ b.checkconsistent();
+ return is;
+ }
+ b*=ten;
+ b+=Bigint(c-'0');
+ // cerr<<"b="<<b<<endl;
+ }
+ if(!acted)is.setstate(ios_base::failbit);
+ b.checkconsistent();
+ return is;
+}
+
+std::ostream& operator<<(std::ostream &os,Bigint b){
+ if(b<0){
+ os<<'-';
+ b.negate();
+ }
+ if(os.flags()&ios_base::hex){
+ os<<"0x";
+ if(b.digits.size()==0)return os<<'0';
+ os<<b.digits.back();
+ for(int i=b.digits.size()-2;i>=0;i--){
+ os<<setw(Bigint::digit_bits/4)<<setfill('0')<<b.digits[i];
+ }
+ return os;
+ }
+#if 0
+ assert(b.digits.size()<=2);
+ if(b.sign==-1)os<<'-';
+ if(b.digits.size()==2)os<<((uint64_t)b.digits[1]<<32)+b.digits[0];
+ else if(b.digits.size()==1)os<<b.digits[0];
+ else os<<'0';
+ return os;
+#else
+ if(b==0)return os<<'0';
+ Bigint div(1000000000000000000LL);
+ vector<Bigint::longdigit_t> outbuf;
+ while(b!=0){
+ pair<Bigint,Bigint> dm=b.divmod(div);
+ b=dm.first;
+ Bigint::longdigit_t val=0;
+ assert(dm.second.digits.size()<=2);
+ if(dm.second.digits.size()>=2)
+ val+=((Bigint::longdigit_t)1<<Bigint::digit_bits)*dm.second.digits[1];
+ if(dm.second.digits.size()>=1)
+ val+=dm.second.digits[0];
+ outbuf.push_back(val);
+ }
+ for(int i=outbuf.size()-1;i>=0;i--){
+ (i==(int)outbuf.size()-1?os:os<<setfill('0')<<setw(18))<<outbuf[i];
+ }
+ return os;
+#endif
+}
diff --git a/bigint.h b/bigint.h
new file mode 100644
index 0000000..6579054
--- /dev/null
+++ b/bigint.h
@@ -0,0 +1,87 @@
+#pragma once
+
+#include <iostream>
+#include <vector>
+#include <utility>
+#include <cstdint>
+
+class Bigint{
+public:
+ using digit_t=uint32_t;
+ using longdigit_t=uint64_t;
+ using slongdigit_t=int64_t;
+ static const int digit_bits=8*sizeof(digit_t);
+
+private:
+ std::vector<digit_t> digits;
+ int sign;
+
+ static void add(Bigint&,const Bigint&); //ignores sign of arguments
+ static void subtract(Bigint&,const Bigint&); //ignores sign of arguments; assumes a>=b
+ static Bigint product(const Bigint&,const Bigint&);
+
+ void shrink();
+ void normalise();
+ void checkconsistent();
+
+ std::pair<Bigint,Bigint> divmod(const Bigint&,int depth) const;
+
+public:
+ Bigint();
+ Bigint(const Bigint&)=default;
+ Bigint(Bigint&&)=default;
+ explicit Bigint(slongdigit_t);
+
+ Bigint& operator=(const Bigint&)=default;
+ Bigint& operator=(Bigint&&)=default;
+ Bigint& operator=(slongdigit_t);
+
+ Bigint& operator+=(const Bigint&);
+ Bigint& operator-=(const Bigint&);
+ Bigint& operator*=(const Bigint&);
+ Bigint& operator<<=(int);
+ Bigint& operator>>=(int);
+ Bigint& negate();
+
+ Bigint operator+(const Bigint&) const;
+ Bigint operator-(const Bigint&) const;
+ Bigint operator*(const Bigint&) const;
+ Bigint operator<<(int) const;
+ Bigint operator>>(int) const;
+ std::pair<Bigint,Bigint> divmod(const Bigint&) const;
+
+ bool operator==(const Bigint&) const;
+ bool operator!=(const Bigint&) const;
+ bool operator<(const Bigint&) const;
+ bool operator>(const Bigint&) const;
+ bool operator<=(const Bigint&) const;
+ bool operator>=(const Bigint&) const;
+ bool operator==(slongdigit_t) const;
+ bool operator!=(slongdigit_t) const;
+ bool operator<(slongdigit_t) const;
+ bool operator>(slongdigit_t) const;
+ bool operator<=(slongdigit_t) const;
+ bool operator>=(slongdigit_t) const;
+
+ int compare(const Bigint&) const; //-1: <; 0: ==; 1: >
+ int compare(slongdigit_t) const;
+ int compareAbs(const Bigint&) const; //-1: <; 0: ==; 1: >; disregards sign
+ int compareAbs(slongdigit_t) const;
+
+ int bitcount() const;
+ slongdigit_t lowdigits() const;
+
+ std::vector<char> serialise() const;
+ void deserialise(const std::vector<char>&);
+ std::vector<bool> bits() const;
+
+ friend std::istream& operator>>(std::istream&,Bigint&);
+ friend std::ostream& operator<<(std::ostream&,Bigint);
+
+ digit_t _digit(int idx) const;
+};
+
+Bigint pow(const Bigint &b,const Bigint &ex);
+
+std::istream& operator>>(std::istream&,Bigint&);
+std::ostream& operator<<(std::ostream&,Bigint);
diff --git a/biginttest.py b/biginttest.py
new file mode 100755
index 0000000..f125953
--- /dev/null
+++ b/biginttest.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+import sys, random, subprocess
+
+ntimes=10000
+maxn=1e100
+
+def check(desc,x,y):
+ if x==y: return
+ print("{}: {} != {}".format(desc,x,y))
+ assert False
+
+def gendata():
+ for _ in range(ntimes):
+ yield random.randint(0,maxn), random.randint(1,maxn)
+
+def proctest():
+ proc=subprocess.Popen(["./main"],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=sys.stderr)
+
+ for (a,b) in gendata():
+ proc.stdin.write("div {} {}\n".format(a,b).encode("ascii"))
+ proc.stdin.write("mod {} {}\n".format(a,b).encode("ascii"))
+ proc.stdin.flush()
+
+ ans=int(proc.stdout.readline())
+ check("{}/{}".format(a,b),ans,a//b)
+
+ ans=int(proc.stdout.readline())
+ check("{}%{}".format(a,b),ans,a%b)
+
+ proc.kill()
+
+def justprint():
+ for (a,b) in gendata():
+ print("div {} {}".format(a,b))
+ print("mod {} {}".format(a,b))
+
+#justprint()
+proctest()
diff --git a/main.cpp b/main.cpp
new file mode 100644
index 0000000..ce1573f
--- /dev/null
+++ b/main.cpp
@@ -0,0 +1,141 @@
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <stdexcept>
+#include <cstdlib>
+#include <cctype>
+#include <ctime>
+#include <cassert>
+#include "bigint.h"
+#include "numalgo.h"
+#include "rsa.h"
+
+using namespace std;
+
+class eof_error : public runtime_error{
+public:
+ eof_error()
+ :runtime_error("EOF"){}
+};
+
+int64_t rand64(){
+ return ((int64_t)rand()<<32)+(((int64_t)rand()%2)<<31)+rand();
+}
+
+Bigint readevalexpr(istream &is){
+ Bigint a;
+ is>>a;
+ if(is.eof())throw eof_error();
+ // cerr<<"Read "<<a<<endl;
+ if(!is.fail())return a;
+ is.clear();
+ string s;
+ is>>s;
+ assert(!is.fail());
+ a=readevalexpr(is);
+ Bigint b=readevalexpr(is);
+ //cerr<<"Operation "<<s<<" on "<<a<<" and "<<b<<endl;
+ if(s=="add")return a+b;
+ else if(s=="sub")return a-b;
+ else if(s=="mul")return a*b;
+ else if(s=="div")return a.divmod(b).first;
+ else if(s=="mod")return a.divmod(b).second;
+ else {
+ cerr<<"Unknown operation '"<<s<<'\''<<endl;
+ assert(false);
+ }
+}
+
+void biginttest(){
+ srand(time(NULL));
+
+ // cerr<<Bigint(599428191)*Bigint(10)<<endl;
+ // cerr<<hex<<Bigint(599428191)*Bigint(10)<<endl;
+
+#if 1
+ {
+ Bigint bi;
+ assert(RAND_MAX==(1U<<31)-1);
+ for(int i=0;i<500000;i++){
+ int64_t a=rand64(),b=rand64();
+ if(a+b<0){i--; continue;}
+ stringstream s1,s2,s3;
+ s1<<a+b;
+ s2<<Bigint(a+b);
+ s3<<Bigint(a)+Bigint(b);
+ assert(s1.str()==s2.str()&&s1.str()==s3.str());
+ }
+ }
+#endif
+
+#if 1
+ {
+ for(int i=0;i<1000;i++){
+ int64_t n=rand64();
+ istringstream ss(to_string(n));
+ Bigint bi;
+ ss>>bi;
+ assert(bi==Bigint(n));
+ }
+ }
+#endif
+
+#if 1
+ {
+ string s="4405994068155852661780322209877856931246944549396705884037139443014164958640201650440984581318995014";
+ istringstream iss(s);
+ Bigint bi;
+ iss>>bi;
+ uint32_t digs[11]={1752788038,953502834,2175607868,1627159508,1754291416,1207689192,3196357285,3165170272,3313904421,3194703103,2062};
+ for(int i=0;i<11;i++)assert(bi._digit(i)==digs[i]);
+ ostringstream oss;
+ oss<<bi;
+ assert(oss.str()==s);
+ }
+#endif
+}
+
+void repl(int argc,char **argv){
+ istream *in;
+ if(argc==2)in=new ifstream(argv[1]);
+ else in=&cin;
+ for(int i=0;;i++){
+ try {
+ cout<<readevalexpr(*in)<<endl;
+ } catch(eof_error){
+ break;
+ }
+ }
+}
+
+void performrsa(){
+ PrivateKey privkey;
+ Bigint p(1000000007),q(3000000019);
+ privkey.pub.mod=3000000040000000133LL;
+ privkey.pub.exp=65537;
+ {
+ Bigint x;
+ Bigint one(1);
+ egcd((p-one)*(q-one),privkey.pub.exp,x,privkey.pexp);
+ }
+ cout<<"d = "<<privkey.pexp<<endl;
+ Bigint msg(123456789);
+ cout<<"msg = "<<msg<<endl;
+ Bigint encr=encrypt(privkey.pub,msg);
+ cout<<"encr = "<<encr<<endl;
+ Bigint msg2=decrypt(privkey,encr);
+ cout<<"msg = "<<msg2<<endl;
+}
+
+int main(int,char**){
+ // biginttest();
+ // repl(argc,argv);
+ performrsa();
+ // cout<<Bigint(0)-Bigint(69255535LL)<<endl;
+ // cout<<Bigint(0)-Bigint(669255535LL)<<endl;
+ // cout<<Bigint(0)-Bigint(5669255535LL)<<endl;
+ // cout<<Bigint(0)-Bigint(75669255535LL)<<endl;
+ // cout<<Bigint(0)-Bigint(775669255535LL)<<endl;
+ // cout<<Bigint(0)-Bigint(5775669255535LL)<<endl;
+ // cout<<Bigint(0)-Bigint(45775669255535LL)<<endl;
+}
diff --git a/numalgo.cpp b/numalgo.cpp
new file mode 100644
index 0000000..1db0763
--- /dev/null
+++ b/numalgo.cpp
@@ -0,0 +1,50 @@
+#include <cassert>
+#include "numalgo.h"
+
+using namespace std;
+
+Bigint gcd(Bigint a,Bigint b){
+ while(true){
+ if(a==0)return b;
+ if(b==0)return a;
+ if(a>=b)a=a.divmod(b).second;
+ else b=b.divmod(a).second;
+ }
+}
+
+Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y){
+ Bigint x2(0),y2(1),r(a),r2(b);
+ x=1; y=0;
+ //cerr<<x<<"\t * "<<a<<"\t + "<<y<<"\t * "<<b<<"\t = "<<r<<endl;
+ while(r2!=0){
+ pair<Bigint,Bigint> dm=r.divmod(r2);
+ //cerr<<x2<<"\t * "<<a<<"\t + "<<y2<<"\t * "<<b<<"\t = "<<r2<<" (q = "<<dm.first<<')'<<endl;
+ Bigint xn=x-dm.first*x2;
+ Bigint yn=y-dm.first*y2;
+ x=x2; x2=xn;
+ y=y2; y2=yn;
+ r=r2; r2=dm.second;
+ }
+ return r;
+}
+
+Bigint expmod(const Bigint &b,const Bigint &e,const Bigint &m){
+ assert(e>=0);
+ assert(m>=1);
+ if(m==1)return Bigint(0);
+ Bigint res(1);
+ vector<bool> bits(e.bits());
+ for(int i=bits.size()-1;i>=0;i--){
+ res*=res;
+ if(bits[i])res*=b;
+ res=res.divmod(m).second;
+ }
+ return res;
+}
+
+int ilog2(uint64_t i){
+ assert(i);
+ int l=0;
+ while(i>>=1)l++;
+ return l;
+}
diff --git a/numalgo.h b/numalgo.h
new file mode 100644
index 0000000..71e06f0
--- /dev/null
+++ b/numalgo.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include <cstdint>
+#include "bigint.h"
+
+Bigint gcd(Bigint a,Bigint b);
+Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y);
+
+Bigint expmod(const Bigint &base,const Bigint &exponent,const Bigint &modulus);
+
+int ilog2(uint64_t i);
diff --git a/rsa.cpp b/rsa.cpp
new file mode 100644
index 0000000..11bf0eb
--- /dev/null
+++ b/rsa.cpp
@@ -0,0 +1,12 @@
+#include <cassert>
+#include "numalgo.h"
+#include "rsa.h"
+
+Bigint encrypt(const PublicKey &pubkey,Bigint msg){
+ assert(msg>1&&msg<pubkey.mod);
+ return expmod(msg,pubkey.exp,pubkey.mod);
+}
+
+Bigint decrypt(const PrivateKey &privkey,Bigint encr){
+ return expmod(encr,privkey.pexp,privkey.pub.mod);
+}
diff --git a/rsa.h b/rsa.h
new file mode 100644
index 0000000..ec3c349
--- /dev/null
+++ b/rsa.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include "bigint.h"
+
+struct PublicKey{
+ Bigint mod,exp;
+};
+
+struct PrivateKey{
+ PublicKey pub;
+ Bigint pexp;
+};
+
+Bigint encrypt(const PublicKey&,Bigint);
+Bigint decrypt(const PrivateKey&,Bigint);