Skip to content

Commit

Permalink
Merge pull request #1021 from edubart/openssl-gmp
Browse files Browse the repository at this point in the history
Support for OpenSSL with a fallback to GMP
  • Loading branch information
iryont authored Mar 31, 2020
2 parents 066df4d + 88e271c commit 63f5351
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 155 deletions.
16 changes: 12 additions & 4 deletions src/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,25 +219,34 @@ endif()
message(STATUS "LuaJIT: " ${LUAJIT})

find_package(PhysFS REQUIRED)
find_package(OpenSSL REQUIRED)
find_package(ZLIB REQUIRED)

set(framework_LIBRARIES ${framework_LIBRARIES}
${Boost_LIBRARIES}
${LUA_LIBRARY}
${PHYSFS_LIBRARY}
${OPENSSL_LIBRARIES}
${ZLIB_LIBRARY}
)

set(framework_INCLUDE_DIRS ${framework_INCLUDE_DIRS}
${Boost_INCLUDE_DIRS}
${LUA_INCLUDE_DIR}
${PHYSFS_INCLUDE_DIR}
${OPENSSL_INCLUDE_DIR}
${framework_INCLUDE_DIRS}
)

find_package(OpenSSL QUIET)

if(NOT OPENSSL_FOUND)
find_package(GMP REQUIRED)
set(framework_LIBRARIES ${framework_LIBRARIES} ${GMP_LIBRARY})
set(framework_INCLUDE_DIRS ${framework_INCLUDE_DIRS} ${GMP_INCLUDE_DIR})
set(framework_DEFINITIONS ${framework_DEFINITIONS} -DUSE_GMP)
else()
set(framework_LIBRARIES ${framework_LIBRARIES} ${OPENSSL_LIBRARIES})
set(framework_INCLUDE_DIRS ${framework_INCLUDE_DIRS} ${OPENSSL_INCLUDE_DIR})
endif()

if(CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
message(STATUS "Debug information: ON")
else()
Expand Down Expand Up @@ -550,4 +559,3 @@ endif()

include_directories(${framework_INCLUDE_DIRS})
add_definitions(${framework_DEFINITIONS})

4 changes: 0 additions & 4 deletions src/framework/luafunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,8 @@ void Application::registerLuaFunctions()
g_lua.bindSingletonFunction("g_crypt", "getMachineUUID", &Crypt::getMachineUUID, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "encrypt", &Crypt::encrypt, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "decrypt", &Crypt::decrypt, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "sha1Encode", &Crypt::sha1Encode, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "md5Encode", &Crypt::md5Encode, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "rsaGenerateKey", &Crypt::rsaGenerateKey, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "rsaSetPublicKey", &Crypt::rsaSetPublicKey, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "rsaSetPrivateKey", &Crypt::rsaSetPrivateKey, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "rsaCheckKey", &Crypt::rsaCheckKey, &g_crypt);
g_lua.bindSingletonFunction("g_crypt", "rsaGetSize", &Crypt::rsaGetSize, &g_crypt);

// Clock
Expand Down
216 changes: 76 additions & 140 deletions src/framework/util/crypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@

#include <boost/functional/hash.hpp>

#ifndef USE_GMP
#include <openssl/rsa.h>
#include <openssl/sha.h>
#include <openssl/md5.h>
#include <openssl/bn.h>
#include <openssl/err.h>
#endif

static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static inline bool is_base64(unsigned char c) { return (isalnum(c) || (c == '+') || (c == '/')); }
Expand All @@ -45,12 +45,28 @@ Crypt g_crypt;

Crypt::Crypt()
{
#ifdef USE_GMP
mpz_init(m_p);
mpz_init(m_q);
mpz_init(m_d);
mpz_init(m_e);
mpz_init(m_n);
#else
m_rsa = RSA_new();
#endif
}

Crypt::~Crypt()
{
#ifdef USE_GMP
mpz_clear(m_p);
mpz_clear(m_q);
mpz_clear(m_n);
mpz_clear(m_d);
mpz_clear(m_e);
#else
RSA_free(m_rsa);
#endif
}

std::string Crypt::base64Encode(const std::string& decoded_string)
Expand Down Expand Up @@ -220,112 +236,12 @@ std::string Crypt::_decrypt(const std::string& encrypted_string, bool useMachine
return std::string();
}

std::string Crypt::md5Encode(const std::string& decoded_string, bool upperCase)
{
MD5_CTX c;
MD5_Init(&c);
MD5_Update(&c, decoded_string.c_str(), decoded_string.length());

uint8_t md[MD5_DIGEST_LENGTH];
MD5_Final(md, &c);

char output[(MD5_DIGEST_LENGTH << 1) + 1];
for(int32_t i = 0; i < (int32_t)sizeof(md); ++i)
sprintf(output + (i << 1), "%.2X", md[i]);

std::string result = output;
if(upperCase)
return result;

std::transform(result.begin(), result.end(), result.begin(), tolower);
return result;
}

std::string Crypt::sha1Encode(const std::string& decoded_string, bool upperCase)
{
SHA_CTX c;
SHA1_Init(&c);
SHA1_Update(&c, decoded_string.c_str(), decoded_string.length());

uint8_t md[SHA_DIGEST_LENGTH];
SHA1_Final(md, &c);

char output[(SHA_DIGEST_LENGTH << 1) + 1];
for(int32_t i = 0; i < (int32_t)sizeof(md); ++i)
sprintf(output + (i << 1), "%.2X", md[i]);

std::string result = output;
if(upperCase)
return result;

std::transform(result.begin(), result.end(), result.begin(), tolower);
return result;
}

std::string Crypt::sha256Encode(const std::string& decoded_string, bool upperCase)
{
SHA256_CTX c;
SHA256_Init(&c);
SHA256_Update(&c, decoded_string.c_str(), decoded_string.length());

uint8_t md[SHA256_DIGEST_LENGTH];
SHA256_Final(md, &c);

char output[(SHA256_DIGEST_LENGTH << 1) + 1];
for(int32_t i = 0; i < (int32_t)sizeof(md); ++i)
sprintf(output + (i << 1), "%.2X", md[i]);

std::string result = output;
if(upperCase)
return result;

std::transform(result.begin(), result.end(), result.begin(), tolower);
return result;
}

std::string Crypt::sha512Encode(const std::string& decoded_string, bool upperCase)
{
SHA512_CTX c;
SHA512_Init(&c);
SHA512_Update(&c, decoded_string.c_str(), decoded_string.length());

uint8_t md[SHA512_DIGEST_LENGTH];
SHA512_Final(md, &c);

char output[(SHA512_DIGEST_LENGTH << 1) + 1];
for(int32_t i = 0; i < (int32_t)sizeof(md); ++i)
sprintf(output + (i << 1), "%.2X", md[i]);

std::string result = output;
if(upperCase)
return result;

std::transform(result.begin(), result.end(), result.begin(), tolower);
return result;
}


void Crypt::rsaGenerateKey(int bits, int e)
{
// disabled because new OpenSSL changes broke
/*
RSA *rsa = RSA_new();
BIGNUM *ebn = BN_new();
BN_set_word(ebn, e);
RSA_generate_key_ex(rsa, bits, ebn, nullptr);
g_logger.info(stdext::format("%d bits (%d bytes) RSA key generated", bits, bits / 8));
g_logger.info(std::string("p = ") + BN_bn2dec(m_rsa->p));
g_logger.info(std::string("q = ") + BN_bn2dec(m_rsa->q));
g_logger.info(std::string("d = ") + BN_bn2dec(m_rsa->d));
g_logger.info(std::string("n = ") + BN_bn2dec(m_rsa->n));
g_logger.info(std::string("e = ") + BN_bn2dec(m_rsa->e));
BN_clear_free(ebn);
RSA_free(rsa);
*/
}

void Crypt::rsaSetPublicKey(const std::string& n, const std::string& e)
{
#ifdef USE_GMP
mpz_set_str(m_n, n.c_str(), 10);
mpz_set_str(m_e, e.c_str(), 10);
#else
#if OPENSSL_VERSION_NUMBER < 0x10100005L
BN_dec2bn(&m_rsa->n, n.c_str());
BN_dec2bn(&m_rsa->e, e.c_str());
Expand All @@ -340,10 +256,19 @@ void Crypt::rsaSetPublicKey(const std::string& n, const std::string& e)
BN_dec2bn(&be, e.c_str());
RSA_set0_key(m_rsa, bn, be, nullptr);
#endif
#endif
}

void Crypt::rsaSetPrivateKey(const std::string& p, const std::string& q, const std::string& d)
{
#ifdef USE_GMP
mpz_set_str(m_p, p.c_str(), 10);
mpz_set_str(m_q, q.c_str(), 10);
mpz_set_str(m_d, d.c_str(), 10);

// n = p * q
mpz_mul(n, p, q);
#else
#if OPENSSL_VERSION_NUMBER < 0x10100005L
BN_dec2bn(&m_rsa->p, p.c_str());
BN_dec2bn(&m_rsa->q, q.c_str());
Expand All @@ -365,58 +290,69 @@ void Crypt::rsaSetPrivateKey(const std::string& p, const std::string& q, const s
RSA_set0_key(m_rsa, nullptr, nullptr, bd);
RSA_set0_factors(m_rsa, bp, bq);
#endif
#endif
}

bool Crypt::rsaCheckKey()
bool Crypt::rsaEncrypt(unsigned char *msg, int size)
{
// only used by server, that sets both public and private
if(RSA_check_key(m_rsa)) {
BN_CTX *ctx = BN_CTX_new();
BN_CTX_start(ctx);
if(size != rsaGetSize())
return false;

BIGNUM *r1 = BN_CTX_get(ctx), *r2 = BN_CTX_get(ctx);
#if OPENSSL_VERSION_NUMBER < 0x10100005L
BN_mod(m_rsa->dmp1, m_rsa->d, r1, ctx);
BN_mod(m_rsa->dmq1, m_rsa->d, r2, ctx);
BN_mod_inverse(m_rsa->iqmp, m_rsa->q, m_rsa->p, ctx);
#else
const BIGNUM *dmp1_c = nullptr, *d = nullptr, *dmq1_c = nullptr, *iqmp_c = nullptr, *q = nullptr, *p = nullptr;
#ifdef USE_GMP
mpz_t c, m;
mpz_init(c);
mpz_init(m);
mpz_import(m, size, 1, 1, 0, 0, msg);

RSA_get0_key(m_rsa, nullptr, nullptr, &d);
RSA_get0_factors(m_rsa, &p, &q);
RSA_get0_crt_params(m_rsa, &dmp1_c, &dmq1_c, &iqmp_c);
// c = m^e mod n
mpz_powm(c, m, m_e, m_n);

BIGNUM *dmp1 = BN_dup(dmp1_c), *dmq1 = BN_dup(dmq1_c), *iqmp = BN_dup(iqmp_c);
size_t count = (mpz_sizeinbase(m, 2) + 7) / 8;
memset((char*)msg, 0, size - count);
mpz_export((char*)msg + (size - count), nullptr, 1, 1, 0, 0, c);

BN_mod(dmp1, d, r1, ctx);
BN_mod(dmq1, d, r2, ctx);
BN_mod_inverse(iqmp, q, p, ctx);
RSA_set0_crt_params(m_rsa, dmp1, dmq1, iqmp);
#endif
return true;
}
else {
ERR_load_crypto_strings();
g_logger.error(stdext::format("RSA check failed - %s", ERR_error_string(ERR_get_error(), nullptr)));
return false;
}
}
mpz_clear(c);
mpz_clear(m);

bool Crypt::rsaEncrypt(unsigned char *msg, int size)
{
if(size != RSA_size(m_rsa))
return false;
return true;
#else
return RSA_public_encrypt(size, msg, msg, m_rsa, RSA_NO_PADDING) != -1;
#endif
}

bool Crypt::rsaDecrypt(unsigned char *msg, int size)
{
if(size != RSA_size(m_rsa))
if(size != rsaGetSize())
return false;

#ifdef USE_GMP
mpz_t c, m;
mpz_init(c);
mpz_init(m);
mpz_import(c, size, 1, 1, 0, 0, msg);

// m = c^d mod n
mpz_powm(m, c, m_d, m_n);

size_t count = (mpz_sizeinbase(m, 2) + 7) / 8;
memset((char*)msg, 0, size - count);
mpz_export((char*)msg + (size - count), nullptr, 1, 1, 0, 0, m);

mpz_clear(c);
mpz_clear(m);

return true;
#else
return RSA_private_decrypt(size, msg, msg, m_rsa, RSA_NO_PADDING) != -1;
#endif
}

int Crypt::rsaGetSize()
{
#ifdef USE_GMP
size_t count = (mpz_sizeinbase(m_n, 2) + 7) / 8;
return ((int)count / 128) * 128;
#else
return RSA_size(m_rsa);
#endif
}
15 changes: 8 additions & 7 deletions src/framework/util/crypt.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
#include <string>

#include <boost/uuid/uuid.hpp>

#ifdef USE_GMP
#include <gmp.h>
#else
typedef struct rsa_st RSA;
#endif

class Crypt
{
Expand All @@ -44,15 +47,9 @@ class Crypt
std::string genUUID();
bool setMachineUUID(std::string uuidstr);
std::string getMachineUUID();
std::string md5Encode(const std::string& decoded_string, bool upperCase);
std::string sha1Encode(const std::string& decoded_string, bool upperCase);
std::string sha256Encode(const std::string& decoded_string, bool upperCase);
std::string sha512Encode(const std::string& decoded_string, bool upperCase);

void rsaGenerateKey(int bits, int e);
void rsaSetPublicKey(const std::string& n, const std::string& e);
void rsaSetPrivateKey(const std::string &p, const std::string &q, const std::string &d);
bool rsaCheckKey();
bool rsaEncrypt(unsigned char *msg, int size);
bool rsaDecrypt(unsigned char *msg, int size);
int rsaGetSize();
Expand All @@ -62,7 +59,11 @@ class Crypt
std::string _decrypt(const std::string& encrypted_string, bool useMachineUUID);
std::string getCryptKey(bool useMachineUUID);
boost::uuids::uuid m_machineUUID;
#ifdef USE_GMP
mpz_t m_p, m_q, m_n, m_e, m_d;
#else
RSA *m_rsa;
#endif
};

extern Crypt g_crypt;
Expand Down

0 comments on commit 63f5351

Please sign in to comment.