Skip to content

Commit

Permalink
[EC] Unify scalar_mul_base point for ec_nistp curves (aws#2003)
Browse files Browse the repository at this point in the history
Added unified scalar multiplication of the base point for curves
implemented in ec_nistp. This is a refactor of the algorithm in
p384.c and p521.c that makes it generic. The implementation 
in p384.c and p521.c is substituted with this new unified
implementation.
  • Loading branch information
dkostic authored Nov 22, 2024
1 parent c48572a commit 80f984e
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 293 deletions.
140 changes: 139 additions & 1 deletion crypto/fipsmodule/ec/ec_nistp.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// |----------------------------|
// | 1. | x | x | x* |
// | 2. | x | x | x* |
// | 3. | | | |
// | 3. | x | x | |
// | 4. | x | x | x* |
// | 5. | | | |
// * For P-256, only the Fiat-crypto implementation in p256.c is replaced.
Expand Down Expand Up @@ -498,3 +498,141 @@ void ec_nistp_scalar_mul(const ec_nistp_meth *ctx,
cmovznz(y_out, ctx->felem_num_limbs, t, y_tmp, y_res);
cmovznz(z_out, ctx->felem_num_limbs, t, z_tmp, z_res);
}

// Multiplication of the base point G of the curve with the given scalar.
// The product is computed with the Comb method using a precomputed table
// and the regular-wNAF scalar encoding.
//
// While the algorithm is generic and works for different curves, window sizes,
// and scalar sizes, for clarity, we describe it by using the example of P-521.
//
// The precomputed table has 27 sub-tables each holding 16 points:
//
// 0 : [1]G, [3]G, ..., [31]G
// 1 : [1*2^20]G, [3*2^20]G, ..., [31*2^20]G
// ...
// i : [1*2^20i]G, [3*2^20i]G, ..., [31*2^20i]G
// ...
// 26 : [2^520]G, [3*2^520]G, ..., [31*2^520]G
// Computing the negation of a point P = (x, y) is relatively easy:
// -P = (x, -y).
// So we may assume that for each sub-table we have 32 points instead of 16:
// [\pm 1*2^20i]G, [\pm 3*2^20i]G, ..., [\pm 31*2^20i]G.
//
// The 521-bit |scalar| is recoded (regular-wNAF encoding) into 105 signed
// digits, each of length 5 bits, as explained in the
// |p521_felem_mul_scalar_rwnaf| function. Namely,
// scalar' = s_0 + s_1*2^5 + s_2*2^10 + ... + s_104*2^520,
// where digits s_i are in [\pm 1, \pm 3, ..., \pm 31]. Note that for an odd
// scalar we have that scalar = scalar', while in the case of an even
// scalar we have that scalar = scalar' - 1.
//
// To compute the required product, [scalar]G, we may do the following.
// Group the recoded digits of the scalar in 4 groups:
// | corresponding multiples in
// digits | the recoded representation
// -------------------------------------------------------------------------
// (0): {s_0, s_4, s_8, ..., s_100, s_104} | { 2^0, 2^20, ..., 2^500, 2^520}
// (1): {s_1, s_5, s_9, ..., s_101} | { 2^5, 2^25, ..., 2^505}
// (2): {s_2, s_6, s_10, ..., s_102} | {2^10, 2^30, ..., 2^510}
// (3): {s_3, s_7, s_11, ..., s_103} | {2^15, 2^35, ..., 2^515}
// corresponding sub-table lookup | { T0, T1, ..., T25, T26}
//
// The group (0) digits correspond precisely to the multiples of G that are
// held in the 27 precomputed sub-tables, so we may simply read the appropriate
// points from the sub-tables and sum them all up (negating if needed, i.e., if
// a digit s_i is negative, we read the point corresponding to the abs(s_i) and
// negate it before adding it to the sum).
// The remaining three groups (1), (2), and (3), correspond to the multiples
// of G from the sub-tables multiplied additionally by 2^5, 2^10, and 2^15,
// respectively. Therefore, for these groups we may read the appropriate points
// from the table, double them 5, 10, or 15 times, respectively, and add them
// to the final result.
//
// To minimize the number of required doubling operations we process the digits
// of the scalar from left to right. In other words, the algorithm is:
// 1. For group (i) in this order (3, 2, 1, 0):
// 2. Double the accumulator 5 times except in the first iteration.
// 3. Read the points corresponding to the group (i) digits from the tables
// and add them to an accumulator.
// 4. If the scalar is even subtract G from the accumulator.
//
// Note: this function is designed to be constant-time.
void ec_nistp_scalar_mul_base(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
ec_nistp_felem_limb *z_out,
const EC_SCALAR *scalar) {
// Regular-wNAF encoding of the scalar.
int16_t rwnaf[SCALAR_MUL_MAX_NUM_WINDOWS];
scalar_rwnaf(rwnaf, SCALAR_MUL_WINDOW_SIZE, scalar, ctx->felem_num_bits);
size_t num_windows = DIV_AND_CEIL(ctx->felem_num_bits, SCALAR_MUL_WINDOW_SIZE);

// We need two point accumulators, so we define them of maximum size
// to avoid allocation, and just take pointers to individual coordinates.
// (This cruft will disapear when we refactor point_add/dbl to work with
// whole points instead of individual coordinates).
ec_nistp_felem_limb res[3 * FELEM_MAX_NUM_OF_LIMBS] = {0};
ec_nistp_felem_limb tmp[3 * FELEM_MAX_NUM_OF_LIMBS] = {0};
ec_nistp_felem_limb *x_res = &res[0];
ec_nistp_felem_limb *y_res = &res[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_res = &res[ctx->felem_num_limbs * 2];
ec_nistp_felem_limb *x_tmp = &tmp[0];
ec_nistp_felem_limb *y_tmp = &tmp[ctx->felem_num_limbs];
ec_nistp_felem_limb *z_tmp = &tmp[ctx->felem_num_limbs * 2];

// Process the 4 groups of digits starting from group (3) down to group (0).
for (int i = 3; i >= 0; i--) {
// Double |res| 5 times in each iteration, except in the first one.
for (size_t j = 0; i != 3 && j < SCALAR_MUL_WINDOW_SIZE; j++) {
ctx->point_dbl(x_res, y_res, z_res, x_res, y_res, z_res);
}

// Process the digits in the current group from the most to the least
// significant one.
size_t start_idx = ((num_windows - i - 1) / 4) * 4 + i;

for (int j = start_idx; j >= 0; j -= 4) {
// For each digit |d| in the current group read the corresponding point
// from the table and add it to |res|. If |d| is negative, negate
// the point before adding it to |res|.
int16_t d = rwnaf[j];
int16_t is_neg = (d >> 15) & 1; // is_neg = (d < 0) ? 1 : 0
d = (d ^ -is_neg) + is_neg; // d = abs(d)

int16_t idx = d >> 1;

// Select the point to add, in constant time.
size_t point_num_limbs = 2 * ctx->felem_num_limbs; // Affine points.
size_t subtable_num_limbs = SCALAR_MUL_TABLE_NUM_POINTS * point_num_limbs;
size_t table_idx = (j / 4) * subtable_num_limbs;
const ec_nistp_felem_limb *table = &ctx->scalar_mul_base_table[table_idx];
select_point_from_table(ctx, tmp, table, idx, 0);

// Negate y coordinate of the point tmp = (x, y); ftmp = -y.
ec_nistp_felem ftmp;
ctx->felem_neg(ftmp, y_tmp);

cmovznz(y_tmp, ctx->felem_num_limbs, is_neg, y_tmp, ftmp);

// Add the point to the accumulator |res|.
ctx->point_add(x_res, y_res, z_res, x_res, y_res, z_res, 1,
x_tmp, y_tmp, ctx->felem_one);
}
}

// Conditionally subtract G if the scalar is even, in constant-time.
const ec_nistp_felem_limb *x_mp = &ctx->scalar_mul_base_table[0];
const ec_nistp_felem_limb *y_mp = &ctx->scalar_mul_base_table[ctx->felem_num_limbs];
ec_nistp_felem ftmp;
ctx->felem_neg(ftmp, y_mp);

// Subtract P from the accumulator.
ctx->point_add(x_tmp, y_tmp, z_tmp, x_res, y_res, z_res, 1, x_mp, ftmp, ctx->felem_one);

// Select |res| or |res - P| based on parity of the scalar.
ec_nistp_felem_limb t = scalar->words[0] & 1;
cmovznz(x_out, ctx->felem_num_limbs, t, x_tmp, x_res);
cmovznz(y_out, ctx->felem_num_limbs, t, y_tmp, y_res);
cmovznz(z_out, ctx->felem_num_limbs, t, z_tmp, z_res);
}
8 changes: 8 additions & 0 deletions crypto/fipsmodule/ec/ec_nistp.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ typedef struct {
void (*felem_sqr)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a);
void (*felem_neg)(ec_nistp_felem_limb *c, const ec_nistp_felem_limb *a);
ec_nistp_felem_limb (*felem_nz)(const ec_nistp_felem_limb *a);
const ec_nistp_felem_limb *felem_one;

void (*point_dbl)(ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
Expand All @@ -72,6 +73,7 @@ typedef struct {
const ec_nistp_felem_limb *y2,
const ec_nistp_felem_limb *z2);

const ec_nistp_felem_limb *scalar_mul_base_table;
} ec_nistp_meth;

const ec_nistp_meth *p256_methods(void);
Expand Down Expand Up @@ -106,5 +108,11 @@ void ec_nistp_scalar_mul(const ec_nistp_meth *ctx,
const ec_nistp_felem_limb *y_in,
const ec_nistp_felem_limb *z_in,
const EC_SCALAR *scalar);

void ec_nistp_scalar_mul_base(const ec_nistp_meth *ctx,
ec_nistp_felem_limb *x_out,
ec_nistp_felem_limb *y_out,
ec_nistp_felem_limb *z_out,
const EC_SCALAR *scalar);
#endif // EC_NISTP_H

155 changes: 8 additions & 147 deletions crypto/fipsmodule/ec/p384.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ static void p384_felem_copy(p384_limb_t out[P384_NLIMBS],
}
}

static void p384_felem_cmovznz(p384_limb_t out[P384_NLIMBS],
p384_limb_t t,
const p384_limb_t z[P384_NLIMBS],
const p384_limb_t nz[P384_NLIMBS]) {
p384_limb_t mask = constant_time_is_zero_w(t);
for (size_t i = 0; i < P384_NLIMBS; i++) {
out[i] = constant_time_select_w(mask, z[i], nz[i]);
}
}

static void p384_from_generic(p384_felem out, const EC_FELEM *in) {
#ifdef OPENSSL_BIG_ENDIAN
uint8_t tmp[P384_EC_FELEM_BYTES];
Expand Down Expand Up @@ -270,6 +260,8 @@ static void p384_point_add(p384_felem x3, p384_felem y3, p384_felem z3,
ec_nistp_point_add(p384_methods(), x3, y3, z3, x1, y1, z1, mixed, x2, y2, z2);
}

#include "p384_table.h"

#if defined(EC_NISTP_USE_S2N_BIGNUM)
DEFINE_METHOD_FUNCTION(ec_nistp_meth, p384_methods) {
out->felem_num_limbs = P384_NLIMBS;
Expand All @@ -280,8 +272,10 @@ DEFINE_METHOD_FUNCTION(ec_nistp_meth, p384_methods) {
out->felem_sqr = bignum_montsqr_p384_selector;
out->felem_neg = bignum_neg_p384;
out->felem_nz = p384_felem_nz;
out->felem_one = p384_felem_one;
out->point_dbl = p384_point_double;
out->point_add = p384_point_add;
out->scalar_mul_base_table = (const ec_nistp_felem_limb*) p384_g_pre_comp;
}
#else
DEFINE_METHOD_FUNCTION(ec_nistp_meth, p384_methods) {
Expand All @@ -293,8 +287,10 @@ DEFINE_METHOD_FUNCTION(ec_nistp_meth, p384_methods) {
out->felem_sqr = fiat_p384_square;
out->felem_neg = fiat_p384_opp;
out->felem_nz = p384_felem_nz;
out->felem_one = p384_felem_one;
out->point_dbl = p384_point_double;
out->point_add = p384_point_add;
out->scalar_mul_base_table = (const ec_nistp_felem_limb*) p384_g_pre_comp;
}
#endif

Expand Down Expand Up @@ -494,20 +490,6 @@ OPENSSL_STATIC_ASSERT(P384_MUL_WSIZE == 5,
#define P384_MUL_TABLE_SIZE (P384_MUL_TWO_TO_WSIZE >> 1)
#define P384_MUL_PUB_TABLE_SIZE (1 << (P384_MUL_PUB_WSIZE - 1))

// p384_select_point_affine selects the |idx|-th affine point from
// the given precomputed table and copies it to |out| in constant-time.
static void p384_select_point_affine(p384_felem out[2],
size_t idx,
const p384_felem table[][2],
size_t table_size) {
OPENSSL_memset(out, 0, sizeof(p384_felem) * 2);
for (size_t i = 0; i < table_size; i++) {
p384_limb_t mismatch = i ^ idx;
p384_felem_cmovznz(out[0], mismatch, table[i][0], out[0]);
p384_felem_cmovznz(out[1], mismatch, table[i][1], out[1]);
}
}

// Multiplication of an arbitrary point by a scalar, r = [scalar]P.
static void ec_GFp_nistp384_point_mul(const EC_GROUP *group, EC_JACOBIAN *r,
const EC_JACOBIAN *p,
Expand All @@ -526,135 +508,14 @@ static void ec_GFp_nistp384_point_mul(const EC_GROUP *group, EC_JACOBIAN *r,
p384_to_generic(&r->Z, res[2]);
}

// Include the precomputed table for the based point scalar multiplication.
#include "p384_table.h"

// Multiplication of the base point G of P-384 curve with the given scalar.
// The product is computed with the Comb method using the precomputed table
// |p384_g_pre_comp| from |p384_table.h| file and the regular-wNAF scalar
// encoding.
//
// The |p384_g_pre_comp| table has 20 sub-tables each holding 16 points:
// 0 : [1]G, [3]G, ..., [31]G
// 1 : [1*2^20]G, [3*2^20]G, ..., [31*2^20]G
// ...
// i : [1*2^20i]G, [3*2^20i]G, ..., [31*2^20i]G
// ...
// 19 : [2^380]G, [3*2^380]G, ..., [31*2^380]G.
// Computing the negation of a point P = (x, y) is relatively easy:
// -P = (x, -y).
// So we may assume that for each sub-table we have 32 points instead of 16:
// [\pm 1*2^20i]G, [\pm 3*2^20i]G, ..., [\pm 31*2^20i]G.
//
// The 384-bit |scalar| is recoded (regular-wNAF encoding) into 77 signed
// digits, each of length 5 bits, as explained in the
// |p384_felem_mul_scalar_rwnaf| function. Namely,
// scalar' = s_0 + s_1*2^5 + s_2*2^10 + ... + s_76*2^380,
// where digits s_i are in [\pm 1, \pm 3, ..., \pm 31]. Note that for an odd
// scalar we have that scalar = scalar', while in the case of an even
// scalar we have that scalar = scalar' - 1.
//
// To compute the required product, [scalar]G, we may do the following.
// Group the recoded digits of the scalar in 4 groups:
// | corresponding multiples in
// digits | the recoded representation
// -------------------------------------------------------------------------
// (0): {s_0, s_4, s_8, ..., s_72, s_76} | { 2^0, 2^20, ..., 2^360, 2^380}
// (1): {s_1, s_5, s_9, ..., s_73} | { 2^5, 2^25, ..., 2^365}
// (2): {s_2, s_6, s_10, ..., s_74} | {2^10, 2^30, ..., 2^370}
// (3): {s_3, s_7, s_11, ..., s_75} | {2^15, 2^35, ..., 2^375}
// corresponding sub-table lookup | { T0, T1, ..., T18, T19}
//
// The group (0) digits correspond precisely to the multiples of G that are
// held in the 20 precomputed sub-tables, so we may simply read the appropriate
// points from the sub-tables and sum them all up (negating if needed, i.e., if
// a digit s_i is negative, we read the point corresponding to the abs(s_i) and
// negate it before adding it to the sum).
// The remaining three groups (1), (2), and (3), correspond to the multiples
// of G from the sub-tables multiplied additionally by 2^5, 2^10, and 2^15,
// respectively. Therefore, for these groups we may read the appropriate points
// from the table, double them 5, 10, or 15 times, respectively, and add them
// to the final result.
//
// To minimize the number of required doubling operations we process the digits
// of the scalar from left to right. In other words, the algorithm is:
// 1. Read the points corresponding to the group (3) digits from the table
// and add them to an accumulator.
// 2. Double the accumulator 5 times.
// 3. Repeat steps 1. and 2. for groups (2) and (1),
// and perform step 1. for group (0).
// 4. If the scalar is even subtract G from the accumulator.
//
// Note: this function is constant-time.
static void ec_GFp_nistp384_point_mul_base(const EC_GROUP *group,
EC_JACOBIAN *r,
const EC_SCALAR *scalar) {
p384_felem res[3] = {{0}, {0}, {0}};

p384_felem res[3] = {{0}, {0}, {0}}, tmp[3] = {{0}, {0}, {0}}, ftmp;
int16_t rnaf[P384_MUL_NWINDOWS] = {0};

// Recode the scalar.
scalar_rwnaf(rnaf, P384_MUL_WSIZE, scalar, 384);

// Process the 4 groups of digits starting from group (3) down to group (0).
for (int i = 3; i >= 0; i--) {
// Double |res| 5 times in each iteration, except in the first one.
for (int j = 0; i != 3 && j < P384_MUL_WSIZE; j++) {
p384_point_double(res[0], res[1], res[2], res[0], res[1], res[2]);
}
ec_nistp_scalar_mul_base(p384_methods(), res[0], res[1], res[2], scalar);

// Process the digits in the current group from the most to the least
// significant one (this is a requirement to ensure that the case of point
// doubling can't happen).
// For group (3) we process digits s_75 to s_3, for group (2) s_74 to s_2,
// group (1) s_73 to s_1, and for group (0) s_76 to s_0.
const size_t start_idx = ((P384_MUL_NWINDOWS - i - 1)/4)*4 + i;

for (int j = start_idx; j >= 0; j -= 4) {
// For each digit |d| in the current group read the corresponding point
// from the table and add it to |res|. If |d| is negative, negate
// the point before adding it to |res|.
int16_t d = rnaf[j];
// is_neg = (d < 0) ? 1 : 0
int16_t is_neg = (d >> 15) & 1;
// d = abs(d)
d = (d ^ -is_neg) + is_neg;

int16_t idx = d >> 1;

// Select the point to add, in constant time.
p384_select_point_affine(tmp, idx, p384_g_pre_comp[j / 4],
P384_MUL_TABLE_SIZE);

// Negate y coordinate of the point tmp = (x, y); ftmp = -y.
p384_felem_opp(ftmp, tmp[1]);
// Conditionally select y or -y depending on the sign of the digit |d|.
p384_felem_cmovznz(tmp[1], is_neg, tmp[1], ftmp);

// Add the point to the accumulator |res|.
// Note that the points in the pre-computed table are given with affine
// coordinates. The point addition function computes a sum of two points,
// either both given in projective, or one in projective and the other one
// in affine coordinates. The |mixed| flag indicates the latter option,
// in which case we set the third coordinate of the second point to one.
p384_point_add(res[0], res[1], res[2], res[0], res[1], res[2],
1 /* mixed */, tmp[0], tmp[1], p384_felem_one);
}
}

// Conditionally subtract G if the scalar is even, in constant-time.
// First, compute |tmp| = |res| + (-G).
p384_felem_copy(tmp[0], p384_g_pre_comp[0][0][0]);
p384_felem_opp(tmp[1], p384_g_pre_comp[0][0][1]);
p384_point_add(tmp[0], tmp[1], tmp[2], res[0], res[1], res[2],
1 /* mixed */, tmp[0], tmp[1], p384_felem_one);

// Select |res| or |tmp| based on the |scalar| parity.
p384_felem_cmovznz(res[0], scalar->words[0] & 1, tmp[0], res[0]);
p384_felem_cmovznz(res[1], scalar->words[0] & 1, tmp[1], res[1]);
p384_felem_cmovznz(res[2], scalar->words[0] & 1, tmp[2], res[2]);

// Copy the result to the output.
p384_to_generic(&r->X, res[0]);
p384_to_generic(&r->Y, res[1]);
p384_to_generic(&r->Z, res[2]);
Expand Down
Loading

0 comments on commit 80f984e

Please sign in to comment.