diff options
author | tomsmeding <tom.smeding@gmail.com> | 2016-10-07 20:33:18 +0200 |
---|---|---|
committer | tomsmeding <tom.smeding@gmail.com> | 2016-10-07 20:33:18 +0200 |
commit | 241518f369efce64046be36a15fcb722b00e9477 (patch) | |
tree | 1aaa6e56e027df35503c497b3c1ef5fa6e8e7916 /aes.cpp | |
parent | 8ee3380a6b116778ccd1a895802465884f58a9b9 (diff) |
Working AES encrypt and decrypt!
Diffstat (limited to 'aes.cpp')
-rw-r--r-- | aes.cpp | 265 |
1 files changed, 237 insertions, 28 deletions
@@ -1,79 +1,288 @@ +#include <iostream> +#include <iomanip> +#include <stdexcept> #include <cstring> #include <cassert> #include "aes.h" +#include "gf28.h" +#include "rng.h" using namespace std; namespace AES{ - //http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf + //http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf (AES) + //https://tools.ietf.org/html/rfc3602 (CBC) - void subWord(uint8_t *word); + //State is represented in bytes, as is the key schedule (which is represented in words in the AES spec). + //State is in column-major order. - void keyExpand(uint8_t *keysched,const uint8_t *key,int keylen,int numrounds); + uint32_t roundconstant[10]={}; + uint8_t sbox[256]={}; + uint8_t invsbox[256]={}; - void addRoundKey(uint8_t *state,const uint8_t *roundkey); + void printstate(const uint8_t *state){ + for(int i=0;i<16;i++){ + cout<<setw(2)<<setfill('0')<<hex<<(int)state[i]<<dec; + } + cout<<endl; + } + + void initTables(){ + //generated tables have been checked with AES spec + int term=0x01; + for(int i=0;i<10;i++){ + roundconstant[i]=term<<24; + term=GF28::multiply(term,0x02); + } + + for(int i=0;i<256;i++){ + uint8_t inv=(uint8_t)GF28(i).inverse(); + uint8_t res=0; + for(int j=0;j<8;j++){ + uint8_t bit=((inv>>j)&1)^ + ((inv>>((j+4)%8))&1)^ + ((inv>>((j+5)%8))&1)^ + ((inv>>((j+6)%8))&1)^ + ((inv>>((j+7)%8))&1); + res|=bit<<j; + } + res^=0x63; + sbox[i]=res; + invsbox[res]=i; + } + } + + uint32_t subWord(uint32_t word){ + return (sbox[word>>24]<<24)|(sbox[(word>>16)&0xff]<<16)|(sbox[(word>>8)&0xff]<<8)|sbox[word&0xff]; + } + + uint32_t rotWord(uint32_t word){ + return (word<<8)|(word>>24); + } + + void keyExpand(uint8_t *keysched,const uint8_t *key,int keylen,int numrounds){ + memcpy(keysched,key,4*keylen); + for(int i=keylen;i<4*(numrounds+1);i++){ + uint32_t temp=(keysched[4*i-4]<<24)|(keysched[4*i-3]<<16)|(keysched[4*i-2]<<8)|keysched[4*i-1]; + if(i%keylen==0){ + temp=subWord(rotWord(temp))^roundconstant[i/keylen-1]; + } else if(keylen>6&&i%keylen==4){ + temp=subWord(temp); + } + keysched[4*i+0]=keysched[4*(i-keylen)+0]^(temp>>24); + keysched[4*i+1]=keysched[4*(i-keylen)+1]^((temp>>16)&0xff); + keysched[4*i+2]=keysched[4*(i-keylen)+2]^((temp>>8)&0xff); + keysched[4*i+3]=keysched[4*(i-keylen)+3]^(temp&0xff); + } + } + + void addRoundKey(uint8_t *state,const uint8_t *roundkey){ + for(int i=0;i<16;i++)state[i]^=roundkey[i]; + } void subBytes(uint8_t *state){ - subWord(state); - subWord(state+4); - subWord(state+8); - subWord(state+12); + for(int i=0;i<16;i++)state[i]=sbox[state[i]]; } - void shiftRows(uint8_t *state); + void shiftRows(uint8_t *state){ + uint8_t t=state[1]; state[1]=state[5]; state[5]=state[9]; state[9]=state[13]; state[13]=t; + swap(state[2],state[10]); swap(state[6],state[14]); + t=state[3]; state[3]=state[15]; state[15]=state[11]; state[11]=state[7]; state[7]=t; + } - void mixColumns(uint8_t *state); + void mixColumns(uint8_t *state){ + for(int i=0;i<4;i++){ + uint8_t a=GF28::multiply(0x02,state[4*i+0]) ^ GF28::multiply(0x03,state[4*i+1]) ^ state[4*i+2] ^ state[4*i+3]; + uint8_t b=state[4*i+0] ^ GF28::multiply(0x02,state[4*i+1]) ^ GF28::multiply(0x03,state[4*i+2]) ^ state[4*i+3]; + uint8_t c=state[4*i+0] ^ state[4*i+1] ^ GF28::multiply(0x02,state[4*i+2]) ^ GF28::multiply(0x03,state[4*i+3]); + uint8_t d=GF28::multiply(0x03,state[4*i+0]) ^ state[4*i+1] ^ state[4*i+2] ^ GF28::multiply(0x02,state[4*i+3]); + state[4*i+0]=a; + state[4*i+1]=b; + state[4*i+2]=c; + state[4*i+3]=d; + } + } + void invShiftRows(uint8_t *state){ + uint8_t t=state[1]; state[1]=state[13]; state[13]=state[9]; state[9]=state[5]; state[5]=t; + swap(state[2],state[10]); swap(state[6],state[14]); + t=state[3]; state[3]=state[7]; state[7]=state[11]; state[11]=state[15]; state[15]=t; + } - void encryptBlock(uint8_t *state,const uint8_t *key,const uint8_t *data,int keylen,int numrounds){ - uint8_t keysched[16*(numrounds+1)]; - keyExpand(keysched,key,keylen,numrounds); + void invSubBytes(uint8_t *state){ + for(int i=0;i<16;i++)state[i]=invsbox[state[i]]; + } + + void invMixColumns(uint8_t *state){ + for(int i=0;i<4;i++){ + uint8_t a=GF28::multiply(0x0e,state[4*i+0])^ + GF28::multiply(0x0b,state[4*i+1])^ + GF28::multiply(0x0d,state[4*i+2])^ + GF28::multiply(0x09,state[4*i+3]); + uint8_t b=GF28::multiply(0x09,state[4*i+0])^ + GF28::multiply(0x0e,state[4*i+1])^ + GF28::multiply(0x0b,state[4*i+2])^ + GF28::multiply(0x0d,state[4*i+3]); + uint8_t c=GF28::multiply(0x0d,state[4*i+0])^ + GF28::multiply(0x09,state[4*i+1])^ + GF28::multiply(0x0e,state[4*i+2])^ + GF28::multiply(0x0b,state[4*i+3]); + uint8_t d=GF28::multiply(0x0b,state[4*i+0])^ + GF28::multiply(0x0d,state[4*i+1])^ + GF28::multiply(0x09,state[4*i+2])^ + GF28::multiply(0x0e,state[4*i+3]); + state[4*i+0]=a; + state[4*i+1]=b; + state[4*i+2]=c; + state[4*i+3]=d; + } + } + + + void encryptBlock(uint8_t *state,const uint8_t *keysched,const uint8_t *data,int numrounds){ memcpy(state,data,16); addRoundKey(state,keysched); for(int round=0;round<numrounds-1;round++){ + //cout<<"round["<<setw(2)<<setfill(' ')<<round+1<<"].start "; printstate(state); subBytes(state); shiftRows(state); mixColumns(state); addRoundKey(state,keysched+16*(round+1)); } + //cout<<"round["<<setw(2)<<setfill(' ')<<numrounds<<"].start "; printstate(state); subBytes(state); shiftRows(state); addRoundKey(state,keysched+16*numrounds); } - string encrypt(const string &key,const string &data,int numrounds){ + void decryptBlock(uint8_t *state,const uint8_t *keysched,const uint8_t *data,int numrounds){ + memcpy(state,data,16); + + addRoundKey(state,keysched+16*numrounds); + for(int round=numrounds-2;round>=0;round--){ + //cout<<"round["<<setw(2)<<setfill(' ')<<round+1<<"].start "; printstate(state); + invShiftRows(state); + invSubBytes(state); + addRoundKey(state,keysched+16*(round+1)); + invMixColumns(state); + } + //cout<<"round["<<setw(2)<<setfill(' ')<<numrounds<<"].start "; printstate(state); + invShiftRows(state); + invSubBytes(state); + addRoundKey(state,keysched); + } + + string encryptCBC(const string &key,const string &data,int numrounds){ + if(roundconstant[0]==0)initTables(); + int sz=data.size(); - if(sz==0)return {}; + if(sz==0)return {}; //if nothing to encrypt, don't even give an IV int blocks=sz/16; - int padding=sz%16==0?16:16-sz%16; + int padding=16-sz%16; string res; assert((sz+padding)%16==0); - res.reserve(sz+padding); - uint8_t buf[16]; + res.reserve(16+sz+padding); + + res.resize(16); + CryptoRng crng; + *(uint32_t*)&res[0]=crng.get(); //IV + *(uint32_t*)&res[4]=crng.get(); //endianness doesn't matter, since the data is random anyway + *(uint32_t*)&res[8]=crng.get(); + *(uint32_t*)&res[12]=crng.get(); + + uint8_t keysched[16*(numrounds+1)]; + keyExpand(keysched,(const uint8_t*)key.data(),key.size()/4,numrounds); + + uint8_t buf[16],inbuf[16]; for(int i=0;i<blocks;i++){ - encryptBlock(buf,(const uint8_t*)key.data(),(const uint8_t*)data.data()+16*i,key.size(),numrounds); + memcpy(inbuf,data.data()+16*i,16); + for(int j=0;j<16;j++)inbuf[j]^=res[res.size()-16+j]; //the CBC xor step + encryptBlock(buf,keysched,inbuf,numrounds); res.insert(res.size(),(char*)buf,16); } - uint8_t inbuf[16]; + if(padding<16)memcpy(inbuf,data.data()+16*blocks,16-padding); memset(inbuf+16-padding,padding,padding); - encryptBlock(buf,(const uint8_t*)key.data(),inbuf,key.size(),numrounds); + for(int j=0;j<16;j++)inbuf[j]^=res[res.size()-16+j]; //the CBC xor step + encryptBlock(buf,keysched,inbuf,numrounds); res.insert(res.size(),(char*)buf,16); return res; } - string decrypt(const string &key,const string &data,int numrounds); + string decryptCBC(const string &key,const string &data,int numrounds){ + if(roundconstant[0]==0)initTables(); + + if(data.size()==0)return {}; + if(data.size()%16!=0)throw invalid_argument("AES encrypted data not multiple of block size"); + int blocks=data.size()/16-1; //the IV is not counted as a block + string res(16*blocks,'\0'); + + uint8_t keysched[16*(numrounds+1)]; + keyExpand(keysched,(const uint8_t*)key.data(),key.size()/4,numrounds); + + for(int i=blocks-1;i>=0;i--){ + decryptBlock((uint8_t*)&res[16*i],keysched,(const uint8_t*)data.data()+(16+16*i),numrounds); + for(int j=0;j<16;j++)res[16*i+j]^=data[16*i+j]; //CBC: xor with the previous (remember the IV taking up space) + } + int padsize=res.back(); + if(padsize>16||padsize<0)throw invalid_argument("Malformed AES padding"); + res.resize(res.size()-padsize); + return res; + } string encrypt(const string &key,const string &data,Algorithm algo){ - assert(key.size()==4+2*algo); - return encrypt(key,data,10+2*algo); + int increment; + switch(algo){ + case AES_128_CBC: increment=0; break; + case AES_192_CBC: increment=1; break; + case AES_256_CBC: increment=2; break; + default: assert(false); + } + assert((int)key.size()==4*(4+2*increment)); + return encryptCBC(key,data,10+2*increment); + } + + string decrypt(const string &key,const string &data,Algorithm algo){ + int increment; + switch(algo){ + case AES_128_CBC: increment=0; break; + case AES_192_CBC: increment=1; break; + case AES_256_CBC: increment=2; break; + default: assert(false); + } + assert((int)key.size()==4*(4+2*increment)); + return decryptCBC(key,data,10+2*increment); } - string decrypt(const string &key,const string &data,Algorithm algo)/*{ - assert(key.size()==4+2*algo); - return decrypt(key,data,10+2*algo); - }*/; + void test(){ + #if 0 + // Test encryption + initTables(); + const int numrounds=10; + uint8_t keysched[16*(numrounds+1)]; + uint8_t plaintext[16]={0x32,0x43,0xf6,0xa8,0x88,0x5a,0x30,0x8d,0x31,0x31,0x98,0xa2,0xe0,0x37,0x07,0x34}; + const int keylen=4; + uint8_t key[4*keylen]={0x2b,0x7e,0x15,0x16,0x28,0xae,0xd2,0xa6,0xab,0xf7,0x15,0x88,0x09,0xcf,0x4f,0x3c}; + keyExpand(keysched,key,keylen,numrounds); + uint8_t dest[16]; + encryptBlock(dest,keysched,plaintext,numrounds); + printstate(dest); + #endif + #if 1 + // Test decryption + initTables(); + const int numrounds=14; + uint8_t keysched[16*(numrounds+1)]; + uint8_t plaintext[16]={0x8e,0xa2,0xb7,0xca,0x51,0x67,0x45,0xbf,0xea,0xfc,0x49,0x90,0x4b,0x49,0x60,0x89}; + const int keylen=8; + uint8_t key[4*keylen]={0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f}; + keyExpand(keysched,key,keylen,numrounds); + uint8_t dest[16]; + decryptBlock(dest,keysched,plaintext,numrounds); + printstate(dest); + #endif + } } |