diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | Makefile | 24 | ||||
-rw-r--r-- | bigint.cpp | 498 | ||||
-rw-r--r-- | bigint.h | 87 | ||||
-rwxr-xr-x | biginttest.py | 38 | ||||
-rw-r--r-- | main.cpp | 141 | ||||
-rw-r--r-- | numalgo.cpp | 50 | ||||
-rw-r--r-- | numalgo.h | 11 | ||||
-rw-r--r-- | rsa.cpp | 12 | ||||
-rw-r--r-- | rsa.h | 15 |
10 files changed, 880 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..89805cd --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.o +*.dSYM +main +*.txt diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..14220de --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +CXX = g++ +CXXFLAGS = -Wall -Wextra -std=c++11 -fwrapv +ifneq ($(DEBUG),) + CXXFLAGS += -g +else + CXXFLAGS += -O2 +endif +BIN = main + +.PHONY: all clean remake + +all: $(BIN) + +clean: + rm -rf $(BIN) *.o *.dSYM + +remake: clean all + + +$(BIN): $(patsubst %.cpp,%.o,$(wildcard *.cpp)) + $(CXX) -o $@ $^ + +%.o: %.cpp $(wildcard *.h) + $(CXX) $(CXXFLAGS) -c -o $@ $< diff --git a/bigint.cpp b/bigint.cpp new file mode 100644 index 0000000..5182d07 --- /dev/null +++ b/bigint.cpp @@ -0,0 +1,498 @@ +#include <iomanip> +#include <stdexcept> +#include <cctype> +#include <cassert> +#include "bigint.h" +#include "numalgo.h" + +using namespace std; + +Bigint::Bigint() + :sign(1){} + +Bigint::Bigint(slongdigit_t v) + :digits(1,abs(v)),sign(v>=0?1:-1){ + static_assert(sizeof(longdigit_t)==2*sizeof(digit_t), + "longdigit_t should be twice as large as digit_t"); + v=abs(v); + if(v>digits[0])digits.push_back(v>>digit_bits); + else if(v==0)digits.clear(); + checkconsistent(); +} + +void Bigint::add(Bigint &a,const Bigint &b){ + if(a.digits.size()<b.digits.size())a.digits.resize(b.digits.size()); + int sz=a.digits.size(); + int carry=0; + for(int i=0;i<sz;i++){ + longdigit_t bdig=i<(int)b.digits.size()?b.digits[i]:0; + longdigit_t sum=a.digits[i]+bdig+carry; + a.digits[i]=sum; + // carry=sum>=((longdigit_t)1<<digit_bits); + carry=sum>>digit_bits; + } + if(carry)a.digits.push_back(1); + a.normalise(); + a.checkconsistent(); +} + +void Bigint::subtract(Bigint &a,const Bigint &b){ + if(a.digits.size()<b.digits.size()){ + a.digits.resize(b.digits.size()); //adds zeros + } + assert(a.digits.size()>=b.digits.size()); + if(a.digits.size()==0){ + a.checkconsistent(); + return; + } + assert(a.digits.size()>0); + int sz=a.digits.size(); + int carry=0; + for(int i=0;i<sz;i++){ + if(i>=(int)b.digits.size()&&!carry)break; + digit_t adig=a.digits[i]; + digit_t bdig=i<(int)b.digits.size()?b.digits[i]:0; + digit_t res=adig-(bdig+carry); + // cerr<<"carry="<<carry<<" res="<<res<<" adig="<<adig<<" bdig="<<bdig<<endl; + carry=(bdig||carry)&&res>=adig; + a.digits[i]=res; + } + if(carry){ + // cerr<<"2s complement"<<endl; + //we do a fake 2s complement, sort of + carry=0; + for(int i=0;i<sz;i++){ + a.digits[i]=~a.digits[i]; + a.digits[i]+=(i==0)+carry; + carry=a.digits[i]<=(digit_t)carry; + } + a.sign=-a.sign; + } + a.shrink(); + a.normalise(); + a.checkconsistent(); +} + +Bigint Bigint::product(const Bigint &a,const Bigint &b){ + int asz=a.digits.size(),bsz=b.digits.size(); + if(asz==0||bsz==0)return Bigint(); + Bigint res; + res.digits.resize(asz+bsz); + for(int i=0;i<asz;i++){ + digit_t carry=0; + for(int j=0;j<bsz;j++){ + longdigit_t pr=(longdigit_t)a.digits[i]*b.digits[j]+carry; + longdigit_t newd=pr+res.digits[i+j]; //this always fits, I checked + res.digits[i+j]=(digit_t)newd; + carry=newd>>digit_bits; + // cerr<<"carry="<<carry<<endl; + } + for(int j=bsz;carry;j++){ + assert(i+j<(int)res.digits.size()); + longdigit_t newd=res.digits[i+j]+carry; + res.digits[i+j]=newd; + carry=newd>>digit_bits; + // cerr<<"(2) carry="<<carry<<endl; + } + } + res.sign=a.sign*b.sign; + res.shrink(); + res.normalise(); + res.checkconsistent(); + return res; +} + +void Bigint::shrink(){ + while(digits.size()&&digits.back()==0)digits.pop_back(); +} + +void Bigint::normalise(){ + if(digits.size()==0&&sign==-1)sign=1; +} + +void Bigint::checkconsistent(){ + assert(digits.size()==0||digits.back()!=0); + assert(digits.size()!=0||sign==1); +} + +Bigint& Bigint::operator=(slongdigit_t v){ + digits.resize(1); + sign=v>=0?1:-1; + v*=sign; + digits[0]=v; + if(v>digits[0])digits.push_back(v>>digit_bits); + shrink(); + normalise(); + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator+=(const Bigint &o){ + if(&o==this){ + return *this=Bigint(*this)+=o; + } + if(sign==1){ + if(o.sign==1)add(*this,o); + else subtract(*this,o); + } else { + if(o.sign==1)subtract(*this,o); + else add(*this,o); + } + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator-=(const Bigint &o){ + if(&o==this){ + return *this=Bigint(*this)-=o; + } + if(sign==1){ + if(o.sign==1)subtract(*this,o); + else add(*this,o); + } else { + if(o.sign==1)add(*this,o); + else subtract(*this,o); + } + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator*=(const Bigint &o){ + *this=product(*this,o); + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator<<=(int sh){ + if(sh==0)return *this; + if(digits.size()==0)return *this; + if(sh<0)return *this>>=-sh; + if(sh/digit_bits>0){ + digits.insert(digits.begin(),sh/digit_bits,0); + sh%=digit_bits; + if(sh==0){ + checkconsistent(); + return *this; + } + } + digits.push_back(0); + for(int i=digits.size()-2;i>=0;i--){ + digits[i+1]|=digits[i]>>(digit_bits-sh); + digits[i]<<=sh; + } + shrink(); + normalise(); + checkconsistent(); + return *this; +} + +Bigint& Bigint::operator>>=(int sh){ + if(sh==0)return *this; + if(digits.size()==0)return *this; + if(sh<0)return *this<<=-sh; + if(sh/digit_bits>0){ + if(sh/digit_bits>=(int)digits.size()){ + digits.clear(); + sign=1; + checkconsistent(); + return *this; + } + digits.erase(digits.begin(),digits.begin()+sh/digit_bits); + sh%=digit_bits; + if(sh==0){ + checkconsistent(); + return *this; + } + } + digits[0]>>=sh; + int sz=digits.size(); + for(int i=1;i<sz;i++){ + digits[i-1]|=digits[i]<<(digit_bits-sh); + digits[i]>>=sh; + } + shrink(); + normalise(); + checkconsistent(); + return *this; +} + +Bigint& Bigint::negate(){ + sign=-sign; + return *this; +} + +Bigint Bigint::operator+(const Bigint &o) const { + return Bigint(*this)+=o; +} + +Bigint Bigint::operator-(const Bigint &o) const { + return Bigint(*this)-=o; +} + +Bigint Bigint::operator*(const Bigint &o) const { + return product(*this,o); +} + +int depthrecord; + +pair<Bigint,Bigint> Bigint::divmod(const Bigint &div) const { + pair<Bigint,Bigint> res=divmod(div,1); + //cerr<<hex<<*this<<' '<<div<<' '; + //cerr<<dec<<depthrecord<<endl; + return res; +} + +inline void record(int depth){ + depthrecord=depth; +} + +pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const { + //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<')'<<endl; + if(div.digits.size()==0)throw domain_error("Bigint divide by zero"); + if(digits.size()==0){record(depth); return make_pair(Bigint(0),Bigint(0));} + + int cmp=compareAbs(div); + if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint(0));} + if(cmp<0){record(depth); return make_pair(Bigint(0),*this);} + //now *this is greater in magnitude than the divisor + +#if 0 + int twoexp=bitcount()-div.bitcount(); + if(twoexp<0)twoexp=0; + Bigint quotient(sign*div.sign); + quotient<<=twoexp; + Bigint guess(div); //guess == quotient * div + guess<<=twoexp; + //cerr<<"guess="<<hex<<guess<<endl; + //twoexp now becomes irrelevant +#else + Bigint guess,quotient; + if(digits.size()==div.digits.size()){ + quotient=digits.back()/div.digits.back(); + guess=quotient; + guess*=div; + } else { + assert(digits.size()>div.digits.size()); + assert(digits.size()>=2&&div.digits.size()>=1); + longdigit_t factor=((longdigit_t)1<<digit_bits)*digits.back()+digits[digits.size()-2]; + factor/=div.digits.back()+1; + quotient=factor; + quotient<<=(digits.size()-div.digits.size()-1)*digit_bits; + guess=quotient; + guess*=div; + } + // cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<endl; +#endif + cmp=guess.compareAbs(*this); + if(cmp<0){ + Bigint guess2(guess); + while(true){ + guess2<<=1; + int cmp=guess2.compareAbs(*this); + if(cmp>0)break; + guess<<=1; + quotient<<=1; + if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + } + Bigint rest(*this); + rest-=guess; //also correct for *this and guess negative + pair<Bigint,Bigint> dm=rest.divmod(div,depth+1); + dm.first+=quotient; + return dm; + } + if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + //then cmp>0, so our guess is too large + Bigint one(1); + do { + if(quotient.digits[0]&1){ + guess-=div; + // quotient-=one; // not necessary, since we shift the bit out anyway + } + guess>>=1; + quotient>>=1; + cmp=guess.compareAbs(*this); + } while(cmp>0); + if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + Bigint rest(*this); + rest-=guess; + pair<Bigint,Bigint> dm=rest.divmod(div,depth+1); + dm.first+=quotient; + return dm; +} + +bool Bigint::operator==(const Bigint &o) const {return compare(o)==0;} +bool Bigint::operator!=(const Bigint &o) const {return compare(o)!=0;} +bool Bigint::operator<(const Bigint &o) const {return compare(o)<0;} +bool Bigint::operator>(const Bigint &o) const {return compare(o)>0;} +bool Bigint::operator<=(const Bigint &o) const {return compare(o)<=0;} +bool Bigint::operator>=(const Bigint &o) const {return compare(o)>=0;} + +bool Bigint::operator==(slongdigit_t v) const {return compare(v)==0;} +bool Bigint::operator!=(slongdigit_t v) const {return compare(v)!=0;} +bool Bigint::operator<(slongdigit_t v) const {return compare(v)<0;} +bool Bigint::operator>(slongdigit_t v) const {return compare(v)>0;} +bool Bigint::operator<=(slongdigit_t v) const {return compare(v)<=0;} +bool Bigint::operator>=(slongdigit_t v) const {return compare(v)>=0;} + +int Bigint::compare(const Bigint &o) const { + if(sign>o.sign)return 1; + if(sign<o.sign)return -1; + return sign*compareAbs(o); +} + +int Bigint::compare(slongdigit_t v) const { + if(sign==-1&&v>=0)return -1; + if(sign==1&&v<0)return 1; + return sign*compareAbs(v); +} + +int Bigint::compareAbs(const Bigint &o) const { + int sz=digits.size(),osz=o.digits.size(); + if(sz>osz)return 1; + if(sz<osz)return -1; + for(int i=sz-1;i>=0;i--){ + if(digits[i]>o.digits[i])return 1; + if(digits[i]<o.digits[i])return -1; + } + return 0; +} + +int Bigint::compareAbs(slongdigit_t v) const { + v=abs(v); + if(digits.size()>2)return 1; + if(digits.size()==0)return v==0?0:-1; + if(digits.size()==2){ + if(digits[1]>(digit_t)(v>>digit_bits))return 1; + if(digits[1]<(digit_t)(v>>digit_bits))return -1; + } + if(digits[0]<(digit_t)v)return -1; + if(digits[0]>(digit_t)v)return 1; + return 0; +} + +int Bigint::bitcount() const { + if(digits.size()==0)return 0; + return (digits.size()-1)*8+ilog2(digits.back())+1; +} + +Bigint::slongdigit_t Bigint::lowdigits() const { + if(digits.size()==0)return 0; + if(digits.size()==1)return digits[0]; + longdigit_t mask=~((longdigit_t)1<<(digit_bits-1)); + return ((slongdigit_t)1<<digit_bits)*(digits[1]&mask)+digits[0]; +} + +vector<char> Bigint::serialise() const { + vector<char> v(1+digits.size()*sizeof(digit_t)); + v[0]=sign; + int sz=digits.size(); + for(int i=0;i<sz;i++){ + for(int j=0;j<(int)sizeof(digit_t);j++){ + v[1+i*sizeof(digit_t)+j]=(digits[i]>>(8*j))&255; + } + } + return v; +} + +void Bigint::deserialise(const vector<char> &v){ + assert(v.size()%4==1); + sign=(int)v[0]; + assert(sign==1||sign==-1); + int sz=v.size()/4; + digits.resize(sz); + for(int i=0;i<sz;i++){ + digits[i]=0; + for(int j=0;j<(int)sizeof(digit_t);j++){ + digits[i]|=v[1+i*sizeof(digit_t)+j]<<(8*j); + } + } + shrink(); + normalise(); + checkconsistent(); +} + +vector<bool> Bigint::bits() const { + if(digits.size()==0)return {}; + vector<bool> v(digit_bits*(digits.size()-1)+ilog2(digits.back())+1); + int sz=digits.size(); + for(int i=0;i<sz;i++){ + digit_t dig=digits[i]; + for(int j=0;dig;j++){ + if(dig&1)v[digit_bits*i+j]=true; + dig>>=1; + } + } + return v; +} + +Bigint::digit_t Bigint::_digit(int idx) const { + return digits.at(idx); +} + +istream& operator>>(istream &is,Bigint &b){ + while(isspace(is.peek()))is.get(); + if(!is)return is; + b.digits.resize(0); + b.sign=1; + Bigint ten(10); + bool acted=false; + while(true){ + char c=is.peek(); + if(!isdigit(c))break; + acted=true; + is.get(); + if(!is){ + b.checkconsistent(); + return is; + } + b*=ten; + b+=Bigint(c-'0'); + // cerr<<"b="<<b<<endl; + } + if(!acted)is.setstate(ios_base::failbit); + b.checkconsistent(); + return is; +} + +std::ostream& operator<<(std::ostream &os,Bigint b){ + if(b<0){ + os<<'-'; + b.negate(); + } + if(os.flags()&ios_base::hex){ + os<<"0x"; + if(b.digits.size()==0)return os<<'0'; + os<<b.digits.back(); + for(int i=b.digits.size()-2;i>=0;i--){ + os<<setw(Bigint::digit_bits/4)<<setfill('0')<<b.digits[i]; + } + return os; + } +#if 0 + assert(b.digits.size()<=2); + if(b.sign==-1)os<<'-'; + if(b.digits.size()==2)os<<((uint64_t)b.digits[1]<<32)+b.digits[0]; + else if(b.digits.size()==1)os<<b.digits[0]; + else os<<'0'; + return os; +#else + if(b==0)return os<<'0'; + Bigint div(1000000000000000000LL); + vector<Bigint::longdigit_t> outbuf; + while(b!=0){ + pair<Bigint,Bigint> dm=b.divmod(div); + b=dm.first; + Bigint::longdigit_t val=0; + assert(dm.second.digits.size()<=2); + if(dm.second.digits.size()>=2) + val+=((Bigint::longdigit_t)1<<Bigint::digit_bits)*dm.second.digits[1]; + if(dm.second.digits.size()>=1) + val+=dm.second.digits[0]; + outbuf.push_back(val); + } + for(int i=outbuf.size()-1;i>=0;i--){ + (i==(int)outbuf.size()-1?os:os<<setfill('0')<<setw(18))<<outbuf[i]; + } + return os; +#endif +} diff --git a/bigint.h b/bigint.h new file mode 100644 index 0000000..6579054 --- /dev/null +++ b/bigint.h @@ -0,0 +1,87 @@ +#pragma once + +#include <iostream> +#include <vector> +#include <utility> +#include <cstdint> + +class Bigint{ +public: + using digit_t=uint32_t; + using longdigit_t=uint64_t; + using slongdigit_t=int64_t; + static const int digit_bits=8*sizeof(digit_t); + +private: + std::vector<digit_t> digits; + int sign; + + static void add(Bigint&,const Bigint&); //ignores sign of arguments + static void subtract(Bigint&,const Bigint&); //ignores sign of arguments; assumes a>=b + static Bigint product(const Bigint&,const Bigint&); + + void shrink(); + void normalise(); + void checkconsistent(); + + std::pair<Bigint,Bigint> divmod(const Bigint&,int depth) const; + +public: + Bigint(); + Bigint(const Bigint&)=default; + Bigint(Bigint&&)=default; + explicit Bigint(slongdigit_t); + + Bigint& operator=(const Bigint&)=default; + Bigint& operator=(Bigint&&)=default; + Bigint& operator=(slongdigit_t); + + Bigint& operator+=(const Bigint&); + Bigint& operator-=(const Bigint&); + Bigint& operator*=(const Bigint&); + Bigint& operator<<=(int); + Bigint& operator>>=(int); + Bigint& negate(); + + Bigint operator+(const Bigint&) const; + Bigint operator-(const Bigint&) const; + Bigint operator*(const Bigint&) const; + Bigint operator<<(int) const; + Bigint operator>>(int) const; + std::pair<Bigint,Bigint> divmod(const Bigint&) const; + + bool operator==(const Bigint&) const; + bool operator!=(const Bigint&) const; + bool operator<(const Bigint&) const; + bool operator>(const Bigint&) const; + bool operator<=(const Bigint&) const; + bool operator>=(const Bigint&) const; + bool operator==(slongdigit_t) const; + bool operator!=(slongdigit_t) const; + bool operator<(slongdigit_t) const; + bool operator>(slongdigit_t) const; + bool operator<=(slongdigit_t) const; + bool operator>=(slongdigit_t) const; + + int compare(const Bigint&) const; //-1: <; 0: ==; 1: > + int compare(slongdigit_t) const; + int compareAbs(const Bigint&) const; //-1: <; 0: ==; 1: >; disregards sign + int compareAbs(slongdigit_t) const; + + int bitcount() const; + slongdigit_t lowdigits() const; + + std::vector<char> serialise() const; + void deserialise(const std::vector<char>&); + std::vector<bool> bits() const; + + friend std::istream& operator>>(std::istream&,Bigint&); + friend std::ostream& operator<<(std::ostream&,Bigint); + + digit_t _digit(int idx) const; +}; + +Bigint pow(const Bigint &b,const Bigint &ex); + +std::istream& operator>>(std::istream&,Bigint&); +std::ostream& operator<<(std::ostream&,Bigint); diff --git a/biginttest.py b/biginttest.py new file mode 100755 index 0000000..f125953 --- /dev/null +++ b/biginttest.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +import sys, random, subprocess + +ntimes=10000 +maxn=1e100 + +def check(desc,x,y): + if x==y: return + print("{}: {} != {}".format(desc,x,y)) + assert False + +def gendata(): + for _ in range(ntimes): + yield random.randint(0,maxn), random.randint(1,maxn) + +def proctest(): + proc=subprocess.Popen(["./main"],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=sys.stderr) + + for (a,b) in gendata(): + proc.stdin.write("div {} {}\n".format(a,b).encode("ascii")) + proc.stdin.write("mod {} {}\n".format(a,b).encode("ascii")) + proc.stdin.flush() + + ans=int(proc.stdout.readline()) + check("{}/{}".format(a,b),ans,a//b) + + ans=int(proc.stdout.readline()) + check("{}%{}".format(a,b),ans,a%b) + + proc.kill() + +def justprint(): + for (a,b) in gendata(): + print("div {} {}".format(a,b)) + print("mod {} {}".format(a,b)) + +#justprint() +proctest() diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..ce1573f --- /dev/null +++ b/main.cpp @@ -0,0 +1,141 @@ +#include <iostream> +#include <fstream> +#include <sstream> +#include <stdexcept> +#include <cstdlib> +#include <cctype> +#include <ctime> +#include <cassert> +#include "bigint.h" +#include "numalgo.h" +#include "rsa.h" + +using namespace std; + +class eof_error : public runtime_error{ +public: + eof_error() + :runtime_error("EOF"){} +}; + +int64_t rand64(){ + return ((int64_t)rand()<<32)+(((int64_t)rand()%2)<<31)+rand(); +} + +Bigint readevalexpr(istream &is){ + Bigint a; + is>>a; + if(is.eof())throw eof_error(); + // cerr<<"Read "<<a<<endl; + if(!is.fail())return a; + is.clear(); + string s; + is>>s; + assert(!is.fail()); + a=readevalexpr(is); + Bigint b=readevalexpr(is); + //cerr<<"Operation "<<s<<" on "<<a<<" and "<<b<<endl; + if(s=="add")return a+b; + else if(s=="sub")return a-b; + else if(s=="mul")return a*b; + else if(s=="div")return a.divmod(b).first; + else if(s=="mod")return a.divmod(b).second; + else { + cerr<<"Unknown operation '"<<s<<'\''<<endl; + assert(false); + } +} + +void biginttest(){ + srand(time(NULL)); + + // cerr<<Bigint(599428191)*Bigint(10)<<endl; + // cerr<<hex<<Bigint(599428191)*Bigint(10)<<endl; + +#if 1 + { + Bigint bi; + assert(RAND_MAX==(1U<<31)-1); + for(int i=0;i<500000;i++){ + int64_t a=rand64(),b=rand64(); + if(a+b<0){i--; continue;} + stringstream s1,s2,s3; + s1<<a+b; + s2<<Bigint(a+b); + s3<<Bigint(a)+Bigint(b); + assert(s1.str()==s2.str()&&s1.str()==s3.str()); + } + } +#endif + +#if 1 + { + for(int i=0;i<1000;i++){ + int64_t n=rand64(); + istringstream ss(to_string(n)); + Bigint bi; + ss>>bi; + assert(bi==Bigint(n)); + } + } +#endif + +#if 1 + { + string s="4405994068155852661780322209877856931246944549396705884037139443014164958640201650440984581318995014"; + istringstream iss(s); + Bigint bi; + iss>>bi; + uint32_t digs[11]={1752788038,953502834,2175607868,1627159508,1754291416,1207689192,3196357285,3165170272,3313904421,3194703103,2062}; + for(int i=0;i<11;i++)assert(bi._digit(i)==digs[i]); + ostringstream oss; + oss<<bi; + assert(oss.str()==s); + } +#endif +} + +void repl(int argc,char **argv){ + istream *in; + if(argc==2)in=new ifstream(argv[1]); + else in=&cin; + for(int i=0;;i++){ + try { + cout<<readevalexpr(*in)<<endl; + } catch(eof_error){ + break; + } + } +} + +void performrsa(){ + PrivateKey privkey; + Bigint p(1000000007),q(3000000019); + privkey.pub.mod=3000000040000000133LL; + privkey.pub.exp=65537; + { + Bigint x; + Bigint one(1); + egcd((p-one)*(q-one),privkey.pub.exp,x,privkey.pexp); + } + cout<<"d = "<<privkey.pexp<<endl; + Bigint msg(123456789); + cout<<"msg = "<<msg<<endl; + Bigint encr=encrypt(privkey.pub,msg); + cout<<"encr = "<<encr<<endl; + Bigint msg2=decrypt(privkey,encr); + cout<<"msg = "<<msg2<<endl; +} + +int main(int,char**){ + // biginttest(); + // repl(argc,argv); + performrsa(); + // cout<<Bigint(0)-Bigint(69255535LL)<<endl; + // cout<<Bigint(0)-Bigint(669255535LL)<<endl; + // cout<<Bigint(0)-Bigint(5669255535LL)<<endl; + // cout<<Bigint(0)-Bigint(75669255535LL)<<endl; + // cout<<Bigint(0)-Bigint(775669255535LL)<<endl; + // cout<<Bigint(0)-Bigint(5775669255535LL)<<endl; + // cout<<Bigint(0)-Bigint(45775669255535LL)<<endl; +} diff --git a/numalgo.cpp b/numalgo.cpp new file mode 100644 index 0000000..1db0763 --- /dev/null +++ b/numalgo.cpp @@ -0,0 +1,50 @@ +#include <cassert> +#include "numalgo.h" + +using namespace std; + +Bigint gcd(Bigint a,Bigint b){ + while(true){ + if(a==0)return b; + if(b==0)return a; + if(a>=b)a=a.divmod(b).second; + else b=b.divmod(a).second; + } +} + +Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y){ + Bigint x2(0),y2(1),r(a),r2(b); + x=1; y=0; + //cerr<<x<<"\t * "<<a<<"\t + "<<y<<"\t * "<<b<<"\t = "<<r<<endl; + while(r2!=0){ + pair<Bigint,Bigint> dm=r.divmod(r2); + //cerr<<x2<<"\t * "<<a<<"\t + "<<y2<<"\t * "<<b<<"\t = "<<r2<<" (q = "<<dm.first<<')'<<endl; + Bigint xn=x-dm.first*x2; + Bigint yn=y-dm.first*y2; + x=x2; x2=xn; + y=y2; y2=yn; + r=r2; r2=dm.second; + } + return r; +} + +Bigint expmod(const Bigint &b,const Bigint &e,const Bigint &m){ + assert(e>=0); + assert(m>=1); + if(m==1)return Bigint(0); + Bigint res(1); + vector<bool> bits(e.bits()); + for(int i=bits.size()-1;i>=0;i--){ + res*=res; + if(bits[i])res*=b; + res=res.divmod(m).second; + } + return res; +} + +int ilog2(uint64_t i){ + assert(i); + int l=0; + while(i>>=1)l++; + return l; +} diff --git a/numalgo.h b/numalgo.h new file mode 100644 index 0000000..71e06f0 --- /dev/null +++ b/numalgo.h @@ -0,0 +1,11 @@ +#pragma once + +#include <cstdint> +#include "bigint.h" + +Bigint gcd(Bigint a,Bigint b); +Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y); + +Bigint expmod(const Bigint &base,const Bigint &exponent,const Bigint &modulus); + +int ilog2(uint64_t i); @@ -0,0 +1,12 @@ +#include <cassert> +#include "numalgo.h" +#include "rsa.h" + +Bigint encrypt(const PublicKey &pubkey,Bigint msg){ + assert(msg>1&&msg<pubkey.mod); + return expmod(msg,pubkey.exp,pubkey.mod); +} + +Bigint decrypt(const PrivateKey &privkey,Bigint encr){ + return expmod(encr,privkey.pexp,privkey.pub.mod); +} @@ -0,0 +1,15 @@ +#pragma once + +#include "bigint.h" + +struct PublicKey{ + Bigint mod,exp; +}; + +struct PrivateKey{ + PublicKey pub; + Bigint pexp; +}; + +Bigint encrypt(const PublicKey&,Bigint); +Bigint decrypt(const PrivateKey&,Bigint); |