Skip to content

Commit

Permalink
Add explanation for FIPS 203 encaps and decaps input validation (#1884)
Browse files Browse the repository at this point in the history
Add explanation for FIPS 203 encaps and decaps input validation.
In particular, explain why we do only part of the checks in the ML-KEM
code itself.
  • Loading branch information
dkostic authored Sep 27, 2024
1 parent c2846eb commit 8ed554c
Showing 1 changed file with 24 additions and 74 deletions.
98 changes: 24 additions & 74 deletions crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,78 +60,6 @@ int crypto_kem_keypair(ml_kem_params *params,
return 0;
}

// REFERENCE IMPLEMENTATION OF SEVERAL FIPS 203 FUNCTIONS.
// Further below we implement optimized versions of the functions
// that are actually used. We commented out and kept the reference
// code for posterity.
//
// FIPS 203. Algorithm 3 BitsToBytes
// Converts a bit array (of a length that is a multiple of eight)
// into an array of bytes.
// static void bits_to_bytes(uint8_t *bytes, size_t num_bytes,
// const uint8_t *bits, size_t num_bits) {
// assert(num_bits == num_bytes * 8);
//
// for (size_t i = 0; i < num_bytes; i++) {
// uint8_t byte = 0;
// for (size_t j = 0; j < 8; j++) {
// byte |= (bits[i * 8 + j] << j);
// }
// bytes[i] = byte;
// }
// }
// FIPS 203. Algorithm 4 BytesToBits
// Performs the inverse of BitsToBytes, converting a byte array into a bit array.
// static void bytes_to_bits(uint8_t *bits, size_t num_bits,
// const uint8_t *bytes, size_t num_bytes) {
// assert(num_bits == num_bytes * 8);
//
// for (size_t i = 0; i < num_bytes; i++) {
// uint8_t byte = bytes[i];
// for (size_t j = 0; j < 8; j++) {
// bits[i * 8 + j] = (byte >> j) & 1;
// }
// }
// }
//
// #define BYTE_ENCODE_12_IN_SIZE (256)
// #define BYTE_ENCODE_12_OUT_SIZE (32 * 12)
// #define BYTE_ENCODE_12_NUM_BITS (256 * 12)
//
// // FIPS 203. Algorithm 5 ByteEncode_12
// // Encodes an array of 256 12-bit integers into a byte array.
// static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE],
// const int16_t in[BYTE_ENCODE_12_IN_SIZE]) {
// uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0};
// for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) {
// int16_t a = in[i];
// for (size_t j = 0; j < 12; j++) {
// bits[i * 12 + j] = a & 1;
// a = a >> 1;
// }
// }
// bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS);
// }
//
// #define BYTE_DECODE_12_OUT_SIZE (256)
// #define BYTE_DECODE_12_IN_SIZE (32 * 12)
// #define BYTE_DECODE_12_NUM_BITS (256 * 12)
//
// // FIPS 203. Algorithm 6 ByteDecode_12
// // Decodes a byte array into an array of 256 12-bit integers.
// static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE],
// const uint8_t in[BYTE_DECODE_12_IN_SIZE]) {
// uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0};
// bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE);
// for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) {
// int16_t val = 0;
// for (size_t j = 0; j < 12; j++) {
// val |= bits[i * 12 + j] << j;
// }
// out[i] = centered_to_positive_representative(barrett_reduce(val));
// }
// }

// Converts a centered representative |in| which is an integer in
// {-(q-1)/2, ..., (q-1)/2}, to a positive representative in {0, ..., q-1}.
// It implements in constant-time the following operation:
Expand All @@ -154,7 +82,7 @@ static int16_t centered_to_positive_representative(int16_t in) {
// in: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ...
// out: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ...
// We divide the input in pairs of elements (2 x 12 bits = 24 bits),
// and the output in triples (3 x 8 bits = 24 bits). For each pair/triplet we:
// and the output in triplets (3 x 8 bits = 24 bits). For each pair/triplet we:
// - out0 <-- first eight bits of in0,
// - out1 <-- concatenate last 4 bits of in0 and first 4 bits of in1,
// - out2 <-- last 8 bits of in1.
Expand All @@ -174,7 +102,7 @@ static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE],
// Intuition for the implementation:
// in: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ...
// out: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ...
// We divide the input in triples of elements (3 x 8 bits = 24 bits),
// We divide the input in triplets of elements (3 x 8 bits = 24 bits),
// and the output in pairs (2 x 12 bits = 24 bits). For each pair/triplet we:
// - out[0] <-- concatenate eight bits of in[0] and first 4 bits of in[1],
// - out[1] <-- concatenate last 4 bits of in[1] and 8 bits of in[2].
Expand Down Expand Up @@ -202,6 +130,16 @@ static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE],
#define ENCAPS_KEY_DECODED_MAX_SIZE (BYTE_DECODE_12_OUT_SIZE * KYBER_K_MAX)

// FIPS 203. Section 7.2 Encapsulation key check.
// This function implements the encapsulation key modulus check. The other
// check specified in Section 7.2 is a type check the key. We can safely omit
// that check here because it is done in higher level functions. The required
// lengths for all variants of ML-KEM are hard-coded in: fipsmodule/kem/kem.c.
// If a key is generated by aws-lc then it satisfies the length requirements.
// If a key is generated outside of aws-lc, it has to be imported into an
// `EVP_PKEY` object to be used within aws-lc. We provide only these three
// functions to do that: `EVP_PKEY_kem_new_raw_key`,
// `EVP_PKEY_kem_new_raw_secret_key`, `EVP_PKEY_kem_new_raw_public_key`.
// The lengths are checked in all three functions.
static int encapsulation_key_modulus_check(ml_kem_params *params, const uint8_t *ek) {

int16_t ek_decoded[ENCAPS_KEY_DECODED_MAX_SIZE];
Expand All @@ -220,6 +158,18 @@ static int encapsulation_key_modulus_check(ml_kem_params *params, const uint8_t
// dk <-- (dk_pke || ek || H(ek) || z).
// This check takes |ek| out of |dk|, computes H(ek), and verifies that it is
// the same as the H(ek) portion stored in |dk|.
//
// This function implements the decapsulation key hash check. The other checks
// specified in Section 7.3 are the ciphertext and the key type check. We can
// safely omit those checks here because they are done in higher level functions.
// The required lengths for all variants of ML-KEM are hard-coded in
// fipsmodule/kem/kem.c. If a key is generated by aws-lc then it satisfies
// the length requirements. If a key is generated outside of aws-lc, it has to
// be imported into an `EVP_PKEY` object to be used within aws-lc. We provide
// only these three functions to do that: `EVP_PKEY_kem_new_raw_key`,
// `EVP_PKEY_kem_new_raw_secret_key`, `EVP_PKEY_kem_new_raw_public_key`.
// The lengths are checked in all three functions. Additionally, the ciphertext
// length is checked in function pkey_kem_decapsulate in fipsmodule/evp/p_kem.c.
static int decapsulation_key_hash_check(ml_kem_params *params, const uint8_t *dk) {
uint8_t dk_pke_hash_computed[KYBER_SYMBYTES] = {0};

Expand Down

0 comments on commit 8ed554c

Please sign in to comment.