diff options
-rw-r--r-- | bigint.cpp | 106 | ||||
-rw-r--r-- | bigint.h | 17 | ||||
-rw-r--r-- | main.cpp | 34 | ||||
-rw-r--r-- | numalgo.cpp | 61 | ||||
-rw-r--r-- | numalgo.h | 5 | ||||
-rw-r--r-- | primes.cpp | 82 | ||||
-rw-r--r-- | primes.h | 19 |
7 files changed, 292 insertions, 32 deletions
@@ -7,6 +7,10 @@ using namespace std; +Bigint Bigint::zero(0); +Bigint Bigint::one(1); +Bigint Bigint::mone(-1); + Bigint::Bigint() :sign(1){} @@ -163,6 +167,11 @@ Bigint& Bigint::operator*=(const Bigint &o){ return *this; } +//TODO: optimise these functions +Bigint& Bigint::operator+=(slongdigit_t n){return *this+=Bigint(n);} +Bigint& Bigint::operator-=(slongdigit_t n){return *this-=Bigint(n);} +Bigint& Bigint::operator*=(slongdigit_t n){return *this*=Bigint(n);} + Bigint& Bigint::operator<<=(int sh){ if(sh==0)return *this; if(digits.size()==0)return *this; @@ -233,12 +242,26 @@ Bigint Bigint::operator*(const Bigint &o) const { return product(*this,o); } +//TODO: optimise these functions +Bigint Bigint::operator+(slongdigit_t n) const {return *this+Bigint(n);} +Bigint Bigint::operator-(slongdigit_t n) const {return *this-Bigint(n);} +Bigint Bigint::operator*(slongdigit_t n) const {return *this*Bigint(n);} + +Bigint Bigint::operator<<(int sh) const { + return Bigint(*this)<<=sh; +} + +Bigint Bigint::operator>>(int sh) const { + return Bigint(*this)>>=sh; +} + int depthrecord; pair<Bigint,Bigint> Bigint::divmod(const Bigint &div) const { pair<Bigint,Bigint> res=divmod(div,1); //cerr<<hex<<*this<<' '<<div<<' '; //cerr<<dec<<depthrecord<<endl; + //cerr<<" -> "<<hex<<res.first<<' '<<res.second<<dec<<endl; return res; } @@ -247,13 +270,14 @@ inline void record(int depth){ } pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const { - //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<')'<<endl; + if(depth>100)assert(false); + //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<dec<<") depth="<<depth<<endl; if(div.digits.size()==0)throw domain_error("Bigint divide by zero"); - if(digits.size()==0){record(depth); return make_pair(Bigint(0),Bigint(0));} + if(digits.size()==0){record(depth); return make_pair(Bigint::zero,Bigint::zero);} int cmp=compareAbs(div); - if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint(0));} - if(cmp<0){record(depth); return make_pair(Bigint(0),*this);} + if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint::zero);} + if(cmp<0){record(depth); return make_pair(Bigint::zero,*this);} //now *this is greater in magnitude than the divisor #if 0 @@ -266,9 +290,11 @@ pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const { //cerr<<"guess="<<hex<<guess<<endl; //twoexp now becomes irrelevant #else +#if 0 Bigint guess,quotient; if(digits.size()==div.digits.size()){ quotient=digits.back()/div.digits.back(); + quotient.sign=sign*div.sign; guess=quotient; guess*=div; } else { @@ -277,42 +303,81 @@ pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const { longdigit_t factor=((longdigit_t)1<<digit_bits)*digits.back()+digits[digits.size()-2]; factor/=div.digits.back()+1; quotient=factor; - quotient<<=(digits.size()-div.digits.size()-1)*digit_bits; - guess=quotient; - guess*=div; + quotient<<=(digits.size()-1-div.digits.size())*digit_bits; + quotient.sign=sign*div.sign; + guess=quotient*div; } - // cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<endl; + cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<dec<<endl; +#else + int thisbtc=bitcount(),divbtc=div.bitcount(); + assert(divbtc<=thisbtc); + Bigint quotient,guess; + if(thisbtc<=2*digit_bits){ + //simple integral division + longdigit_t thisnum=(digits.size()==2?((longdigit_t)1<<digit_bits)*digits[1]:0)+digits[0]; + longdigit_t divnum=(div.digits.size()==2?((longdigit_t)1<<digit_bits)*div.digits[1]:0)+div.digits[0]; + if(divnum==1){record(depth); return make_pair(*this,Bigint::zero);} + record(depth); return make_pair( + Bigint(sign*div.sign*(slongdigit_t)(thisnum/divnum)), + Bigint(sign*div.sign*(slongdigit_t)(thisnum%divnum))); + } else if(divbtc>=digit_bits){ //the large case + //take 2 digits of *this and 1 digit of div; quotient gives a good guess + int spill=__builtin_clz(digits.back()); + //cerr<<"spill="<<spill<<endl; + longdigit_t thishead2=((longdigit_t)digits.back()<<(spill+digit_bits))|((longdigit_t)digits[digits.size()-2]<<spill); + if(spill>0)thishead2|=digits[digits.size()-3]>>(digit_bits-spill); + //cerr<<"thishead2="<<hex<<thishead2<<dec<<endl; + longdigit_t divhead=((longdigit_t)div.digits.back()<<digit_bits)|div.digits[div.digits.size()-2]; + divhead>>=digit_bits-__builtin_clz(div.digits.back()); + //cerr<<"divhead="<<hex<<divhead<<dec<<endl; + longdigit_t factor=thishead2/(divhead+1); //+1 to make sure the quotient guess is <= the actual quotient + quotient=factor; + quotient<<=thisbtc-digit_bits-divbtc; //shift amount may be negative if thisbtc and divbtc < digit_bits apart + quotient.sign=sign*div.sign; + guess=quotient*div; + } else { //divbtc<digit_bits, but *this is large + //take 2 digits of *this and all of div + int spill=__builtin_clz(digits.back()); + longdigit_t thishead2=((longdigit_t)digits.back()<<(spill+digit_bits))|(digits[digits.size()-2]<<spill); + if(spill>0)thishead2|=digits[digits.size()-3]>>(digit_bits-spill); + longdigit_t factor=thishead2/(div.digits.back()+1); //+1 to make sure the quotient guess is <= the actual quotient + quotient=factor; + quotient<<=thisbtc-2*digit_bits; + quotient.sign=sign*div.sign; + guess=quotient*div; + } + //cerr<<"guess= "<<hex<<guess<<" quotient="<<quotient<<dec<<endl; +#endif #endif cmp=guess.compareAbs(*this); if(cmp<0){ - Bigint guess2(guess); + /*Bigint guess2(guess); while(true){ guess2<<=1; int cmp=guess2.compareAbs(*this); if(cmp>0)break; guess<<=1; quotient<<=1; - if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} - } + if(cmp==0){record(depth); return make_pair(quotient,Bigint::zero);} + }*/ Bigint rest(*this); rest-=guess; //also correct for *this and guess negative pair<Bigint,Bigint> dm=rest.divmod(div,depth+1); dm.first+=quotient; return dm; } - if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + if(cmp==0){record(depth); return make_pair(quotient,Bigint::zero);} //then cmp>0, so our guess is too large - Bigint one(1); do { if(quotient.digits[0]&1){ guess-=div; - // quotient-=one; // not necessary, since we shift the bit out anyway + // quotient-=Bigint::one; // not necessary, since we shift the bit out anyway } guess>>=1; quotient>>=1; cmp=guess.compareAbs(*this); } while(cmp>0); - if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));} + if(cmp==0){record(depth); return make_pair(quotient,Bigint::zero);} Bigint rest(*this); rest-=guess; pair<Bigint,Bigint> dm=rest.divmod(div,depth+1); @@ -372,7 +437,7 @@ int Bigint::compareAbs(slongdigit_t v) const { int Bigint::bitcount() const { if(digits.size()==0)return 0; - return (digits.size()-1)*8+ilog2(digits.back())+1; + return (digits.size()-1)*digit_bits+ilog2(digits.back())+1; } Bigint::slongdigit_t Bigint::lowdigits() const { @@ -382,6 +447,13 @@ Bigint::slongdigit_t Bigint::lowdigits() const { return ((slongdigit_t)1<<digit_bits)*(digits[1]&mask)+digits[0]; } +bool Bigint::even() const { + return digits.size()==0||(digits[0]&1)==0; +} +bool Bigint::odd() const { + return !even(); +} + vector<char> Bigint::serialise() const { vector<char> v(1+digits.size()*sizeof(digit_t)); v[0]=sign; @@ -446,7 +518,7 @@ istream& operator>>(istream &is,Bigint &b){ return is; } b*=ten; - b+=Bigint(c-'0'); + b+=c-'0'; // cerr<<"b="<<b<<endl; } if(!acted)is.setstate(ios_base::failbit); @@ -39,6 +39,9 @@ public: Bigint& operator+=(const Bigint&); Bigint& operator-=(const Bigint&); Bigint& operator*=(const Bigint&); + Bigint& operator+=(slongdigit_t); + Bigint& operator-=(slongdigit_t); + Bigint& operator*=(slongdigit_t); Bigint& operator<<=(int); Bigint& operator>>=(int); Bigint& negate(); @@ -46,9 +49,12 @@ public: Bigint operator+(const Bigint&) const; Bigint operator-(const Bigint&) const; Bigint operator*(const Bigint&) const; + Bigint operator+(slongdigit_t) const; + Bigint operator-(slongdigit_t) const; + Bigint operator*(slongdigit_t) const; Bigint operator<<(int) const; Bigint operator>>(int) const; - std::pair<Bigint,Bigint> divmod(const Bigint&) const; + std::pair<Bigint,Bigint> divmod(const Bigint&) const; //rounds towards zero; returns {quotient,remainder} bool operator==(const Bigint&) const; bool operator!=(const Bigint&) const; @@ -70,6 +76,8 @@ public: int bitcount() const; slongdigit_t lowdigits() const; + bool even() const; + bool odd() const; std::vector<char> serialise() const; void deserialise(const std::vector<char>&); @@ -79,9 +87,12 @@ public: friend std::ostream& operator<<(std::ostream&,Bigint); digit_t _digit(int idx) const; -}; -Bigint pow(const Bigint &b,const Bigint &ex); + + static Bigint zero; + static Bigint one; + static Bigint mone; +}; std::istream& operator>>(std::istream&,Bigint&); std::ostream& operator<<(std::ostream&,Bigint); @@ -8,6 +8,7 @@ #include <cassert> #include "bigint.h" #include "numalgo.h" +#include "primes.h" #include "rsa.h" using namespace std; @@ -108,6 +109,23 @@ void repl(int argc,char **argv){ } } +void testisqrt(int argc,char **argv){ + int randsize=argc==2?strtol(argv[1],nullptr,10):1; + assert(randsize>=1); + for(int i=0;i<1000;i++){ + Bigint n(rand64()); + for(int j=1;j<randsize;j++){ + n<<=63; + n+=rand64(); + } + // cout<<hex<<n<<dec<<endl; + Bigint root(isqrt(n)); + assert(root*root<=n); + root+=Bigint::one; + assert(root*root>n); + } +} + void performrsa(){ PrivateKey privkey; Bigint p(1000000007),q(3000000019); @@ -115,8 +133,7 @@ void performrsa(){ privkey.pub.exp=65537; { Bigint x; - Bigint one(1); - egcd((p-one)*(q-one),privkey.pub.exp,x,privkey.pexp); + egcd((p-Bigint::one)*(q-Bigint::one),privkey.pub.exp,x,privkey.pexp); } cout<<"d = "<<privkey.pexp<<endl; Bigint msg(123456789); @@ -127,15 +144,10 @@ void performrsa(){ cout<<"msg = "<<msg2<<endl; } -int main(int,char**){ +int main(int argc,char **argv){ // biginttest(); // repl(argc,argv); - performrsa(); - // cout<<Bigint(0)-Bigint(69255535LL)<<endl; - // cout<<Bigint(0)-Bigint(669255535LL)<<endl; - // cout<<Bigint(0)-Bigint(5669255535LL)<<endl; - // cout<<Bigint(0)-Bigint(75669255535LL)<<endl; - // cout<<Bigint(0)-Bigint(775669255535LL)<<endl; - // cout<<Bigint(0)-Bigint(5775669255535LL)<<endl; - // cout<<Bigint(0)-Bigint(45775669255535LL)<<endl; + // performrsa(); + // testisqrt(argc,argv); + fillsmallprimes(); } diff --git a/numalgo.cpp b/numalgo.cpp index 1db0763..f2ebac5 100644 --- a/numalgo.cpp +++ b/numalgo.cpp @@ -1,3 +1,4 @@ +#include <cstdlib> #include <cassert> #include "numalgo.h" @@ -31,7 +32,7 @@ Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y){ Bigint expmod(const Bigint &b,const Bigint &e,const Bigint &m){ assert(e>=0); assert(m>=1); - if(m==1)return Bigint(0); + if(m==1)return Bigint::zero; Bigint res(1); vector<bool> bits(e.bits()); for(int i=bits.size()-1;i>=0;i--){ @@ -42,9 +43,67 @@ Bigint expmod(const Bigint &b,const Bigint &e,const Bigint &m){ return res; } +Bigint isqrt(const Bigint &n){ + assert(n>=0); + if(n<=1)return n; + const int maxiter=20; //empirically, this should happen around 3.5 million bits in n. (niter ~= -1.87+1.45ln(bits)) + // __asm("int3\n\t"); + // cout<<"bitcount="<<n.bitcount()<<" maxiter="<<maxiter<<endl; + Bigint x(n); + x>>=n.bitcount()/2; + // cerr<<"isqrt("<<hex<<n<<"); x = "<<x<<dec<<endl; + int iter; + for(iter=0;iter<maxiter;iter++){ + // cerr<<"x="<<hex<<x<<dec<<endl; + Bigint xnext(x*x); // xnext = x - (x*x-n)/(2*x) [Newton's method] + xnext-=n; + xnext>>=1; + xnext=xnext.divmod(x).first; + xnext.negate(); + xnext+=x; + if(xnext==x)break; + x=xnext; + } + cerr<<iter<<endl; + assert(iter<maxiter); + switch((x*x).compare(n)){ + case -1:{ + Bigint x1(x); + x1+=Bigint::one; + assert(x1*x1>n); + return x; + } + case 0: return x; + case 1:{ + x-=Bigint::one; + assert(x*x<=n); + return x; + } + default: assert(false); + } +} + int ilog2(uint64_t i){ assert(i); int l=0; while(i>>=1)l++; return l; } + +Bigint cryptrandom_big(const Bigint &bound){ + const int blocksize=32; + int btc=bound.bitcount(); + int nblocks=btc/blocksize,rest=btc%blocksize; + while(true){ + Bigint res; + for(int i=0;i<nblocks;i++){ + if(i!=0)res<<=blocksize; + res+=arc4random_uniform((uint32_t)(((uint64_t)1<<blocksize)-1)); //make sure we don't shift out of our int + } + if(rest){ + res<<=rest; + res+=arc4random_uniform((uint32_t)1<<rest); + } + if(res<=bound)return res; + } +} @@ -8,4 +8,9 @@ Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y); Bigint expmod(const Bigint &base,const Bigint &exponent,const Bigint &modulus); +// Returns sqrt(n), rounded down if necessary +Bigint isqrt(const Bigint &n); + int ilog2(uint64_t i); + +Bigint cryptrandom_big(const Bigint &upperbound); //Return value in [0,upperbound] diff --git a/primes.cpp b/primes.cpp new file mode 100644 index 0000000..d144e5e --- /dev/null +++ b/primes.cpp @@ -0,0 +1,82 @@ +#include <cstring> +#include <cmath> +#include <cassert> +#include "numalgo.h" +#include "primes.h" + +using namespace std; + +vector<int> smallprimes; +bool smallprimes_inited=false; + +void fillsmallprimes(){ + smallprimes_inited=true; + //TODO: reserve expected amount of space in smallprimes + smallprimes.push_back(2); + const int highbound=65000; + bool composite[highbound/2]; //entries for 3, 5, 7, 9, etc. + memset(composite,0,highbound/2*sizeof(bool)); + int roothighbound=sqrt(highbound); + for(int i=3;i<=highbound;i+=2){ + if(composite[i/2-1])continue; + smallprimes.push_back(i); + if(i>roothighbound)continue; + for(int j=i*i;j<=highbound;j+=2*i){ + composite[j/2-1]=true; + } + } + //for(int p : smallprimes)cerr<<p<<' '; + //cerr<<endl; +} + +pair<Bigint,Bigint> genprimepair(int nbits){ + // for x = nbits/2: + // (2^x)^2 = 2^(2x) + // (2^x + 2^(x-2))^2 = 2^(2x) + 2^(2x-1) + 2^(2x-4) + // ergo: (2^x + lambda*2^(x-2))^2 \in [2^(2x), 2^(2x+1)), for lambda \in [0,1] + // To make sure the primes "differ in length by a few digits" [RSA78], we use x1=x-2 in the first + // prime and x2-x+2 in the second + int x1=nbits/2-2,x2=(nbits+1)/2+2; + assert(x1+x2==nbits); + return make_pair( + randprime(Bigint::one<<x1,(Bigint::one<<x1)+(Bigint::one<<(x1-2))), + randprime(Bigint::one<<x2,(Bigint::one<<x2)+(Bigint::one<<(x2-2)))); +} + +Bigint randprime(const Bigint &biglow,const Bigint &bighigh){ + if(!smallprimes_inited)fillsmallprimes(); + + assert(bighigh>biglow); + static const int maxrangesize=100001; + Bigint diff(bighigh-biglow); + Bigint low,high; //inclusive + if(diff<=maxrangesize){ + low=biglow; + high=bighigh; + } else { + high=low=cryptrandom_big(diff-maxrangesize); + high+=maxrangesize; + } + if(low.even())low+=1; + if(high.even())high-=1; + Bigint sizeb(high-low+1); + assert(sizeb>=0&&sizeb<=maxrangesize); + int nnums=(sizeb.lowdigits()+1)/2; + + bool composite[nnums]; //low, low+2, low+4, ..., high (all odd numbers) + memset(composite,0,nnums*sizeof(bool)); + int nsmallprimes=smallprimes.size(); + for(int i=1;i<nsmallprimes;i++){ + int pr=smallprimes[i]; + int lowoffset=low.divmod(Bigint(pr)).second.lowdigits(); + int startat; + if(lowoffset==0)startat=0; + else if((pr-lowoffset)%2==0)startat=(pr-lowoffset)/2; + else startat=pr-lowoffset; + for(int i=startat;i<nnums;i+=pr){ //skips ahead `2*pr` each time (so `pr` array elements) + composite[i]=true; + } + } +} + +bool bailliePSW(const Bigint&); diff --git a/primes.h b/primes.h new file mode 100644 index 0000000..61e5fc7 --- /dev/null +++ b/primes.h @@ -0,0 +1,19 @@ +#pragma once + +#include <vector> +#include <utility> +#include "bigint.h" + +extern std::vector<int> smallprimes; + +void fillsmallprimes(); + +//for use in RSA (pass target number of bits of N) +std::pair<Bigint,Bigint> genprimepair(int nbits); + +//finds random in range [low,high]; throws domain_error if no prime found +//Will call fillsmallprimes() if not yet done +Bigint randprime(const Bigint &low,const Bigint &high); + +//checks primality +bool bailliePSW(const Bigint&); |