/*	benchmark rsa operation for ctaocrypt and openssl
         gcc -g -Wall -o benchrsa benchrsa.c -L/home/julien/Code/cyassl/cyassl-rc2-1.0.0/src/.libs/ -I/home/julien/Code/cyassl/cyassl-rc2-1.0.0/ctaocrypt/include/ -lcyassl -lm -DUSE_FAST_MATH -lssl

*/

#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <openssl/err.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <sys/time.h>

#include "asn.h"
#include "rsa.h"
#include "random.h"
#include "coding.h"
#include "tfm.h"

typedef struct buffer {
	word32 length;
	byte*  buffer;
} buffer;

static int PemToDer(const char* fileName, buffer* der);

int main (int argc, char *argv[])
{
	if (argc<2)
		printf("\nusage : ./benchrsa <pubkey> <privkey> <datasize in bytes>\n\nex: ./benchrsa pub.key priv.key 512");

	// for openssl
	RSA     *rsaPrivKey = RSA_new();
	RSA     *rsaPubKey = RSA_new();

	// for ctaocrypt
	RsaKey key;
	RNG rng;
	buffer der;
	der.buffer = 0;
	word32 idx = 0;

	// for everybody
	int datasize = atoi(argv[3]);
	unsigned char InputData[datasize+1];

        /*unsigned char InputData[] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
        int size = 501;
        */
        int OpenSSLPublicEncryptSize = 0;
        int OpenSSLPrivateDecryptSize = 0;

        int CTAOCRYPTPublicEncryptSize = 0;
        int CTAOCRYPTPrivateDecryptSize = 0;

	struct timeval start;
	struct timeval current;
	float computetime;


	ERR_load_crypto_strings();

        // OPENSSL
	// open public key
	FILE  *pubKey = fopen(argv[1], "r");
	// read public key
	rsaPubKey = PEM_read_RSA_PUBKEY(pubKey, &rsaPubKey, NULL, NULL);
	fclose(pubKey);

	// open private key
	FILE  *privKey = fopen(argv[2], "r");
	//read private key
	rsaPrivKey = PEM_read_RSAPrivateKey(privKey, &rsaPrivKey, NULL,NULL);
	fclose(privKey);

	// CTAOCRYPT
        // convert private key into DER
	PemToDer(argv[2], &der);
        // and then get both public and private keys from it
	InitRsaKey(&key, NULL);
	RsaPrivateKeyDecode(der.buffer, &idx, &key, der.length);

        //also init the random number generator
        InitRng(&rng);

        /* mp_unsigned_bin_size(&key.n) return the size of the rsa modulus in bytes
         * --> CTAOCRYPT function
         */
        int rsakeysize = mp_unsigned_bin_size(&key.n);

        printf("%d bytes (%d bits) rsa modulus loaded\n",rsakeysize, rsakeysize*8);

        //key lengths should be equals between OpenSSL and CTAOCRYPT
        if(rsakeysize != RSA_size(rsaPrivKey))
           {
              printf("Keys lengths are not equal between OpenSSL and CTAOCRYPT....\n");
              return -1;
           }

        /* RSA requires datasize to be smaller than modulus
        * using PKCS1_PADDING, data block must be RSA_PKCS1_PADDING_SIZE bytes smaller than modulus
         */
        if(datasize > rsakeysize)
           {
              printf("Size must be %d bytes at max\n",RSA_size(rsaPrivKey));
              return -1;
           }

	/* gen random data
	 *
	 * note : when using padding RSA_PKCS1_PADDING, data must be 11 bytes smaller than the modulus
	 */
	FILE *random = fopen("/dev/urandom", "r");
	int size = 0;

	while (1)
	{
                // get "datasize - RSA_PKCS1_PADDING_SIZE" elements of size 1 byte, 11 bytes are left for padding
                //size_t fread (void *ptr, size_t size, size_t nmemb, FILE *stream);
                size = fread(InputData, 1, datasize - RSA_PKCS1_PADDING_SIZE, random);
                if (size == datasize - RSA_PKCS1_PADDING_SIZE)
			break;
	}

	fclose(random);

	printf("%d bytes of data loaded from PRNG\n",size);


        /* RSA_size(rsaPubKey) return the size of the public rsa modulus in bytes
         * --> OPENSSL function
        */
        unsigned char *WorkBuffer = malloc(RSA_size(rsaPubKey));
        unsigned char *OutputData = malloc(size);
	int i;

        memset(WorkBuffer,0,RSA_size(rsaPubKey));
        memset(OutputData,0,size);

#define ITERATION 100

         /******************** OpenSSL ******************/

	// Openssl public key Encryption
	gettimeofday(&start, NULL);

	for (i=0; i<ITERATION; i++)

           // see http://www.openssl.org/docs/crypto/RSA_public_encrypt.html
           OpenSSLPublicEncryptSize = RSA_public_encrypt(size, InputData, WorkBuffer, rsaPubKey, RSA_PKCS1_PADDING);


        gettimeofday(&current, NULL);

	computetime = (current.tv_sec * 1000000 + current.tv_usec) -
	              (start.tv_sec * 1000000 + start.tv_usec);

        printf("OPENSSL RSA_public_encrypt (%d bytes)\t%4.6f s over %d iterations\n", OpenSSLPublicEncryptSize, computetime / (1000000*ITERATION), ITERATION);
        //printf("%4.6f; ", computetime / (1000000*ITERATION));

  	// Openssl private key Decryption
	gettimeofday(&start, NULL);

	for (i=0; i<ITERATION; i++)
           OpenSSLPrivateDecryptSize = RSA_private_decrypt(RSA_size(rsaPrivKey), WorkBuffer, OutputData, rsaPrivKey, RSA_PKCS1_PADDING);

        gettimeofday(&current, NULL);

	computetime = (current.tv_sec * 1000000 + current.tv_usec) -
	              (start.tv_sec * 1000000 + start.tv_usec);

         printf("OPENSSL RSA_private_decrypt (%d bytes)\t%4.6f s over %d iterations\n",OpenSSLPrivateDecryptSize, computetime / (1000000*ITERATION), ITERATION);
        //printf("%4.6f; ", computetime / (1000000*ITERATION));

	// verify
        printf("Compare original to deciphered -> ");
        if(memcmp(InputData, OutputData, size)==0)
		printf("OK\n");
	else
           {
              printf("FAILED\n");
              printf("InputData buffer state \n------\n%s\n------\n",InputData);
              printf("OutputData buffer state \n------\n%s\n------\n",OutputData);
           }



        /******************** CTAOCrypt ******************/

        //reinit the buffers
        memset(WorkBuffer,0,rsakeysize);
        memset(OutputData,0,size);

	// Ctaocrypt public key encryption
	gettimeofday(&start, NULL);

	for (i=0; i<ITERATION; i++)
                // Prototype :
                //int RsaPublicEncrypt(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key, RNG* rng)
                CTAOCRYPTPublicEncryptSize = RsaPublicEncrypt(InputData, size, WorkBuffer, rsakeysize, &key, &rng);

        gettimeofday(&current, NULL);

  	computetime = (current.tv_sec * 1000000 + current.tv_usec) -
	              (start.tv_sec * 1000000 + start.tv_usec);
        printf("CTAOCRYPT RsaPublicEncrypt (%d bytes)\t%4.6f s over %d iterations\n",CTAOCRYPTPublicEncryptSize, computetime / (1000000*ITERATION), ITERATION);

	// CTAOCRYPT Decryption
	gettimeofday(&start, NULL);

        for (i=0; i<ITERATION; i++)
                //Prototype :
                //int RsaPrivateDecrypt(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key)
                CTAOCRYPTPrivateDecryptSize = RsaPrivateDecrypt(WorkBuffer, rsakeysize, OutputData, size, &key);

        gettimeofday(&current, NULL);

	computetime = (current.tv_sec * 1000000 + current.tv_usec) -
	              (start.tv_sec * 1000000 + start.tv_usec);
        printf("CTAOCRYPT RsaPrivateDecrypt (%d bytes)\t%4.6f s over %d iterations\n",CTAOCRYPTPrivateDecryptSize, computetime / (1000000*ITERATION), ITERATION);

	// verify
        printf("Compare original to deciphered -> ");
        if(memcmp(InputData, OutputData, size)==0)
           printf("OK\n");
        else
        {
           printf("FAILED\n");
           printf("InputData buffer state \n------\n%s\n------\n",InputData);
           printf("OutputData buffer state \n------\n%s\n------\n",OutputData);
        }

        printf("\n--- COMPATIBILITY TESTS ---\n");

        printf("OpenSSL RSA_public_encrypt, CTAOCRYPT RsaPrivateDecrypt....");
        RSA_public_encrypt(size, InputData, WorkBuffer, rsaPubKey, RSA_PKCS1_PADDING);
        RsaPrivateDecrypt(WorkBuffer, rsakeysize, OutputData, size, &key);

        if(memcmp(InputData, OutputData, size)==0)
           printf("OK\n");
        else
           printf("NOK\n");

        printf("CTAOCRYPT RsaPublicEncrypt, OpenSSL RSA_private_decrypt....");
        RsaPublicEncrypt(InputData, size, WorkBuffer, datasize, &key, &rng);
        RSA_private_decrypt(datasize, WorkBuffer, OutputData, rsaPrivKey, RSA_PKCS1_PADDING);

        if(memcmp(InputData, OutputData, size)==0)
           printf("OK\n");
        else
           printf("NOK\n");

        //OpenSSL
	RSA_free(rsaPubKey);
	RSA_free(rsaPrivKey);

	return 0;
}

static int PemToDer(const char* fileName, buffer* der)
{
	long   begin	= -1;
	long   end	  =  0;
	int	foundEnd =  0;
	int	ret	  =  0;
	word32 sz	   =  0;

	char  line[80];
	char  header[80];
	char  footer[80];

	FILE* file;
	byte* tmp = 0;

	strncpy(header, "-----BEGIN RSA PRIVATE KEY-----", sizeof(header));
	strncpy(footer, "-----END RSA PRIVATE KEY-----", sizeof(header));

	file = fopen(fileName, "rb");

	while(fgets(line, sizeof(line), file))
	if (strncmp(header, line, strlen(header)) == 0) {
		begin = ftell(file);
		break;
	}

	while(fgets(line, sizeof(line), file))
	if (strncmp(footer, line, strlen(footer)) == 0) {
		foundEnd = 1;
		break;
	}
	else
		end = ftell(file);

	if (begin == -1 || !foundEnd) {
		fclose(file);
		return -1;
	}

	sz = end - begin;
	tmp = (byte*) malloc(sz);
	if (!tmp) {
		fclose(file);
		return -1;
	}

	fseek(file, begin, SEEK_SET);
	if (fread(tmp, sz, 1, file) != 1 ||
	    (der->buffer = (byte*) malloc(sz)) == 0) {
		free(tmp);
		fclose(file);
		return -1;
	}

	der->length = sz;
	if (Base64Decode(tmp, sz, der->buffer, &der->length) < 0)
		ret = -1;

	free(tmp);
	fclose(file);

	return ret;
}

