aboutsummaryrefslogtreecommitdiff
path: root/aes.cpp
diff options
context:
space:
mode:
authortomsmeding <tom.smeding@gmail.com>2016-10-07 20:33:18 +0200
committertomsmeding <tom.smeding@gmail.com>2016-10-07 20:33:18 +0200
commit241518f369efce64046be36a15fcb722b00e9477 (patch)
tree1aaa6e56e027df35503c497b3c1ef5fa6e8e7916 /aes.cpp
parent8ee3380a6b116778ccd1a895802465884f58a9b9 (diff)
Working AES encrypt and decrypt!
Diffstat (limited to 'aes.cpp')
-rw-r--r--aes.cpp265
1 files changed, 237 insertions, 28 deletions
diff --git a/aes.cpp b/aes.cpp
index 13659e6..e226a3c 100644
--- a/aes.cpp
+++ b/aes.cpp
@@ -1,79 +1,288 @@
+#include <iostream>
+#include <iomanip>
+#include <stdexcept>
#include <cstring>
#include <cassert>
#include "aes.h"
+#include "gf28.h"
+#include "rng.h"
using namespace std;
namespace AES{
- //http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf
+ //http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf (AES)
+ //https://tools.ietf.org/html/rfc3602 (CBC)
- void subWord(uint8_t *word);
+ //State is represented in bytes, as is the key schedule (which is represented in words in the AES spec).
+ //State is in column-major order.
- void keyExpand(uint8_t *keysched,const uint8_t *key,int keylen,int numrounds);
+ uint32_t roundconstant[10]={};
+ uint8_t sbox[256]={};
+ uint8_t invsbox[256]={};
- void addRoundKey(uint8_t *state,const uint8_t *roundkey);
+ void printstate(const uint8_t *state){
+ for(int i=0;i<16;i++){
+ cout<<setw(2)<<setfill('0')<<hex<<(int)state[i]<<dec;
+ }
+ cout<<endl;
+ }
+
+ void initTables(){
+ //generated tables have been checked with AES spec
+ int term=0x01;
+ for(int i=0;i<10;i++){
+ roundconstant[i]=term<<24;
+ term=GF28::multiply(term,0x02);
+ }
+
+ for(int i=0;i<256;i++){
+ uint8_t inv=(uint8_t)GF28(i).inverse();
+ uint8_t res=0;
+ for(int j=0;j<8;j++){
+ uint8_t bit=((inv>>j)&1)^
+ ((inv>>((j+4)%8))&1)^
+ ((inv>>((j+5)%8))&1)^
+ ((inv>>((j+6)%8))&1)^
+ ((inv>>((j+7)%8))&1);
+ res|=bit<<j;
+ }
+ res^=0x63;
+ sbox[i]=res;
+ invsbox[res]=i;
+ }
+ }
+
+ uint32_t subWord(uint32_t word){
+ return (sbox[word>>24]<<24)|(sbox[(word>>16)&0xff]<<16)|(sbox[(word>>8)&0xff]<<8)|sbox[word&0xff];
+ }
+
+ uint32_t rotWord(uint32_t word){
+ return (word<<8)|(word>>24);
+ }
+
+ void keyExpand(uint8_t *keysched,const uint8_t *key,int keylen,int numrounds){
+ memcpy(keysched,key,4*keylen);
+ for(int i=keylen;i<4*(numrounds+1);i++){
+ uint32_t temp=(keysched[4*i-4]<<24)|(keysched[4*i-3]<<16)|(keysched[4*i-2]<<8)|keysched[4*i-1];
+ if(i%keylen==0){
+ temp=subWord(rotWord(temp))^roundconstant[i/keylen-1];
+ } else if(keylen>6&&i%keylen==4){
+ temp=subWord(temp);
+ }
+ keysched[4*i+0]=keysched[4*(i-keylen)+0]^(temp>>24);
+ keysched[4*i+1]=keysched[4*(i-keylen)+1]^((temp>>16)&0xff);
+ keysched[4*i+2]=keysched[4*(i-keylen)+2]^((temp>>8)&0xff);
+ keysched[4*i+3]=keysched[4*(i-keylen)+3]^(temp&0xff);
+ }
+ }
+
+ void addRoundKey(uint8_t *state,const uint8_t *roundkey){
+ for(int i=0;i<16;i++)state[i]^=roundkey[i];
+ }
void subBytes(uint8_t *state){
- subWord(state);
- subWord(state+4);
- subWord(state+8);
- subWord(state+12);
+ for(int i=0;i<16;i++)state[i]=sbox[state[i]];
}
- void shiftRows(uint8_t *state);
+ void shiftRows(uint8_t *state){
+ uint8_t t=state[1]; state[1]=state[5]; state[5]=state[9]; state[9]=state[13]; state[13]=t;
+ swap(state[2],state[10]); swap(state[6],state[14]);
+ t=state[3]; state[3]=state[15]; state[15]=state[11]; state[11]=state[7]; state[7]=t;
+ }
- void mixColumns(uint8_t *state);
+ void mixColumns(uint8_t *state){
+ for(int i=0;i<4;i++){
+ uint8_t a=GF28::multiply(0x02,state[4*i+0]) ^ GF28::multiply(0x03,state[4*i+1]) ^ state[4*i+2] ^ state[4*i+3];
+ uint8_t b=state[4*i+0] ^ GF28::multiply(0x02,state[4*i+1]) ^ GF28::multiply(0x03,state[4*i+2]) ^ state[4*i+3];
+ uint8_t c=state[4*i+0] ^ state[4*i+1] ^ GF28::multiply(0x02,state[4*i+2]) ^ GF28::multiply(0x03,state[4*i+3]);
+ uint8_t d=GF28::multiply(0x03,state[4*i+0]) ^ state[4*i+1] ^ state[4*i+2] ^ GF28::multiply(0x02,state[4*i+3]);
+ state[4*i+0]=a;
+ state[4*i+1]=b;
+ state[4*i+2]=c;
+ state[4*i+3]=d;
+ }
+ }
+ void invShiftRows(uint8_t *state){
+ uint8_t t=state[1]; state[1]=state[13]; state[13]=state[9]; state[9]=state[5]; state[5]=t;
+ swap(state[2],state[10]); swap(state[6],state[14]);
+ t=state[3]; state[3]=state[7]; state[7]=state[11]; state[11]=state[15]; state[15]=t;
+ }
- void encryptBlock(uint8_t *state,const uint8_t *key,const uint8_t *data,int keylen,int numrounds){
- uint8_t keysched[16*(numrounds+1)];
- keyExpand(keysched,key,keylen,numrounds);
+ void invSubBytes(uint8_t *state){
+ for(int i=0;i<16;i++)state[i]=invsbox[state[i]];
+ }
+
+ void invMixColumns(uint8_t *state){
+ for(int i=0;i<4;i++){
+ uint8_t a=GF28::multiply(0x0e,state[4*i+0])^
+ GF28::multiply(0x0b,state[4*i+1])^
+ GF28::multiply(0x0d,state[4*i+2])^
+ GF28::multiply(0x09,state[4*i+3]);
+ uint8_t b=GF28::multiply(0x09,state[4*i+0])^
+ GF28::multiply(0x0e,state[4*i+1])^
+ GF28::multiply(0x0b,state[4*i+2])^
+ GF28::multiply(0x0d,state[4*i+3]);
+ uint8_t c=GF28::multiply(0x0d,state[4*i+0])^
+ GF28::multiply(0x09,state[4*i+1])^
+ GF28::multiply(0x0e,state[4*i+2])^
+ GF28::multiply(0x0b,state[4*i+3]);
+ uint8_t d=GF28::multiply(0x0b,state[4*i+0])^
+ GF28::multiply(0x0d,state[4*i+1])^
+ GF28::multiply(0x09,state[4*i+2])^
+ GF28::multiply(0x0e,state[4*i+3]);
+ state[4*i+0]=a;
+ state[4*i+1]=b;
+ state[4*i+2]=c;
+ state[4*i+3]=d;
+ }
+ }
+
+
+ void encryptBlock(uint8_t *state,const uint8_t *keysched,const uint8_t *data,int numrounds){
memcpy(state,data,16);
addRoundKey(state,keysched);
for(int round=0;round<numrounds-1;round++){
+ //cout<<"round["<<setw(2)<<setfill(' ')<<round+1<<"].start "; printstate(state);
subBytes(state);
shiftRows(state);
mixColumns(state);
addRoundKey(state,keysched+16*(round+1));
}
+ //cout<<"round["<<setw(2)<<setfill(' ')<<numrounds<<"].start "; printstate(state);
subBytes(state);
shiftRows(state);
addRoundKey(state,keysched+16*numrounds);
}
- string encrypt(const string &key,const string &data,int numrounds){
+ void decryptBlock(uint8_t *state,const uint8_t *keysched,const uint8_t *data,int numrounds){
+ memcpy(state,data,16);
+
+ addRoundKey(state,keysched+16*numrounds);
+ for(int round=numrounds-2;round>=0;round--){
+ //cout<<"round["<<setw(2)<<setfill(' ')<<round+1<<"].start "; printstate(state);
+ invShiftRows(state);
+ invSubBytes(state);
+ addRoundKey(state,keysched+16*(round+1));
+ invMixColumns(state);
+ }
+ //cout<<"round["<<setw(2)<<setfill(' ')<<numrounds<<"].start "; printstate(state);
+ invShiftRows(state);
+ invSubBytes(state);
+ addRoundKey(state,keysched);
+ }
+
+ string encryptCBC(const string &key,const string &data,int numrounds){
+ if(roundconstant[0]==0)initTables();
+
int sz=data.size();
- if(sz==0)return {};
+ if(sz==0)return {}; //if nothing to encrypt, don't even give an IV
int blocks=sz/16;
- int padding=sz%16==0?16:16-sz%16;
+ int padding=16-sz%16;
string res;
assert((sz+padding)%16==0);
- res.reserve(sz+padding);
- uint8_t buf[16];
+ res.reserve(16+sz+padding);
+
+ res.resize(16);
+ CryptoRng crng;
+ *(uint32_t*)&res[0]=crng.get(); //IV
+ *(uint32_t*)&res[4]=crng.get(); //endianness doesn't matter, since the data is random anyway
+ *(uint32_t*)&res[8]=crng.get();
+ *(uint32_t*)&res[12]=crng.get();
+
+ uint8_t keysched[16*(numrounds+1)];
+ keyExpand(keysched,(const uint8_t*)key.data(),key.size()/4,numrounds);
+
+ uint8_t buf[16],inbuf[16];
for(int i=0;i<blocks;i++){
- encryptBlock(buf,(const uint8_t*)key.data(),(const uint8_t*)data.data()+16*i,key.size(),numrounds);
+ memcpy(inbuf,data.data()+16*i,16);
+ for(int j=0;j<16;j++)inbuf[j]^=res[res.size()-16+j]; //the CBC xor step
+ encryptBlock(buf,keysched,inbuf,numrounds);
res.insert(res.size(),(char*)buf,16);
}
- uint8_t inbuf[16];
+
if(padding<16)memcpy(inbuf,data.data()+16*blocks,16-padding);
memset(inbuf+16-padding,padding,padding);
- encryptBlock(buf,(const uint8_t*)key.data(),inbuf,key.size(),numrounds);
+ for(int j=0;j<16;j++)inbuf[j]^=res[res.size()-16+j]; //the CBC xor step
+ encryptBlock(buf,keysched,inbuf,numrounds);
res.insert(res.size(),(char*)buf,16);
return res;
}
- string decrypt(const string &key,const string &data,int numrounds);
+ string decryptCBC(const string &key,const string &data,int numrounds){
+ if(roundconstant[0]==0)initTables();
+
+ if(data.size()==0)return {};
+ if(data.size()%16!=0)throw invalid_argument("AES encrypted data not multiple of block size");
+ int blocks=data.size()/16-1; //the IV is not counted as a block
+ string res(16*blocks,'\0');
+
+ uint8_t keysched[16*(numrounds+1)];
+ keyExpand(keysched,(const uint8_t*)key.data(),key.size()/4,numrounds);
+
+ for(int i=blocks-1;i>=0;i--){
+ decryptBlock((uint8_t*)&res[16*i],keysched,(const uint8_t*)data.data()+(16+16*i),numrounds);
+ for(int j=0;j<16;j++)res[16*i+j]^=data[16*i+j]; //CBC: xor with the previous (remember the IV taking up space)
+ }
+ int padsize=res.back();
+ if(padsize>16||padsize<0)throw invalid_argument("Malformed AES padding");
+ res.resize(res.size()-padsize);
+ return res;
+ }
string encrypt(const string &key,const string &data,Algorithm algo){
- assert(key.size()==4+2*algo);
- return encrypt(key,data,10+2*algo);
+ int increment;
+ switch(algo){
+ case AES_128_CBC: increment=0; break;
+ case AES_192_CBC: increment=1; break;
+ case AES_256_CBC: increment=2; break;
+ default: assert(false);
+ }
+ assert((int)key.size()==4*(4+2*increment));
+ return encryptCBC(key,data,10+2*increment);
+ }
+
+ string decrypt(const string &key,const string &data,Algorithm algo){
+ int increment;
+ switch(algo){
+ case AES_128_CBC: increment=0; break;
+ case AES_192_CBC: increment=1; break;
+ case AES_256_CBC: increment=2; break;
+ default: assert(false);
+ }
+ assert((int)key.size()==4*(4+2*increment));
+ return decryptCBC(key,data,10+2*increment);
}
- string decrypt(const string &key,const string &data,Algorithm algo)/*{
- assert(key.size()==4+2*algo);
- return decrypt(key,data,10+2*algo);
- }*/;
+ void test(){
+ #if 0
+ // Test encryption
+ initTables();
+ const int numrounds=10;
+ uint8_t keysched[16*(numrounds+1)];
+ uint8_t plaintext[16]={0x32,0x43,0xf6,0xa8,0x88,0x5a,0x30,0x8d,0x31,0x31,0x98,0xa2,0xe0,0x37,0x07,0x34};
+ const int keylen=4;
+ uint8_t key[4*keylen]={0x2b,0x7e,0x15,0x16,0x28,0xae,0xd2,0xa6,0xab,0xf7,0x15,0x88,0x09,0xcf,0x4f,0x3c};
+ keyExpand(keysched,key,keylen,numrounds);
+ uint8_t dest[16];
+ encryptBlock(dest,keysched,plaintext,numrounds);
+ printstate(dest);
+ #endif
+ #if 1
+ // Test decryption
+ initTables();
+ const int numrounds=14;
+ uint8_t keysched[16*(numrounds+1)];
+ uint8_t plaintext[16]={0x8e,0xa2,0xb7,0xca,0x51,0x67,0x45,0xbf,0xea,0xfc,0x49,0x90,0x4b,0x49,0x60,0x89};
+ const int keylen=8;
+ uint8_t key[4*keylen]={0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f};
+ keyExpand(keysched,key,keylen,numrounds);
+ uint8_t dest[16];
+ decryptBlock(dest,keysched,plaintext,numrounds);
+ printstate(dest);
+ #endif
+ }
}