Skip to content

Commit

Permalink
ec: Use LeakyLimb for public values.
Browse files Browse the repository at this point in the history
Take a step toward making `Word` and `Limb` opaque types.

This adds some unnecessary copies but the overhead is
negligible as those copies are outside of loops.
  • Loading branch information
briansmith committed Dec 7, 2024
1 parent 5f3dbbf commit 4f915bc
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 52 deletions.
17 changes: 9 additions & 8 deletions mk/generate_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
p: limbs_from_hex("%(q)x"),
rr: limbs_from_hex(%(q_rr)s),
},
n: Elem::from_hex("%(n)x"),
n: PublicElem::from_hex("%(n)x"),
a: Elem::from_hex(%(a)s),
b: Elem::from_hex(%(b)s),
a: PublicElem::from_hex(%(a)s),
b: PublicElem::from_hex(%(b)s),
elem_mul_mont: p%(bits)s_elem_mul_mont,
elem_sqr_mont: p%(bits)s_elem_sqr_mont,
Expand All @@ -56,8 +56,8 @@
};
pub(super) static GENERATOR: (Elem<R>, Elem<R>) = (
Elem::from_hex(%(Gx)s),
Elem::from_hex(%(Gy)s),
PublicElem::from_hex(%(Gx)s),
PublicElem::from_hex(%(Gy)s),
);
pub static PRIVATE_KEY_OPS: PrivateKeyOps = PrivateKeyOps {
Expand Down Expand Up @@ -93,7 +93,8 @@
fn p%(bits)s_point_mul_base_impl(a: &Scalar) -> Point {
// XXX: Not efficient. TODO: Precompute multiples of the generator.
PRIVATE_KEY_OPS.point_mul(a, &GENERATOR)
let generator = (Elem::from(&GENERATOR.0), Elem::from(&GENERATOR.1));
PRIVATE_KEY_OPS.point_mul(a, &generator)
}
pub static PUBLIC_KEY_OPS: PublicKeyOps = PublicKeyOps {
Expand All @@ -112,7 +113,7 @@
twin_mul_inefficient(&PRIVATE_KEY_OPS, g_scalar, p_scalar, p_xy, cpu)
},
q_minus_n: Elem::from_hex("%(q_minus_n)x"),
q_minus_n: PublicElem::from_hex("%(q_minus_n)x"),
// TODO: Use an optimized variable-time implementation.
scalar_inv_to_mont_vartime: |s| PRIVATE_SCALAR_OPS.scalar_inv_to_mont(s),
Expand All @@ -121,7 +122,7 @@
pub static PRIVATE_SCALAR_OPS: PrivateScalarOps = PrivateScalarOps {
scalar_ops: &SCALAR_OPS,
oneRR_mod_n: Scalar::from_hex(%(oneRR_mod_n)s),
oneRR_mod_n: PublicScalar::from_hex(%(oneRR_mod_n)s),
scalar_inv_to_mont: p%(bits)s_scalar_inv_to_mont,
};
Expand Down
8 changes: 4 additions & 4 deletions src/arithmetic/constant.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::limb::Limb;
use crate::limb::LeakyLimb;
use core::mem::size_of;

const fn parse_digit(d: u8) -> u8 {
Expand All @@ -10,16 +10,16 @@ const fn parse_digit(d: u8) -> u8 {
}

// TODO: this would be nicer as a trait, but currently traits don't support const functions
pub const fn limbs_from_hex<const LIMBS: usize>(hex: &str) -> [Limb; LIMBS] {
pub const fn limbs_from_hex<const LIMBS: usize>(hex: &str) -> [LeakyLimb; LIMBS] {

Check warning on line 13 in src/arithmetic/constant.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/constant.rs#L13

Added line #L13 was not covered by tests
let hex = hex.as_bytes();
let mut limbs = [0; LIMBS];
let limb_nibbles = size_of::<Limb>() * 2;
let limb_nibbles = size_of::<LeakyLimb>() * 2;

Check warning on line 16 in src/arithmetic/constant.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/constant.rs#L16

Added line #L16 was not covered by tests
let mut i = 0;

while i < hex.len() {
let char = hex[hex.len() - 1 - i];
let val = parse_digit(char);
limbs[i / limb_nibbles] |= (val as Limb) << ((i % limb_nibbles) * 4);
limbs[i / limb_nibbles] |= (val as LeakyLimb) << ((i % limb_nibbles) * 4);

Check warning on line 22 in src/arithmetic/constant.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/constant.rs#L22

Added line #L22 was not covered by tests
i += 1;
}

Expand Down
11 changes: 8 additions & 3 deletions src/ec/suite_b.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ fn verify_affine_point_is_on_the_curve(
ops: &CommonOps,
(x, y): (&Elem<R>, &Elem<R>),
) -> Result<(), error::Unspecified> {
verify_affine_point_is_on_the_curve_scaled(ops, (x, y), &ops.a, &ops.b)
verify_affine_point_is_on_the_curve_scaled(
ops,
(x, y),
&Elem::from(&ops.a),
&Elem::from(&ops.b),
)
}

// Use `verify_affine_point_is_on_the_curve` instead of this function whenever
Expand Down Expand Up @@ -101,9 +106,9 @@ fn verify_jacobian_point_is_on_the_curve(
//
let z2 = ops.elem_squared(&z);
let z4 = ops.elem_squared(&z2);
let z4_a = ops.elem_product(&z4, &ops.a);
let z4_a = ops.elem_product(&z4, &Elem::from(&ops.a));
let z6 = ops.elem_product(&z4, &z2);
let z6_b = ops.elem_product(&z6, &ops.b);
let z6_b = ops.elem_product(&z6, &Elem::from(&ops.b));
verify_affine_point_is_on_the_curve_scaled(ops, (&x, &y), &z4_a, &z6_b)?;
Ok(z2)
}
Expand Down
3 changes: 2 additions & 1 deletion src/ec/suite_b/ecdsa/verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ impl EcdsaVerificationAlgorithm {
return Ok(());
}
if self.ops.elem_less_than(&r, &self.ops.q_minus_n) {
self.ops.scalar_ops.common.elem_add(&mut r, self.ops.n());
let n = Elem::from(self.ops.n());
self.ops.scalar_ops.common.elem_add(&mut r, &n);
if sig_r_equals_x(self.ops, &r, &x, &z2) {
return Ok(());
}
Expand Down
43 changes: 27 additions & 16 deletions src/ec/suite_b/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub use self::elem::*;
/// A field element, i.e. an element of ℤ/qℤ for the curve's field modulus
/// *q*.
pub type Elem<E> = elem::Elem<Q, E>;
type PublicElem<E> = elem::PublicElem<Q, E>;

/// Represents the (prime) order *q* of the curve's prime field.
#[derive(Clone, Copy)]
Expand All @@ -31,6 +32,7 @@ pub enum Q {}
/// A scalar. Its value is in [0, n). Zero-valued scalars are forbidden in most
/// contexts.
pub type Scalar<E = Unencoded> = elem::Elem<N, E>;
type PublicScalar<E> = elem::PublicElem<N, E>;

/// Represents the prime order *n* of the curve's group.
#[derive(Clone, Copy)]
Expand All @@ -57,10 +59,10 @@ impl Point {
pub struct CommonOps {
num_limbs: usize,
q: Modulus,
n: Elem<Unencoded>,
n: PublicElem<Unencoded>,

pub a: Elem<R>, // Must be -3 mod q
pub b: Elem<R>,
pub a: PublicElem<R>, // Must be -3 mod q
pub b: PublicElem<R>,

// In all cases, `r`, `a`, and `b` may all alias each other.
elem_mul_mont: unsafe extern "C" fn(r: *mut Limb, a: *const Limb, b: *const Limb),
Expand Down Expand Up @@ -98,8 +100,7 @@ impl CommonOps {

#[inline]
pub fn elem_unencoded(&self, a: &Elem<R>) -> Elem<Unencoded> {
const ONE: Elem<Unencoded> = Elem::from_hex("1");
self.elem_product(a, &ONE)
self.elem_product(a, &Elem::one())
}

#[inline]
Expand Down Expand Up @@ -171,8 +172,8 @@ impl CommonOps {
}

struct Modulus {
p: [Limb; MAX_LIMBS],
rr: [Limb; MAX_LIMBS],
p: [LeakyLimb; MAX_LIMBS],
rr: [LeakyLimb; MAX_LIMBS],
}

/// Operations on private keys, for ECDH and ECDSA signing.
Expand Down Expand Up @@ -301,11 +302,11 @@ pub struct PublicScalarOps {
cpu: cpu::Features,
) -> Point,
scalar_inv_to_mont_vartime: fn(s: &Scalar<Unencoded>, cpu: cpu::Features) -> Scalar<R>,
pub(super) q_minus_n: Elem<Unencoded>,
pub(super) q_minus_n: PublicElem<Unencoded>,
}

impl PublicScalarOps {
pub fn n(&self) -> &Elem<Unencoded> {
pub fn n(&self) -> &PublicElem<Unencoded> {
&self.scalar_ops.common.n
}

Expand All @@ -323,7 +324,7 @@ impl PublicScalarOps {
== b.limbs[..self.public_key_ops.common.num_limbs]
}

pub fn elem_less_than(&self, a: &Elem<Unencoded>, b: &Elem<Unencoded>) -> bool {
pub fn elem_less_than(&self, a: &Elem<Unencoded>, b: &PublicElem<Unencoded>) -> bool {
let num_limbs = self.public_key_ops.common.num_limbs;
limbs_less_than_limbs_vartime(&a.limbs[..num_limbs], &b.limbs[..num_limbs])
}
Expand All @@ -341,13 +342,14 @@ impl PublicScalarOps {
pub struct PrivateScalarOps {
pub scalar_ops: &'static ScalarOps,

oneRR_mod_n: Scalar<RR>, // 1 * R**2 (mod n). TOOD: Use One<RR>.
oneRR_mod_n: PublicScalar<RR>, // 1 * R**2 (mod n). TOOD: Use One<RR>.
scalar_inv_to_mont: fn(a: Scalar<R>, cpu: cpu::Features) -> Scalar<R>,
}

impl PrivateScalarOps {
pub(super) fn to_mont(&self, s: &Scalar<Unencoded>, cpu: cpu::Features) -> Scalar<R> {
self.scalar_ops.scalar_product(s, &self.oneRR_mod_n, cpu)
self.scalar_ops
.scalar_product(s, &Scalar::from(&self.oneRR_mod_n), cpu)
}

/// Returns the modular inverse of `a` (mod `n`). Panics if `a` is zero.
Expand Down Expand Up @@ -509,8 +511,8 @@ mod tests {

fn q_minus_n_plus_n_equals_0_test(ops: &PublicScalarOps) {
let cops = ops.scalar_ops.common;
let mut x = ops.q_minus_n;
cops.elem_add(&mut x, &cops.n);
let mut x = Elem::from(&ops.q_minus_n);
cops.elem_add(&mut x, &Elem::from(&cops.n));
assert!(cops.is_zero(&x));
}

Expand Down Expand Up @@ -958,19 +960,28 @@ mod tests {
/// TODO: We should be testing `point_mul` with points other than the generator.
#[test]
fn p256_point_mul_test() {
let generator = (
Elem::from(&p256::GENERATOR.0),
Elem::from(&p256::GENERATOR.1),
);
point_mul_base_tests(
&p256::PRIVATE_KEY_OPS,
|s, cpu| p256::PRIVATE_KEY_OPS.point_mul(s, &p256::GENERATOR, cpu),
|s, cpu| p256::PRIVATE_KEY_OPS.point_mul(s, &generator, cpu),
test_file!("ops/p256_point_mul_base_tests.txt"),
);
}

/// TODO: We should be testing `point_mul` with points other than the generator.
#[test]
fn p384_point_mul_test() {
let generator = (
Elem::from(&p384::GENERATOR.0),
Elem::from(&p384::GENERATOR.1),
);

point_mul_base_tests(
&p384::PRIVATE_KEY_OPS,
|s, cpu| p384::PRIVATE_KEY_OPS.point_mul(s, &p384::GENERATOR, cpu),
|s, cpu| p384::PRIVATE_KEY_OPS.point_mul(s, &generator, cpu),
test_file!("ops/p384_point_mul_base_tests.txt"),
);
}
Expand Down
32 changes: 29 additions & 3 deletions src/ec/suite_b/ops/elem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
use crate::{
arithmetic::{
limbs_from_hex,
montgomery::{Encoding, ProductEncoding},
montgomery::{Encoding, ProductEncoding, Unencoded},
},
limb::{Limb, LIMB_BITS},
limb::{LeakyLimb, Limb, LIMB_BITS},
};
use core::marker::PhantomData;

Expand All @@ -36,6 +36,22 @@ pub struct Elem<M, E: Encoding> {
pub(super) encoding: PhantomData<E>,
}

pub struct PublicElem<M, E: Encoding> {
pub(super) limbs: [LeakyLimb; MAX_LIMBS],
pub(super) m: PhantomData<M>,
pub(super) encoding: PhantomData<E>,
}

impl<M, E: Encoding> From<&PublicElem<M, E>> for Elem<M, E> {
fn from(value: &PublicElem<M, E>) -> Self {
Self {
limbs: core::array::from_fn(|i| Limb::from(value.limbs[i])),
m: value.m,
encoding: value.encoding,
}
}
}

impl<M, E: Encoding> Elem<M, E> {
// There's no need to convert `value` to the Montgomery domain since
// 0 * R**2 (mod m) == 0, so neither the modulus nor the encoding are needed
Expand All @@ -47,9 +63,19 @@ impl<M, E: Encoding> Elem<M, E> {
encoding: PhantomData,
}
}
}

impl<M> Elem<M, Unencoded> {
pub fn one() -> Self {
let mut r = Self::zero();
r.limbs[0] = 1;
r
}
}

impl<M, E: Encoding> PublicElem<M, E> {
pub const fn from_hex(hex: &str) -> Self {
Elem {
Self {

Check warning on line 78 in src/ec/suite_b/ops/elem.rs

View check run for this annotation

Codecov / codecov/patch

src/ec/suite_b/ops/elem.rs#L78

Added line #L78 was not covered by tests
limbs: limbs_from_hex(hex),
m: PhantomData,
encoding: PhantomData,
Expand Down
16 changes: 8 additions & 8 deletions src/ec/suite_b/ops/p256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ pub static COMMON_OPS: CommonOps = CommonOps {
p: limbs_from_hex("ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"),
rr: limbs_from_hex("4fffffffdfffffffffffffffefffffffbffffffff0000000000000003"),
},
n: Elem::from_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"),
n: PublicElem::from_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"),

a: Elem::from_hex("fffffffc00000004000000000000000000000003fffffffffffffffffffffffc"),
b: Elem::from_hex("dc30061d04874834e5a220abf7212ed6acf005cd78843090d89cdf6229c4bddf"),
a: PublicElem::from_hex("fffffffc00000004000000000000000000000003fffffffffffffffffffffffc"),
b: PublicElem::from_hex("dc30061d04874834e5a220abf7212ed6acf005cd78843090d89cdf6229c4bddf"),

elem_mul_mont: p256_mul_mont,
elem_sqr_mont: p256_sqr_mont,
Expand All @@ -36,9 +36,9 @@ pub static COMMON_OPS: CommonOps = CommonOps {
};

#[cfg(test)]
pub(super) static GENERATOR: (Elem<R>, Elem<R>) = (
Elem::from_hex("18905f76a53755c679fb732b7762251075ba95fc5fedb60179e730d418a9143c"),
Elem::from_hex("8571ff1825885d85d2e88688dd21f3258b4ab8e4ba19e45cddf25357ce95560a"),
pub(super) static GENERATOR: (PublicElem<R>, PublicElem<R>) = (
PublicElem::from_hex("18905f76a53755c679fb732b7762251075ba95fc5fedb60179e730d418a9143c"),
PublicElem::from_hex("8571ff1825885d85d2e88688dd21f3258b4ab8e4ba19e45cddf25357ce95560a"),
);

pub static PRIVATE_KEY_OPS: PrivateKeyOps = PrivateKeyOps {
Expand Down Expand Up @@ -129,7 +129,7 @@ pub static PUBLIC_SCALAR_OPS: PublicScalarOps = PublicScalarOps {
twin_mul_inefficient(&PRIVATE_KEY_OPS, g_scalar, p_scalar, p_xy, cpu)
},

q_minus_n: Elem::from_hex("4319055358e8617b0c46353d039cdaae"),
q_minus_n: PublicElem::from_hex("4319055358e8617b0c46353d039cdaae"),

// TODO: Use an optimized variable-time implementation.
scalar_inv_to_mont_vartime: |s, cpu| PRIVATE_SCALAR_OPS.scalar_inv_to_mont(s, cpu),
Expand Down Expand Up @@ -164,7 +164,7 @@ fn point_mul_base_vartime(g_scalar: &Scalar, _cpu: cpu::Features) -> Point {
pub static PRIVATE_SCALAR_OPS: PrivateScalarOps = PrivateScalarOps {
scalar_ops: &SCALAR_OPS,

oneRR_mod_n: Scalar::from_hex(
oneRR_mod_n: PublicScalar::from_hex(
"66e12d94f3d956202845b2392b6bec594699799c49bd6fa683244c95be79eea2",
),
scalar_inv_to_mont: p256_scalar_inv_to_mont,
Expand Down
Loading

0 comments on commit 4f915bc

Please sign in to comment.