aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bigint.cpp106
-rw-r--r--bigint.h17
-rw-r--r--main.cpp34
-rw-r--r--numalgo.cpp61
-rw-r--r--numalgo.h5
-rw-r--r--primes.cpp82
-rw-r--r--primes.h19
7 files changed, 292 insertions, 32 deletions
diff --git a/bigint.cpp b/bigint.cpp
index 5182d07..644ffb3 100644
--- a/bigint.cpp
+++ b/bigint.cpp
@@ -7,6 +7,10 @@
using namespace std;
+Bigint Bigint::zero(0);
+Bigint Bigint::one(1);
+Bigint Bigint::mone(-1);
+
Bigint::Bigint()
:sign(1){}
@@ -163,6 +167,11 @@ Bigint& Bigint::operator*=(const Bigint &o){
return *this;
}
+//TODO: optimise these functions
+Bigint& Bigint::operator+=(slongdigit_t n){return *this+=Bigint(n);}
+Bigint& Bigint::operator-=(slongdigit_t n){return *this-=Bigint(n);}
+Bigint& Bigint::operator*=(slongdigit_t n){return *this*=Bigint(n);}
+
Bigint& Bigint::operator<<=(int sh){
if(sh==0)return *this;
if(digits.size()==0)return *this;
@@ -233,12 +242,26 @@ Bigint Bigint::operator*(const Bigint &o) const {
return product(*this,o);
}
+//TODO: optimise these functions
+Bigint Bigint::operator+(slongdigit_t n) const {return *this+Bigint(n);}
+Bigint Bigint::operator-(slongdigit_t n) const {return *this-Bigint(n);}
+Bigint Bigint::operator*(slongdigit_t n) const {return *this*Bigint(n);}
+
+Bigint Bigint::operator<<(int sh) const {
+ return Bigint(*this)<<=sh;
+}
+
+Bigint Bigint::operator>>(int sh) const {
+ return Bigint(*this)>>=sh;
+}
+
int depthrecord;
pair<Bigint,Bigint> Bigint::divmod(const Bigint &div) const {
pair<Bigint,Bigint> res=divmod(div,1);
//cerr<<hex<<*this<<' '<<div<<' ';
//cerr<<dec<<depthrecord<<endl;
+ //cerr<<" -> "<<hex<<res.first<<' '<<res.second<<dec<<endl;
return res;
}
@@ -247,13 +270,14 @@ inline void record(int depth){
}
pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const {
- //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<')'<<endl;
+ if(depth>100)assert(false);
+ //cerr<<"divmod("<<hex<<*this<<','<<hex<<div<<dec<<") depth="<<depth<<endl;
if(div.digits.size()==0)throw domain_error("Bigint divide by zero");
- if(digits.size()==0){record(depth); return make_pair(Bigint(0),Bigint(0));}
+ if(digits.size()==0){record(depth); return make_pair(Bigint::zero,Bigint::zero);}
int cmp=compareAbs(div);
- if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint(0));}
- if(cmp<0){record(depth); return make_pair(Bigint(0),*this);}
+ if(cmp==0){record(depth); return make_pair(Bigint(sign*div.sign),Bigint::zero);}
+ if(cmp<0){record(depth); return make_pair(Bigint::zero,*this);}
//now *this is greater in magnitude than the divisor
#if 0
@@ -266,9 +290,11 @@ pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const {
//cerr<<"guess="<<hex<<guess<<endl;
//twoexp now becomes irrelevant
#else
+#if 0
Bigint guess,quotient;
if(digits.size()==div.digits.size()){
quotient=digits.back()/div.digits.back();
+ quotient.sign=sign*div.sign;
guess=quotient;
guess*=div;
} else {
@@ -277,42 +303,81 @@ pair<Bigint,Bigint> Bigint::divmod(const Bigint &div,int depth) const {
longdigit_t factor=((longdigit_t)1<<digit_bits)*digits.back()+digits[digits.size()-2];
factor/=div.digits.back()+1;
quotient=factor;
- quotient<<=(digits.size()-div.digits.size()-1)*digit_bits;
- guess=quotient;
- guess*=div;
+ quotient<<=(digits.size()-1-div.digits.size())*digit_bits;
+ quotient.sign=sign*div.sign;
+ guess=quotient*div;
}
- // cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<endl;
+ cerr<<"guess="<<hex<<guess<<" quotient="<<quotient<<dec<<endl;
+#else
+ int thisbtc=bitcount(),divbtc=div.bitcount();
+ assert(divbtc<=thisbtc);
+ Bigint quotient,guess;
+ if(thisbtc<=2*digit_bits){
+ //simple integral division
+ longdigit_t thisnum=(digits.size()==2?((longdigit_t)1<<digit_bits)*digits[1]:0)+digits[0];
+ longdigit_t divnum=(div.digits.size()==2?((longdigit_t)1<<digit_bits)*div.digits[1]:0)+div.digits[0];
+ if(divnum==1){record(depth); return make_pair(*this,Bigint::zero);}
+ record(depth); return make_pair(
+ Bigint(sign*div.sign*(slongdigit_t)(thisnum/divnum)),
+ Bigint(sign*div.sign*(slongdigit_t)(thisnum%divnum)));
+ } else if(divbtc>=digit_bits){ //the large case
+ //take 2 digits of *this and 1 digit of div; quotient gives a good guess
+ int spill=__builtin_clz(digits.back());
+ //cerr<<"spill="<<spill<<endl;
+ longdigit_t thishead2=((longdigit_t)digits.back()<<(spill+digit_bits))|((longdigit_t)digits[digits.size()-2]<<spill);
+ if(spill>0)thishead2|=digits[digits.size()-3]>>(digit_bits-spill);
+ //cerr<<"thishead2="<<hex<<thishead2<<dec<<endl;
+ longdigit_t divhead=((longdigit_t)div.digits.back()<<digit_bits)|div.digits[div.digits.size()-2];
+ divhead>>=digit_bits-__builtin_clz(div.digits.back());
+ //cerr<<"divhead="<<hex<<divhead<<dec<<endl;
+ longdigit_t factor=thishead2/(divhead+1); //+1 to make sure the quotient guess is <= the actual quotient
+ quotient=factor;
+ quotient<<=thisbtc-digit_bits-divbtc; //shift amount may be negative if thisbtc and divbtc < digit_bits apart
+ quotient.sign=sign*div.sign;
+ guess=quotient*div;
+ } else { //divbtc<digit_bits, but *this is large
+ //take 2 digits of *this and all of div
+ int spill=__builtin_clz(digits.back());
+ longdigit_t thishead2=((longdigit_t)digits.back()<<(spill+digit_bits))|(digits[digits.size()-2]<<spill);
+ if(spill>0)thishead2|=digits[digits.size()-3]>>(digit_bits-spill);
+ longdigit_t factor=thishead2/(div.digits.back()+1); //+1 to make sure the quotient guess is <= the actual quotient
+ quotient=factor;
+ quotient<<=thisbtc-2*digit_bits;
+ quotient.sign=sign*div.sign;
+ guess=quotient*div;
+ }
+ //cerr<<"guess= "<<hex<<guess<<" quotient="<<quotient<<dec<<endl;
+#endif
#endif
cmp=guess.compareAbs(*this);
if(cmp<0){
- Bigint guess2(guess);
+ /*Bigint guess2(guess);
while(true){
guess2<<=1;
int cmp=guess2.compareAbs(*this);
if(cmp>0)break;
guess<<=1;
quotient<<=1;
- if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));}
- }
+ if(cmp==0){record(depth); return make_pair(quotient,Bigint::zero);}
+ }*/
Bigint rest(*this);
rest-=guess; //also correct for *this and guess negative
pair<Bigint,Bigint> dm=rest.divmod(div,depth+1);
dm.first+=quotient;
return dm;
}
- if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));}
+ if(cmp==0){record(depth); return make_pair(quotient,Bigint::zero);}
//then cmp>0, so our guess is too large
- Bigint one(1);
do {
if(quotient.digits[0]&1){
guess-=div;
- // quotient-=one; // not necessary, since we shift the bit out anyway
+ // quotient-=Bigint::one; // not necessary, since we shift the bit out anyway
}
guess>>=1;
quotient>>=1;
cmp=guess.compareAbs(*this);
} while(cmp>0);
- if(cmp==0){record(depth); return make_pair(quotient,Bigint(0));}
+ if(cmp==0){record(depth); return make_pair(quotient,Bigint::zero);}
Bigint rest(*this);
rest-=guess;
pair<Bigint,Bigint> dm=rest.divmod(div,depth+1);
@@ -372,7 +437,7 @@ int Bigint::compareAbs(slongdigit_t v) const {
int Bigint::bitcount() const {
if(digits.size()==0)return 0;
- return (digits.size()-1)*8+ilog2(digits.back())+1;
+ return (digits.size()-1)*digit_bits+ilog2(digits.back())+1;
}
Bigint::slongdigit_t Bigint::lowdigits() const {
@@ -382,6 +447,13 @@ Bigint::slongdigit_t Bigint::lowdigits() const {
return ((slongdigit_t)1<<digit_bits)*(digits[1]&mask)+digits[0];
}
+bool Bigint::even() const {
+ return digits.size()==0||(digits[0]&1)==0;
+}
+bool Bigint::odd() const {
+ return !even();
+}
+
vector<char> Bigint::serialise() const {
vector<char> v(1+digits.size()*sizeof(digit_t));
v[0]=sign;
@@ -446,7 +518,7 @@ istream& operator>>(istream &is,Bigint &b){
return is;
}
b*=ten;
- b+=Bigint(c-'0');
+ b+=c-'0';
// cerr<<"b="<<b<<endl;
}
if(!acted)is.setstate(ios_base::failbit);
diff --git a/bigint.h b/bigint.h
index 6579054..213bf5c 100644
--- a/bigint.h
+++ b/bigint.h
@@ -39,6 +39,9 @@ public:
Bigint& operator+=(const Bigint&);
Bigint& operator-=(const Bigint&);
Bigint& operator*=(const Bigint&);
+ Bigint& operator+=(slongdigit_t);
+ Bigint& operator-=(slongdigit_t);
+ Bigint& operator*=(slongdigit_t);
Bigint& operator<<=(int);
Bigint& operator>>=(int);
Bigint& negate();
@@ -46,9 +49,12 @@ public:
Bigint operator+(const Bigint&) const;
Bigint operator-(const Bigint&) const;
Bigint operator*(const Bigint&) const;
+ Bigint operator+(slongdigit_t) const;
+ Bigint operator-(slongdigit_t) const;
+ Bigint operator*(slongdigit_t) const;
Bigint operator<<(int) const;
Bigint operator>>(int) const;
- std::pair<Bigint,Bigint> divmod(const Bigint&) const;
+ std::pair<Bigint,Bigint> divmod(const Bigint&) const; //rounds towards zero; returns {quotient,remainder}
bool operator==(const Bigint&) const;
bool operator!=(const Bigint&) const;
@@ -70,6 +76,8 @@ public:
int bitcount() const;
slongdigit_t lowdigits() const;
+ bool even() const;
+ bool odd() const;
std::vector<char> serialise() const;
void deserialise(const std::vector<char>&);
@@ -79,9 +87,12 @@ public:
friend std::ostream& operator<<(std::ostream&,Bigint);
digit_t _digit(int idx) const;
-};
-Bigint pow(const Bigint &b,const Bigint &ex);
+
+ static Bigint zero;
+ static Bigint one;
+ static Bigint mone;
+};
std::istream& operator>>(std::istream&,Bigint&);
std::ostream& operator<<(std::ostream&,Bigint);
diff --git a/main.cpp b/main.cpp
index ce1573f..2d652b4 100644
--- a/main.cpp
+++ b/main.cpp
@@ -8,6 +8,7 @@
#include <cassert>
#include "bigint.h"
#include "numalgo.h"
+#include "primes.h"
#include "rsa.h"
using namespace std;
@@ -108,6 +109,23 @@ void repl(int argc,char **argv){
}
}
+void testisqrt(int argc,char **argv){
+ int randsize=argc==2?strtol(argv[1],nullptr,10):1;
+ assert(randsize>=1);
+ for(int i=0;i<1000;i++){
+ Bigint n(rand64());
+ for(int j=1;j<randsize;j++){
+ n<<=63;
+ n+=rand64();
+ }
+ // cout<<hex<<n<<dec<<endl;
+ Bigint root(isqrt(n));
+ assert(root*root<=n);
+ root+=Bigint::one;
+ assert(root*root>n);
+ }
+}
+
void performrsa(){
PrivateKey privkey;
Bigint p(1000000007),q(3000000019);
@@ -115,8 +133,7 @@ void performrsa(){
privkey.pub.exp=65537;
{
Bigint x;
- Bigint one(1);
- egcd((p-one)*(q-one),privkey.pub.exp,x,privkey.pexp);
+ egcd((p-Bigint::one)*(q-Bigint::one),privkey.pub.exp,x,privkey.pexp);
}
cout<<"d = "<<privkey.pexp<<endl;
Bigint msg(123456789);
@@ -127,15 +144,10 @@ void performrsa(){
cout<<"msg = "<<msg2<<endl;
}
-int main(int,char**){
+int main(int argc,char **argv){
// biginttest();
// repl(argc,argv);
- performrsa();
- // cout<<Bigint(0)-Bigint(69255535LL)<<endl;
- // cout<<Bigint(0)-Bigint(669255535LL)<<endl;
- // cout<<Bigint(0)-Bigint(5669255535LL)<<endl;
- // cout<<Bigint(0)-Bigint(75669255535LL)<<endl;
- // cout<<Bigint(0)-Bigint(775669255535LL)<<endl;
- // cout<<Bigint(0)-Bigint(5775669255535LL)<<endl;
- // cout<<Bigint(0)-Bigint(45775669255535LL)<<endl;
+ // performrsa();
+ // testisqrt(argc,argv);
+ fillsmallprimes();
}
diff --git a/numalgo.cpp b/numalgo.cpp
index 1db0763..f2ebac5 100644
--- a/numalgo.cpp
+++ b/numalgo.cpp
@@ -1,3 +1,4 @@
+#include <cstdlib>
#include <cassert>
#include "numalgo.h"
@@ -31,7 +32,7 @@ Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y){
Bigint expmod(const Bigint &b,const Bigint &e,const Bigint &m){
assert(e>=0);
assert(m>=1);
- if(m==1)return Bigint(0);
+ if(m==1)return Bigint::zero;
Bigint res(1);
vector<bool> bits(e.bits());
for(int i=bits.size()-1;i>=0;i--){
@@ -42,9 +43,67 @@ Bigint expmod(const Bigint &b,const Bigint &e,const Bigint &m){
return res;
}
+Bigint isqrt(const Bigint &n){
+ assert(n>=0);
+ if(n<=1)return n;
+ const int maxiter=20; //empirically, this should happen around 3.5 million bits in n. (niter ~= -1.87+1.45ln(bits))
+ // __asm("int3\n\t");
+ // cout<<"bitcount="<<n.bitcount()<<" maxiter="<<maxiter<<endl;
+ Bigint x(n);
+ x>>=n.bitcount()/2;
+ // cerr<<"isqrt("<<hex<<n<<"); x = "<<x<<dec<<endl;
+ int iter;
+ for(iter=0;iter<maxiter;iter++){
+ // cerr<<"x="<<hex<<x<<dec<<endl;
+ Bigint xnext(x*x); // xnext = x - (x*x-n)/(2*x) [Newton's method]
+ xnext-=n;
+ xnext>>=1;
+ xnext=xnext.divmod(x).first;
+ xnext.negate();
+ xnext+=x;
+ if(xnext==x)break;
+ x=xnext;
+ }
+ cerr<<iter<<endl;
+ assert(iter<maxiter);
+ switch((x*x).compare(n)){
+ case -1:{
+ Bigint x1(x);
+ x1+=Bigint::one;
+ assert(x1*x1>n);
+ return x;
+ }
+ case 0: return x;
+ case 1:{
+ x-=Bigint::one;
+ assert(x*x<=n);
+ return x;
+ }
+ default: assert(false);
+ }
+}
+
int ilog2(uint64_t i){
assert(i);
int l=0;
while(i>>=1)l++;
return l;
}
+
+Bigint cryptrandom_big(const Bigint &bound){
+ const int blocksize=32;
+ int btc=bound.bitcount();
+ int nblocks=btc/blocksize,rest=btc%blocksize;
+ while(true){
+ Bigint res;
+ for(int i=0;i<nblocks;i++){
+ if(i!=0)res<<=blocksize;
+ res+=arc4random_uniform((uint32_t)(((uint64_t)1<<blocksize)-1)); //make sure we don't shift out of our int
+ }
+ if(rest){
+ res<<=rest;
+ res+=arc4random_uniform((uint32_t)1<<rest);
+ }
+ if(res<=bound)return res;
+ }
+}
diff --git a/numalgo.h b/numalgo.h
index 71e06f0..997a67c 100644
--- a/numalgo.h
+++ b/numalgo.h
@@ -8,4 +8,9 @@ Bigint egcd(const Bigint &a,const Bigint &b,Bigint &x,Bigint &y);
Bigint expmod(const Bigint &base,const Bigint &exponent,const Bigint &modulus);
+// Returns sqrt(n), rounded down if necessary
+Bigint isqrt(const Bigint &n);
+
int ilog2(uint64_t i);
+
+Bigint cryptrandom_big(const Bigint &upperbound); //Return value in [0,upperbound]
diff --git a/primes.cpp b/primes.cpp
new file mode 100644
index 0000000..d144e5e
--- /dev/null
+++ b/primes.cpp
@@ -0,0 +1,82 @@
+#include <cstring>
+#include <cmath>
+#include <cassert>
+#include "numalgo.h"
+#include "primes.h"
+
+using namespace std;
+
+vector<int> smallprimes;
+bool smallprimes_inited=false;
+
+void fillsmallprimes(){
+ smallprimes_inited=true;
+ //TODO: reserve expected amount of space in smallprimes
+ smallprimes.push_back(2);
+ const int highbound=65000;
+ bool composite[highbound/2]; //entries for 3, 5, 7, 9, etc.
+ memset(composite,0,highbound/2*sizeof(bool));
+ int roothighbound=sqrt(highbound);
+ for(int i=3;i<=highbound;i+=2){
+ if(composite[i/2-1])continue;
+ smallprimes.push_back(i);
+ if(i>roothighbound)continue;
+ for(int j=i*i;j<=highbound;j+=2*i){
+ composite[j/2-1]=true;
+ }
+ }
+ //for(int p : smallprimes)cerr<<p<<' ';
+ //cerr<<endl;
+}
+
+pair<Bigint,Bigint> genprimepair(int nbits){
+ // for x = nbits/2:
+ // (2^x)^2 = 2^(2x)
+ // (2^x + 2^(x-2))^2 = 2^(2x) + 2^(2x-1) + 2^(2x-4)
+ // ergo: (2^x + lambda*2^(x-2))^2 \in [2^(2x), 2^(2x+1)), for lambda \in [0,1]
+ // To make sure the primes "differ in length by a few digits" [RSA78], we use x1=x-2 in the first
+ // prime and x2-x+2 in the second
+ int x1=nbits/2-2,x2=(nbits+1)/2+2;
+ assert(x1+x2==nbits);
+ return make_pair(
+ randprime(Bigint::one<<x1,(Bigint::one<<x1)+(Bigint::one<<(x1-2))),
+ randprime(Bigint::one<<x2,(Bigint::one<<x2)+(Bigint::one<<(x2-2))));
+}
+
+Bigint randprime(const Bigint &biglow,const Bigint &bighigh){
+ if(!smallprimes_inited)fillsmallprimes();
+
+ assert(bighigh>biglow);
+ static const int maxrangesize=100001;
+ Bigint diff(bighigh-biglow);
+ Bigint low,high; //inclusive
+ if(diff<=maxrangesize){
+ low=biglow;
+ high=bighigh;
+ } else {
+ high=low=cryptrandom_big(diff-maxrangesize);
+ high+=maxrangesize;
+ }
+ if(low.even())low+=1;
+ if(high.even())high-=1;
+ Bigint sizeb(high-low+1);
+ assert(sizeb>=0&&sizeb<=maxrangesize);
+ int nnums=(sizeb.lowdigits()+1)/2;
+
+ bool composite[nnums]; //low, low+2, low+4, ..., high (all odd numbers)
+ memset(composite,0,nnums*sizeof(bool));
+ int nsmallprimes=smallprimes.size();
+ for(int i=1;i<nsmallprimes;i++){
+ int pr=smallprimes[i];
+ int lowoffset=low.divmod(Bigint(pr)).second.lowdigits();
+ int startat;
+ if(lowoffset==0)startat=0;
+ else if((pr-lowoffset)%2==0)startat=(pr-lowoffset)/2;
+ else startat=pr-lowoffset;
+ for(int i=startat;i<nnums;i+=pr){ //skips ahead `2*pr` each time (so `pr` array elements)
+ composite[i]=true;
+ }
+ }
+}
+
+bool bailliePSW(const Bigint&);
diff --git a/primes.h b/primes.h
new file mode 100644
index 0000000..61e5fc7
--- /dev/null
+++ b/primes.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <vector>
+#include <utility>
+#include "bigint.h"
+
+extern std::vector<int> smallprimes;
+
+void fillsmallprimes();
+
+//for use in RSA (pass target number of bits of N)
+std::pair<Bigint,Bigint> genprimepair(int nbits);
+
+//finds random in range [low,high]; throws domain_error if no prime found
+//Will call fillsmallprimes() if not yet done
+Bigint randprime(const Bigint &low,const Bigint &high);
+
+//checks primality
+bool bailliePSW(const Bigint&);