Skip to content

Commit

Permalink
Move KEM API and ML-KEM definitions to FIPS module (aws#1828)
Browse files Browse the repository at this point in the history
In this PR, we first move the ML-KEM definitions under the FIPS module
and the set of KEM API. There are two objects to keep track of:

`KEM` is the "context" object of a kem method storing e.g. size
information, unique identifiers, and a reference to a method table (of
type `KEM_METHOD`).

`KEM_METHOD` is the method table that stores references to
implementations for each abstract operation.

Both `KEM` and `KEM_METHOD` has been defined outside the FIPS module.
Move this under the module now. Previously, also the KEM method function
tables were defined in their own source code file. Move these together
with the definitions of the KEMs. They will later go to their own file.

Note, that `KEM` was defined with only `const` fields. I think this was
my invention but it was not an optimal choice. It only allows static
initialization. This is not good with the current static FIPS build, so
remove the `const` qualifier from all non-pointer fields in `KEM`.

Finally, drop part of the table search method. KEMs are now directly
referenced instead of searching through a table. Retain the table search
for the legacy Kyber kems for now though to minimize code changes in
each PR.
  • Loading branch information
torben-hansen authored Sep 6, 2024
1 parent b15bf7e commit 16ca6e7
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 330 deletions.
2 changes: 1 addition & 1 deletion crypto/evp_extra/p_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "../fipsmodule/evp/internal.h"
#include "../fipsmodule/delocate.h"
#include "../kem/internal.h"
#include "../fipsmodule/kem/internal.h"
#include "../internal.h"
#include "internal.h"

Expand Down
2 changes: 1 addition & 1 deletion crypto/evp_extra/p_kem_asn1.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <openssl/mem.h>

#include "../fipsmodule/evp/internal.h"
#include "../kem/internal.h"
#include "../fipsmodule/kem/internal.h"
#include "../internal.h"
#include "internal.h"

Expand Down
1 change: 1 addition & 0 deletions crypto/fipsmodule/bcm.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
#include "hmac/hmac.c"
#include "kdf/kbkdf.c"
#include "kdf/sskdf.c"
#include "kem/kem.c"
#include "md4/md4.c"
#include "md5/md5.c"
#include "ml_kem/ml_kem.c"
Expand Down
94 changes: 94 additions & 0 deletions crypto/fipsmodule/kem/internal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

#ifndef AWSLC_HEADER_KEM_INTERNAL_H
#define AWSLC_HEADER_KEM_INTERNAL_H

#include <openssl/base.h>


#if defined(__cplusplus)
extern "C" {
#endif

// KEM_METHOD structure and helper functions.
typedef struct {
int (*keygen_deterministic)(uint8_t *ctx,
uint8_t *pkey,
const uint8_t *seed);

int (*keygen)(uint8_t *public_key,
uint8_t *secret_key);

int (*encaps_deterministic)(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key,
const uint8_t *seed);

int (*encaps)(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key);

int (*decaps)(uint8_t *shared_secret,
const uint8_t *ciphertext,
const uint8_t *secret_key);
} KEM_METHOD;

// KEM structure and helper functions.
typedef struct {
int nid;
const uint8_t *oid;
uint8_t oid_len;
const char *comment;
size_t public_key_len;
size_t secret_key_len;
size_t ciphertext_len;
size_t shared_secret_len;
size_t keygen_seed_len;
size_t encaps_seed_len;
const KEM_METHOD *method;
} KEM;

// KEM_KEY structure and helper functions.
struct kem_key_st {
const KEM *kem;
uint8_t *public_key;
uint8_t *secret_key;
};

const KEM *KEM_find_kem_by_nid(int nid);

KEM_KEY *KEM_KEY_new(void);
int KEM_KEY_init(KEM_KEY *key, const KEM *kem);
void KEM_KEY_free(KEM_KEY *key);
const KEM *KEM_KEY_get0_kem(KEM_KEY* key);

// KEM_KEY_set_raw_public_key function allocates the public key buffer
// within the given |key| and copies the contents of |in| to it.
//
// NOTE: No checks are done in this function, the caller has to ensure
// that the pointers are valid and |in| has the correct size.
int KEM_KEY_set_raw_public_key(KEM_KEY *key, const uint8_t *in);

// KEM_KEY_set_raw_secret_key function allocates the secret key buffer
// within the given |key| and copies the contents of |in| to it.
//
// NOTE: No checks are done in this function, the caller has to ensure
// that the pointers are valid and |in| has the correct size.
int KEM_KEY_set_raw_secret_key(KEM_KEY *key, const uint8_t *in);

// KEM_KEY_set_raw_key function allocates the public and secret key buffers
// within the given |key| and copies the contents of |in_public| and
// |in_secret| to them.
//
// NOTE: No checks are done in this function, the caller has to ensure
// that the pointers are valid and |in_public| and |in_secret|
// have the correct size.
int KEM_KEY_set_raw_key(KEM_KEY *key, const uint8_t *in_public,
const uint8_t *in_secret);

#if defined(__cplusplus)
} // extern C
#endif

#endif // AWSLC_HEADER_KEM_TEST_INTERNAL_H
270 changes: 270 additions & 0 deletions crypto/fipsmodule/kem/kem.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

#include <openssl/base.h>

#include "../../kem/internal.h"
#include "../delocate.h"
#include "../ml_kem/ml_kem.h"
#include "internal.h"

static const uint8_t kOIDMLKEM512[] = {0xff, 0xff, 0xff, 0xff};
static const uint8_t kOIDMLKEM768[] = {0xff, 0xff, 0xff, 0xff};
static const uint8_t kOIDMLKEM1024[] = {0xff, 0xff, 0xff, 0xff};

static int ml_kem_1024_keygen_deterministic(uint8_t *public_key,
uint8_t *secret_key,
const uint8_t *seed) {
return ml_kem_1024_keypair_deterministic(public_key, secret_key, seed) == 0;
}

static int ml_kem_1024_keygen(uint8_t *public_key,
uint8_t *secret_key) {
return ml_kem_1024_keypair(public_key, secret_key) == 0;
}

static int ml_kem_1024_encaps_deterministic(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key,
const uint8_t *seed) {
return ml_kem_1024_encapsulate_deterministic(ciphertext, shared_secret, public_key, seed) == 0;
}

static int ml_kem_1024_encaps(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key) {
return ml_kem_1024_encapsulate(ciphertext, shared_secret, public_key) == 0;
}

static int ml_kem_1024_decaps(uint8_t *shared_secret,
const uint8_t *ciphertext,
const uint8_t *secret_key) {
return ml_kem_1024_decapsulate(shared_secret, ciphertext, secret_key) == 0;
}

DEFINE_LOCAL_DATA(KEM_METHOD, kem_ml_kem_1024_method) {
out->keygen_deterministic = ml_kem_1024_keygen_deterministic;
out->keygen = ml_kem_1024_keygen;
out->encaps_deterministic = ml_kem_1024_encaps_deterministic;
out->encaps = ml_kem_1024_encaps;
out->decaps = ml_kem_1024_decaps;
}

static int ml_kem_768_keygen_deterministic(uint8_t *public_key,
uint8_t *secret_key,
const uint8_t *seed) {
return ml_kem_768_keypair_deterministic(public_key, secret_key, seed) == 0;
}

static int ml_kem_768_keygen(uint8_t *public_key,
uint8_t *secret_key) {
return ml_kem_768_keypair(public_key, secret_key) == 0;
}

static int ml_kem_768_encaps_deterministic(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key,
const uint8_t *seed) {
return ml_kem_768_encapsulate_deterministic(ciphertext, shared_secret, public_key, seed) == 0;
}

static int ml_kem_768_encaps(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key) {
return ml_kem_768_encapsulate(ciphertext, shared_secret, public_key) == 0;
}

static int ml_kem_768_decaps(uint8_t *shared_secret,
const uint8_t *ciphertext,
const uint8_t *secret_key) {
return ml_kem_768_decapsulate(shared_secret, ciphertext, secret_key) == 0;
}

DEFINE_LOCAL_DATA(KEM_METHOD, kem_ml_kem_768_method) {
out->keygen_deterministic = ml_kem_768_keygen_deterministic;
out->keygen = ml_kem_768_keygen;
out->encaps_deterministic = ml_kem_768_encaps_deterministic;
out->encaps = ml_kem_768_encaps;
out->decaps = ml_kem_768_decaps;
}

static int ml_kem_512_keygen_deterministic(uint8_t *public_key,
uint8_t *secret_key,
const uint8_t *seed) {
return ml_kem_512_keypair_deterministic(public_key, secret_key, seed) == 0;
}

static int ml_kem_512_keygen(uint8_t *public_key,
uint8_t *secret_key) {
return ml_kem_512_keypair(public_key, secret_key) == 0;
}

static int ml_kem_512_encaps_deterministic(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key,
const uint8_t *seed) {
return ml_kem_512_encapsulate_deterministic(ciphertext, shared_secret, public_key, seed) == 0;
}

static int ml_kem_512_encaps(uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key) {
return ml_kem_512_encapsulate(ciphertext, shared_secret, public_key) == 0;
}

static int ml_kem_512_decaps(uint8_t *shared_secret,
const uint8_t *ciphertext,
const uint8_t *secret_key) {
return ml_kem_512_decapsulate(shared_secret, ciphertext, secret_key) == 0;
}

DEFINE_LOCAL_DATA(KEM_METHOD, kem_ml_kem_512_method) {
out->keygen_deterministic = ml_kem_512_keygen_deterministic;
out->keygen = ml_kem_512_keygen;
out->encaps_deterministic = ml_kem_512_encaps_deterministic;
out->encaps = ml_kem_512_encaps;
out->decaps = ml_kem_512_decaps;
}

DEFINE_LOCAL_DATA(KEM, KEM_ml_kem_512) {
out->nid = NID_MLKEM512;
out->oid = kOIDMLKEM512;
out->oid_len = sizeof(kOIDMLKEM512);
out->comment = "MLKEM512 ";
out->public_key_len = MLKEM512_PUBLIC_KEY_BYTES;
out->secret_key_len = MLKEM512_SECRET_KEY_BYTES;
out->ciphertext_len = MLKEM512_CIPHERTEXT_BYTES;
out->shared_secret_len = MLKEM512_SHARED_SECRET_LEN;
out->keygen_seed_len = MLKEM512_KEYGEN_SEED_LEN;
out->encaps_seed_len = MLKEM512_ENCAPS_SEED_LEN;
out->method = kem_ml_kem_512_method();
}

DEFINE_LOCAL_DATA(KEM, KEM_ml_kem_768) {
out->nid = NID_MLKEM768;
out->oid = kOIDMLKEM768;
out->oid_len = sizeof(kOIDMLKEM768);
out->comment = "MLKEM768 ";
out->public_key_len = MLKEM768_PUBLIC_KEY_BYTES;
out->secret_key_len = MLKEM768_SECRET_KEY_BYTES;
out->ciphertext_len = MLKEM768_CIPHERTEXT_BYTES;
out->shared_secret_len = MLKEM768_SHARED_SECRET_LEN;
out->keygen_seed_len = MLKEM768_KEYGEN_SEED_LEN;
out->encaps_seed_len = MLKEM768_ENCAPS_SEED_LEN;
out->method = kem_ml_kem_768_method();
}

DEFINE_LOCAL_DATA(KEM, KEM_ml_kem_1024) {
out->nid = NID_MLKEM1024;
out->oid = kOIDMLKEM1024;
out->oid_len = sizeof(kOIDMLKEM1024);
out->comment = "MLKEM1024 ";
out->public_key_len = MLKEM1024_PUBLIC_KEY_BYTES;
out->secret_key_len = MLKEM1024_SECRET_KEY_BYTES;
out->ciphertext_len = MLKEM1024_CIPHERTEXT_BYTES;
out->shared_secret_len = MLKEM1024_SHARED_SECRET_LEN;
out->keygen_seed_len = MLKEM1024_KEYGEN_SEED_LEN;
out->encaps_seed_len = MLKEM1024_ENCAPS_SEED_LEN;
out->method = kem_ml_kem_1024_method();
}

const KEM *KEM_find_kem_by_nid(int nid) {

switch (nid) {
case NID_MLKEM512:
return KEM_ml_kem_512();
case NID_MLKEM768:
return KEM_ml_kem_768();
case NID_MLKEM1024:
return KEM_ml_kem_1024();
default:
break;
}

// We couldn't match a known KEM. Try legacy KEMs.
const KEM *legacy_kems = get_legacy_kems();
for (size_t i = 0; i < AWSLC_NUM_LEGACY_KEMS; i++) {
if (legacy_kems[i].nid == nid) {
return &legacy_kems[i];
}
}

return NULL;
}

KEM_KEY *KEM_KEY_new(void) {
KEM_KEY *ret = OPENSSL_zalloc(sizeof(KEM_KEY));
if (ret == NULL) {
return NULL;
}

return ret;
}

static void KEM_KEY_clear(KEM_KEY *key) {
key->kem = NULL;
OPENSSL_free(key->public_key);
OPENSSL_free(key->secret_key);
key->public_key = NULL;
key->secret_key = NULL;
}

int KEM_KEY_init(KEM_KEY *key, const KEM *kem) {
if (key == NULL || kem == NULL) {
return 0;
}
// If the key is already initialized clear it.
KEM_KEY_clear(key);

key->kem = kem;
key->public_key = OPENSSL_malloc(kem->public_key_len);
key->secret_key = OPENSSL_malloc(kem->secret_key_len);
if (key->public_key == NULL || key->secret_key == NULL) {
KEM_KEY_clear(key);
return 0;
}

return 1;
}

void KEM_KEY_free(KEM_KEY *key) {
if (key == NULL) {
return;
}
KEM_KEY_clear(key);
OPENSSL_free(key);
}

const KEM *KEM_KEY_get0_kem(KEM_KEY* key) {
return key->kem;
}

int KEM_KEY_set_raw_public_key(KEM_KEY *key, const uint8_t *in) {
key->public_key = OPENSSL_memdup(in, key->kem->public_key_len);
if (key->public_key == NULL) {
return 0;
}

return 1;
}

int KEM_KEY_set_raw_secret_key(KEM_KEY *key, const uint8_t *in) {
key->secret_key = OPENSSL_memdup(in, key->kem->secret_key_len);
if (key->secret_key == NULL) {
return 0;
}

return 1;
}

int KEM_KEY_set_raw_key(KEM_KEY *key, const uint8_t *in_public,
const uint8_t *in_secret) {
key->public_key = OPENSSL_memdup(in_public, key->kem->public_key_len);
key->secret_key = OPENSSL_memdup(in_secret, key->kem->secret_key_len);
if (key->public_key == NULL || key->secret_key == NULL) {
KEM_KEY_clear(key);
return 0;
}

return 1;
}
Loading

0 comments on commit 16ca6e7

Please sign in to comment.