aboutsummaryrefslogtreecommitdiff
path: root/rsa.cpp
blob: 94c97820932193f06acbd19faefc8419cf5aea86 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <algorithm>
#include <stdexcept>
#include <cstdint>
#include <cassert>
#include "base64.h"
#include "numalgo.h"
#include "primes.h"
#include "rng.h"
#include "rsa.h"

using namespace std;

namespace RSA{

	Bigint encrypt(Bigint msg,const Key &pubkey){
		assert(msg>1&&msg<pubkey.mod);
		return expmod(msg,pubkey.exp,pubkey.mod);
	}

	Bigint decrypt(Bigint encr,const Key &privkey){
		return expmod(encr,privkey.exp,privkey.mod);
	}

	pair<Key,Key> genkeys(int nbits,Rng &rng){
		pair<Bigint,Bigint> pq=genprimepair(rng,nbits);
		Key pubkey,privkey;
		pubkey.mod=privkey.mod=pq.first*pq.second;
		pubkey.exp=65537;
		Bigint x;
		Bigint phi((pq.first-Bigint::one)*(pq.second-Bigint::one));
		assert(egcd(phi,pubkey.exp,x,privkey.exp)==1);
		privkey.exp=privkey.exp.divmod(phi).second;
		// cerr<<"pubkey = {"<<pubkey.mod<<" , "<<pubkey.exp<<'}'<<endl;
		// cerr<<"privkey = {"<<privkey.mod<<" , "<<privkey.exp<<'}'<<endl;
		return make_pair(pubkey,privkey);
	}

	pair<Key,Key> genkeys(int nbits){
		CryptoRng rng;
		return genkeys(nbits,rng);
	}

	pair<Key,Key> genkeys(int nbits,const string &password){
		KeyRng rng(password);
		return genkeys(nbits,rng);
	}

	string exportKey(const Key &key){
		string modser=key.mod.serialiseMantissa();
		int32_t modlen=modser.size();
		string modlenstr{(char)(modlen&0xff),(char)((modlen>>8)&0xff),(char)((modlen>>16)&0xff),(char)((modlen>>24)&0xff)};
		return Base64::encode(modlenstr + modser + key.exp.serialiseMantissa());
	}

	Key importKey(const string &repr){
		string deser=Base64::decode(repr);
		if(deser.size()<=4)throw invalid_argument("Invalid key string length");
		int modlen=(uint8_t)deser[0]+((uint8_t)deser[1]<<8)+
		           ((uint8_t)deser[2]<<16)+((uint8_t)deser[3]<<24);
		if((int)deser.size()-4-modlen<=0)throw invalid_argument("Key string incomplete");
		Key key;
		key.mod.deserialiseMantissa(string(deser.begin()+4,deser.begin()+(4+modlen)));
		key.exp.deserialiseMantissa(string(deser.begin()+(4+modlen),deser.end()));
		return key;
	}

}