Skip to content

Commit

Permalink
ML-KEM encapsulation key modulus check (#1868)
Browse files Browse the repository at this point in the history
ML-KEM encapsulation key modulus check as specified in Section 7.2 
of FIPS 203: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf
  • Loading branch information
dkostic authored Sep 23, 2024
1 parent 4916fe8 commit 2835116
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 20 deletions.
39 changes: 39 additions & 0 deletions crypto/evp_extra/evp_extra_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2879,3 +2879,42 @@ TEST_P(PerKEMTest, EncapsSeedTest) {
ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len, es.data(), &es_len));
EXPECT_EQ(EVP_R_INVALID_PARAMETERS, ERR_GET_REASON(ERR_peek_last_error()));
}

static const struct KnownKEM kMLKEMs[] = {
{"MLKEM512", NID_MLKEM512, 800, 1632, 768, 32, 64, 32, "fipsmodule/ml_kem/kat/mlkem512.txt"},
{"MLKEM768", NID_MLKEM768, 1184, 2400, 1088, 32, 64, 32, "fipsmodule/ml_kem/kat/mlkem768.txt"},
{"MLKEM1024", NID_MLKEM1024, 1568, 3168, 1568, 32, 64, 32, "fipsmodule/ml_kem/kat/mlkem1024.txt"},
};

class PerMLKEMTest : public testing::TestWithParam<KnownKEM> {};

INSTANTIATE_TEST_SUITE_P(All, PerMLKEMTest, testing::ValuesIn(kMLKEMs),
[](const testing::TestParamInfo<KnownKEM> &params)
-> std::string { return params.param.name; });

TEST_P(PerMLKEMTest, InputValidation) {
// ---- 1. Setup phase: generate a context and a key ----
bssl::UniquePtr<EVP_PKEY_CTX> ctx;
ctx = setup_ctx_and_generate_key(GetParam().nid, nullptr, nullptr);
ASSERT_TRUE(ctx);

// ---- 2. Test basic encapsulation flow ----
// Alloc ciphertext and shared secret with the expected lengths.
size_t ct_len = GetParam().ciphertext_len;
size_t ss_len = GetParam().shared_secret_len;
std::vector<uint8_t> ct(ct_len);
std::vector<uint8_t> ss(ss_len);

// Encapsulate.
ASSERT_TRUE(
EVP_PKEY_encapsulate(ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len));

// ---- 3. Test invalid public key ----
// FIPS 203 Section 7.2 Encapsulation key check (Modulus check).
// Invalidate the key by forcing a coefficient out of range.
ctx->pkey->pkey.kem_key->public_key[0] = 0xff;
ctx->pkey->pkey.kem_key->public_key[1] = 0xff;

ASSERT_FALSE(
EVP_PKEY_encapsulate(ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len));
}
103 changes: 102 additions & 1 deletion crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "kem.h"
#include "indcpa.h"
#include "verify.h"
#include "reduce.h"
#include "symmetric.h"
#include "openssl/rand.h"

Expand Down Expand Up @@ -59,6 +60,102 @@ int crypto_kem_keypair(ml_kem_params *params,
return 0;
}

// FIPS 203. Algorithm 3 BitsToBytes
// Converts a bit array (of a length that is a multiple of eight)
// into an array of bytes.
static void bits_to_bytes(uint8_t *bytes, size_t num_bytes,
const uint8_t *bits, size_t num_bits) {
assert(num_bits == num_bytes * 8);

for (size_t i = 0; i < num_bytes; i++) {
uint8_t byte = 0;
for (size_t j = 0; j < 8; j++) {
byte |= (bits[i * 8 + j] << j);
}
bytes[i] = byte;
}
}

// FIPS 203. Algorithm 4 BytesToBits
// Performs the inverse of BitsToBytes, converting a byte array into a bit array.
static void bytes_to_bits(uint8_t *bits, size_t num_bits,
const uint8_t *bytes, size_t num_bytes) {
assert(num_bits == num_bytes * 8);

for (size_t i = 0; i < num_bytes; i++) {
uint8_t byte = bytes[i];
for (size_t j = 0; j < 8; j++) {
bits[i * 8 + j] = (byte >> j) & 1;
}
}
}

#define BYTE_ENCODE_12_IN_SIZE (256)
#define BYTE_ENCODE_12_OUT_SIZE (32 * 12)
#define BYTE_ENCODE_12_NUM_BITS (256 * 12)

// FIPS 203. Algorithm 5 ByteEncode_12
// Encodes an array of 256 12-bit integers into a byte array.
static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE],
const int16_t in[BYTE_ENCODE_12_IN_SIZE]) {
uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0};
for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) {
int16_t a = in[i];
for (size_t j = 0; j < 12; j++) {
bits[i * 12 + j] = a & 1;
a = a >> 1;
}
}
bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS);
}

// Converts a centered representative |in| which is an integer in
// {-(q-1)/2, ..., (q-1)/2}, to a positive representative in {0, ..., q-1}.
// It implements in constant-time the following operation:
// return (in < 0) ? in + KYBER_Q : in;
static int16_t centered_to_positive_representative(int16_t in) {
// mask = (in < 0) ? b11..11 : b00..00;
crypto_word_t mask = constant_time_is_zero_w(in >> 15);
int16_t in_fixed = in + KYBER_Q;
return constant_time_select_int(mask, in, in_fixed);
}

#define BYTE_DECODE_12_OUT_SIZE (256)
#define BYTE_DECODE_12_IN_SIZE (32 * 12)
#define BYTE_DECODE_12_NUM_BITS (256 * 12)

// FIPS 203. Algorithm 5 ByteDecode_12
// Decodes a byte array into an array of 256 12-bit integers.
static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE],
const uint8_t in[BYTE_DECODE_12_IN_SIZE]) {
uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0};
bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE);
for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) {
int16_t val = 0;
for (size_t j = 0; j < 12; j++) {
val |= bits[i * 12 + j] << j;
}
out[i] = centered_to_positive_representative(barrett_reduce(val));
}
}

#define ENCAPS_KEY_ENCODED_MAX_SIZE (BYTE_ENCODE_12_OUT_SIZE * KYBER_K_MAX)
#define ENCAPS_KEY_DECODED_MAX_SIZE (BYTE_DECODE_12_OUT_SIZE * KYBER_K_MAX)

// FIPS 203. Section 7.2 Encapsulation key check.
static int encapsulation_key_modulus_check(ml_kem_params *params, const uint8_t *ek) {

int16_t ek_decoded[ENCAPS_KEY_DECODED_MAX_SIZE];
uint8_t ek_recoded[ENCAPS_KEY_ENCODED_MAX_SIZE];

for (size_t i = 0; i < params->k; i++) {
byte_decode_12(&ek_decoded[i * BYTE_DECODE_12_OUT_SIZE], &ek[i * BYTE_DECODE_12_IN_SIZE]);
byte_encode_12(&ek_recoded[i * BYTE_ENCODE_12_OUT_SIZE], &ek_decoded[i * BYTE_ENCODE_12_IN_SIZE]);
}

return verify(ek_recoded, ek, params->k * BYTE_ENCODE_12_OUT_SIZE);
}

/*************************************************
* Name: crypto_kem_enc_derand
*
Expand Down Expand Up @@ -112,13 +209,17 @@ int crypto_kem_enc_derand(ml_kem_params *params,
* - const uint8_t *pk: pointer to input public key
* (an already allocated array of KYBER_PUBLICKEYBYTES bytes)
*
* Returns 0 (success)
* Returns 0 (success), or 1 when the encapsulation key check fails.
**************************************************/
int crypto_kem_enc(ml_kem_params *params,
uint8_t *ct,
uint8_t *ss,
const uint8_t *pk)
{
if (encapsulation_key_modulus_check(params, pk) != 0) {
return 1;
}

uint8_t coins[KYBER_SYMBYTES];
RAND_bytes(coins, KYBER_SYMBYTES);
crypto_kem_enc_derand(params, ct, ss, pk, coins);
Expand Down
33 changes: 14 additions & 19 deletions ssl/ssl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11960,14 +11960,9 @@ TEST_P(BadKemKeyShareAcceptTest, BadKemKeyShareAccept) {
}

// |client_public_key| is initialized with key material that is the correct
// length, but is not a valid key. In this case, the basic sanity checks
// will not reject the key because it has been initialized properly with
// the correct amount of data. The KEM encapsulate function is written
// so that it will return success if given an invalid key of the correct
// length. Therefore, the call to server_key_share->Accept() will succeed,
// but ultimately, the ciphertext (server's public key) will be garbage,
// the server and client will end up with different secrets, and the
// overall handshake will eventually fail.
// length, but it doesn't match the corresponding secret key. The exchange
// will succeed, but the client and the server will end up with different
// secrets, and the overall handshake will eventually fail.
{
bssl::UniquePtr<SSLKeyShare> server_key_share = bssl::SSLKeyShare::Create(t.group_id);
bssl::UniquePtr<SSLKeyShare> client_key_share = bssl::SSLKeyShare::Create(t.group_id);
Expand All @@ -11984,19 +11979,19 @@ TEST_P(BadKemKeyShareAcceptTest, BadKemKeyShareAccept) {
EXPECT_TRUE(CBB_init(&client_out_public_key, t.offer_key_share_size));
EXPECT_TRUE(client_key_share->Offer(&client_out_public_key));

// Then invalidate it by negating the bits in the first byte
uint8_t *invalid_client_public_key_buf =
// Modify the public key making it incompatible with the secret key
uint8_t *modified_client_public_key_buf =
(uint8_t *)OPENSSL_malloc(t.offer_key_share_size);
ASSERT_TRUE(invalid_client_public_key_buf);
ASSERT_TRUE(modified_client_public_key_buf);
const uint8_t *client_out_public_key_data = CBB_data(&client_out_public_key);
ASSERT_TRUE(client_out_public_key_data);
OPENSSL_memcpy(invalid_client_public_key_buf, client_out_public_key_data,
OPENSSL_memcpy(modified_client_public_key_buf, client_out_public_key_data,
t.offer_key_share_size);
invalid_client_public_key_buf[0] = ~invalid_client_public_key_buf[0];
modified_client_public_key_buf[0] ^= 1;
Span<const uint8_t> client_public_key =
MakeConstSpan(invalid_client_public_key_buf, t.offer_key_share_size);
MakeConstSpan(modified_client_public_key_buf, t.offer_key_share_size);

// When the server calls Accept() with the invalid public key, it will
// When the server calls Accept() with the modified public key, it will
// return success
EXPECT_TRUE(CBB_init(&server_out_public_key, t.accept_key_share_size));
EXPECT_TRUE(server_key_share->Accept(&server_out_public_key,
Expand All @@ -12020,7 +12015,7 @@ TEST_P(BadKemKeyShareAcceptTest, BadKemKeyShareAccept) {

EXPECT_EQ(server_alert, 0);
EXPECT_EQ(client_alert, 0);
OPENSSL_free(invalid_client_public_key_buf);
OPENSSL_free(modified_client_public_key_buf);
CBB_cleanup(&server_out_public_key);
CBB_cleanup(&client_out_public_key);
}
Expand Down Expand Up @@ -12114,19 +12109,19 @@ TEST_P(BadKemKeyShareFinishTest, BadKemKeyShareFinish) {
client_alert = 0;
}

// |server_public_key| is initialized with an invalid key of the correct
// |server_public_key| is initialized with a modified key of the correct
// length. The decapsulation operations will succeed; however, the resulting
// shared secret will be garbage, and eventually the overall handshake
// would fail because the client secret does not match the server secret.
{
// The server's public key was already correctly generated previously in
// a call to Accept(). Here we invalidate it by negating the first byte.
// a call to Accept(). Here we modify it.
uint8_t *invalid_server_public_key_buf = (uint8_t *) OPENSSL_malloc(t.accept_key_share_size);
ASSERT_TRUE(invalid_server_public_key_buf);
const uint8_t *server_out_public_key_data = CBB_data(&server_out_public_key);
ASSERT_TRUE(server_out_public_key_data);
OPENSSL_memcpy(invalid_server_public_key_buf, server_out_public_key_data, t.accept_key_share_size);
invalid_server_public_key_buf[0] = ~invalid_server_public_key_buf[0];
invalid_server_public_key_buf[0] ^= 1;

// The call to Finish() will return success
server_public_key =
Expand Down

0 comments on commit 2835116

Please sign in to comment.