From 241518f369efce64046be36a15fcb722b00e9477 Mon Sep 17 00:00:00 2001 From: tomsmeding Date: Fri, 7 Oct 2016 20:33:18 +0200 Subject: Working AES encrypt and decrypt! --- aes.cpp | 265 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++------- aes.h | 10 ++- gf28.cpp | 119 ++++++++++++++++++++++++++++ gf28.h | 39 ++++++++++ main.cpp | 9 ++- rng.cpp | 4 +- rng.h | 2 +- rsa.cpp | 2 +- 8 files changed, 413 insertions(+), 37 deletions(-) create mode 100644 gf28.cpp create mode 100644 gf28.h diff --git a/aes.cpp b/aes.cpp index 13659e6..e226a3c 100644 --- a/aes.cpp +++ b/aes.cpp @@ -1,79 +1,288 @@ +#include +#include +#include #include #include #include "aes.h" +#include "gf28.h" +#include "rng.h" using namespace std; namespace AES{ - //http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf + //http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf (AES) + //https://tools.ietf.org/html/rfc3602 (CBC) - void subWord(uint8_t *word); + //State is represented in bytes, as is the key schedule (which is represented in words in the AES spec). + //State is in column-major order. - void keyExpand(uint8_t *keysched,const uint8_t *key,int keylen,int numrounds); + uint32_t roundconstant[10]={}; + uint8_t sbox[256]={}; + uint8_t invsbox[256]={}; - void addRoundKey(uint8_t *state,const uint8_t *roundkey); + void printstate(const uint8_t *state){ + for(int i=0;i<16;i++){ + cout<>j)&1)^ + ((inv>>((j+4)%8))&1)^ + ((inv>>((j+5)%8))&1)^ + ((inv>>((j+6)%8))&1)^ + ((inv>>((j+7)%8))&1); + res|=bit<>24]<<24)|(sbox[(word>>16)&0xff]<<16)|(sbox[(word>>8)&0xff]<<8)|sbox[word&0xff]; + } + + uint32_t rotWord(uint32_t word){ + return (word<<8)|(word>>24); + } + + void keyExpand(uint8_t *keysched,const uint8_t *key,int keylen,int numrounds){ + memcpy(keysched,key,4*keylen); + for(int i=keylen;i<4*(numrounds+1);i++){ + uint32_t temp=(keysched[4*i-4]<<24)|(keysched[4*i-3]<<16)|(keysched[4*i-2]<<8)|keysched[4*i-1]; + if(i%keylen==0){ + temp=subWord(rotWord(temp))^roundconstant[i/keylen-1]; + } else if(keylen>6&&i%keylen==4){ + temp=subWord(temp); + } + keysched[4*i+0]=keysched[4*(i-keylen)+0]^(temp>>24); + keysched[4*i+1]=keysched[4*(i-keylen)+1]^((temp>>16)&0xff); + keysched[4*i+2]=keysched[4*(i-keylen)+2]^((temp>>8)&0xff); + keysched[4*i+3]=keysched[4*(i-keylen)+3]^(temp&0xff); + } + } + + void addRoundKey(uint8_t *state,const uint8_t *roundkey){ + for(int i=0;i<16;i++)state[i]^=roundkey[i]; + } void subBytes(uint8_t *state){ - subWord(state); - subWord(state+4); - subWord(state+8); - subWord(state+12); + for(int i=0;i<16;i++)state[i]=sbox[state[i]]; } - void shiftRows(uint8_t *state); + void shiftRows(uint8_t *state){ + uint8_t t=state[1]; state[1]=state[5]; state[5]=state[9]; state[9]=state[13]; state[13]=t; + swap(state[2],state[10]); swap(state[6],state[14]); + t=state[3]; state[3]=state[15]; state[15]=state[11]; state[11]=state[7]; state[7]=t; + } - void mixColumns(uint8_t *state); + void mixColumns(uint8_t *state){ + for(int i=0;i<4;i++){ + uint8_t a=GF28::multiply(0x02,state[4*i+0]) ^ GF28::multiply(0x03,state[4*i+1]) ^ state[4*i+2] ^ state[4*i+3]; + uint8_t b=state[4*i+0] ^ GF28::multiply(0x02,state[4*i+1]) ^ GF28::multiply(0x03,state[4*i+2]) ^ state[4*i+3]; + uint8_t c=state[4*i+0] ^ state[4*i+1] ^ GF28::multiply(0x02,state[4*i+2]) ^ GF28::multiply(0x03,state[4*i+3]); + uint8_t d=GF28::multiply(0x03,state[4*i+0]) ^ state[4*i+1] ^ state[4*i+2] ^ GF28::multiply(0x02,state[4*i+3]); + state[4*i+0]=a; + state[4*i+1]=b; + state[4*i+2]=c; + state[4*i+3]=d; + } + } + void invShiftRows(uint8_t *state){ + uint8_t t=state[1]; state[1]=state[13]; state[13]=state[9]; state[9]=state[5]; state[5]=t; + swap(state[2],state[10]); swap(state[6],state[14]); + t=state[3]; state[3]=state[7]; state[7]=state[11]; state[11]=state[15]; state[15]=t; + } - void encryptBlock(uint8_t *state,const uint8_t *key,const uint8_t *data,int keylen,int numrounds){ - uint8_t keysched[16*(numrounds+1)]; - keyExpand(keysched,key,keylen,numrounds); + void invSubBytes(uint8_t *state){ + for(int i=0;i<16;i++)state[i]=invsbox[state[i]]; + } + + void invMixColumns(uint8_t *state){ + for(int i=0;i<4;i++){ + uint8_t a=GF28::multiply(0x0e,state[4*i+0])^ + GF28::multiply(0x0b,state[4*i+1])^ + GF28::multiply(0x0d,state[4*i+2])^ + GF28::multiply(0x09,state[4*i+3]); + uint8_t b=GF28::multiply(0x09,state[4*i+0])^ + GF28::multiply(0x0e,state[4*i+1])^ + GF28::multiply(0x0b,state[4*i+2])^ + GF28::multiply(0x0d,state[4*i+3]); + uint8_t c=GF28::multiply(0x0d,state[4*i+0])^ + GF28::multiply(0x09,state[4*i+1])^ + GF28::multiply(0x0e,state[4*i+2])^ + GF28::multiply(0x0b,state[4*i+3]); + uint8_t d=GF28::multiply(0x0b,state[4*i+0])^ + GF28::multiply(0x0d,state[4*i+1])^ + GF28::multiply(0x09,state[4*i+2])^ + GF28::multiply(0x0e,state[4*i+3]); + state[4*i+0]=a; + state[4*i+1]=b; + state[4*i+2]=c; + state[4*i+3]=d; + } + } + + + void encryptBlock(uint8_t *state,const uint8_t *keysched,const uint8_t *data,int numrounds){ memcpy(state,data,16); addRoundKey(state,keysched); for(int round=0;round=0;round--){ + //cout<<"round["<=0;i--){ + decryptBlock((uint8_t*)&res[16*i],keysched,(const uint8_t*)data.data()+(16+16*i),numrounds); + for(int j=0;j<16;j++)res[16*i+j]^=data[16*i+j]; //CBC: xor with the previous (remember the IV taking up space) + } + int padsize=res.back(); + if(padsize>16||padsize<0)throw invalid_argument("Malformed AES padding"); + res.resize(res.size()-padsize); + return res; + } string encrypt(const string &key,const string &data,Algorithm algo){ - assert(key.size()==4+2*algo); - return encrypt(key,data,10+2*algo); + int increment; + switch(algo){ + case AES_128_CBC: increment=0; break; + case AES_192_CBC: increment=1; break; + case AES_256_CBC: increment=2; break; + default: assert(false); + } + assert((int)key.size()==4*(4+2*increment)); + return encryptCBC(key,data,10+2*increment); + } + + string decrypt(const string &key,const string &data,Algorithm algo){ + int increment; + switch(algo){ + case AES_128_CBC: increment=0; break; + case AES_192_CBC: increment=1; break; + case AES_256_CBC: increment=2; break; + default: assert(false); + } + assert((int)key.size()==4*(4+2*increment)); + return decryptCBC(key,data,10+2*increment); } - string decrypt(const string &key,const string &data,Algorithm algo)/*{ - assert(key.size()==4+2*algo); - return decrypt(key,data,10+2*algo); - }*/; + void test(){ + #if 0 + // Test encryption + initTables(); + const int numrounds=10; + uint8_t keysched[16*(numrounds+1)]; + uint8_t plaintext[16]={0x32,0x43,0xf6,0xa8,0x88,0x5a,0x30,0x8d,0x31,0x31,0x98,0xa2,0xe0,0x37,0x07,0x34}; + const int keylen=4; + uint8_t key[4*keylen]={0x2b,0x7e,0x15,0x16,0x28,0xae,0xd2,0xa6,0xab,0xf7,0x15,0x88,0x09,0xcf,0x4f,0x3c}; + keyExpand(keysched,key,keylen,numrounds); + uint8_t dest[16]; + encryptBlock(dest,keysched,plaintext,numrounds); + printstate(dest); + #endif + #if 1 + // Test decryption + initTables(); + const int numrounds=14; + uint8_t keysched[16*(numrounds+1)]; + uint8_t plaintext[16]={0x8e,0xa2,0xb7,0xca,0x51,0x67,0x45,0xbf,0xea,0xfc,0x49,0x90,0x4b,0x49,0x60,0x89}; + const int keylen=8; + uint8_t key[4*keylen]={0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f}; + keyExpand(keysched,key,keylen,numrounds); + uint8_t dest[16]; + decryptBlock(dest,keysched,plaintext,numrounds); + printstate(dest); + #endif + } } diff --git a/aes.h b/aes.h index 012c9ed..b2064f2 100644 --- a/aes.h +++ b/aes.h @@ -5,12 +5,16 @@ namespace AES{ enum Algorithm{ - AES_128, - AES_192, - AES_256, + AES_128_CBC, + AES_192_CBC, + AES_256_CBC, }; std::string encrypt(const std::string &key,const std::string &data,Algorithm algo); + + //throws invalid_argument for an invalid ciphertext (length not a multiple of block size, or padding malformed) std::string decrypt(const std::string &key,const std::string &data,Algorithm algo); + void test(); + } diff --git a/gf28.cpp b/gf28.cpp new file mode 100644 index 0000000..e65c20a --- /dev/null +++ b/gf28.cpp @@ -0,0 +1,119 @@ +#include +#include "gf28.h" + +using namespace std; + +int GF28::reduce(int v,int m){ + assert(m); + while(true){ + int sh=__builtin_clz(m)-__builtin_clz(v); + if(sh<0)break; + v^=m<>=1; + } + return res; +} + +GF28::GF28() + :value(0){} + +GF28::GF28(int v) + :value(reduce(v,modulus)){} + +GF28::operator uint8_t() const { + return value; +} + +GF28& GF28::operator+=(GF28 o){ + value^=o.value; + return *this; +} + +GF28& GF28::operator-=(GF28 o){ + value^=o.value; + return *this; +} + +GF28& GF28::operator<<=(int n){ //multiplication by x^n + assert(n>=0); + value<<=n; + if(value&0x100)value^=modulus; + return *this; +} + +GF28 GF28::operator+(GF28 o) const { + return GF28(value^o.value); +} + +GF28 GF28::operator-(GF28 o) const { + return GF28(value^o.value); +} + +GF28 GF28::operator*(GF28 o) const { + if(value==0||o.value==0)return GF28(0); + GF28 res; + GF28 addend(*this); + while(o.value){ + if(o.value&1)res+=addend; + addend<<=1; + o.value>>=1; + } + return res; +} +GF28 GF28::operator<<(int n) const { + return GF28(*this)<<=n; +} + +bool GF28::operator==(GF28 o) const { + return value==o.value; +} + +GF28 GF28::inverse() const { + if(value==0)return *this; + int x=1,y=0,x2=0,y2=1,r=modulus,r2=value; + while(r2!=0){ + assert(r!=0); + int ex=__builtin_clz(r2)-__builtin_clz(r); + if(ex<0){ + swap(x,x2); + swap(y,y2); + swap(r,r2); + } else { + int xn=x^(x2<>=1,i--){ + if(p.value&m){ + if(!first)os<<'+'; + first=false; + if(i==0)os<<'1'; + else os<<"x^"< +#include + +//The GF(2^8) field used in AES + +class GF28{ + int value; + + static int reduce(int v,int m); + +public: + static const int modulus=0x11b; + + static uint8_t multiply(uint8_t x,uint8_t y); //for when the class is overkill + + GF28(); + explicit GF28(int v); + + explicit operator uint8_t() const; + + GF28& operator+=(GF28 o); + GF28& operator-=(GF28 o); + GF28& operator<<=(int n); //multiplication by x^n + + GF28 operator+(GF28 o) const; + GF28 operator-(GF28 o) const; + GF28 operator*(GF28 o) const; + GF28 operator<<(int n) const; //multiplication by x^n + + bool operator==(GF28 o) const; + + GF28 inverse() const; + + friend std::ostream& operator<<(std::ostream&,GF28); +}; + +std::ostream& operator<<(std::ostream &os,GF28 p); diff --git a/main.cpp b/main.cpp index f818c48..2cd31ec 100644 --- a/main.cpp +++ b/main.cpp @@ -6,6 +6,7 @@ #include #include #include +#include "aes.h" #include "base64.h" #include "bigint.h" #include "numalgo.h" @@ -211,12 +212,16 @@ int main(int argc,char **argv){ fwrite(data,1,4,stdout); }*/ - string s; + /*string s; while(true){ char c=cin.get(); if(!cin)break; s.push_back(c); } cout<