Skip to content

Commit

Permalink
arithmetic: Remove PartialEq & Debug for LimbMask.
Browse files Browse the repository at this point in the history
Add `LimbMask::leak()` and change all callers to use it. This
proactively prevents accidental leakage of the `LimbMask` value
and makes it easier to audit the code for places where we
intentionally leak the value of a `LimbMask`.

Within the tests, use a `#[cfg(test)]-only wrapper `leak_in_test` to
make it easier to see that those leaks are uninteresting.
  • Loading branch information
briansmith committed Dec 7, 2024
1 parent cb6d5de commit df14ce9
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 29 deletions.
8 changes: 4 additions & 4 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use crate::{
arithmetic::montgomery::*,
bits::BitLength,
c, error,
limb::{self, Limb, LimbMask, LIMB_BITS},
limb::{self, Limb, LIMB_BITS},
};
use alloc::vec;
use core::{marker::PhantomData, num::NonZeroU64};
Expand Down Expand Up @@ -85,7 +85,7 @@ impl<M, E> Clone for Elem<M, E> {
impl<M, E> Elem<M, E> {
#[inline]
pub fn is_zero(&self) -> bool {
limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
limb::limbs_are_zero_constant_time(&self.limbs).leak()
}
}

Expand Down Expand Up @@ -132,7 +132,7 @@ impl<M> Elem<M, Unencoded> {
}

fn is_one(&self) -> bool {
limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
limb::limbs_equal_limb_constant_time(&self.limbs, 1).leak()
}
}

Expand Down Expand Up @@ -696,7 +696,7 @@ pub fn elem_verify_equal_consttime<M, E>(
a: &Elem<M, E>,
b: &Elem<M, E>,
) -> Result<(), error::Unspecified> {
if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs).leak() {
Ok(())
} else {
Err(error::Unspecified)
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/bigint/boxed_limbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use super::Modulus;
use crate::{
error,
limb::{self, Limb, LimbMask, LIMB_BYTES},
limb::{self, Limb, LIMB_BYTES},
};
use alloc::{boxed::Box, vec};
use core::{
Expand Down Expand Up @@ -88,7 +88,7 @@ impl<M> BoxedLimbs<M> {
) -> Result<Self, error::Unspecified> {
let mut r = Self::zero(m.limbs().len());
limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
if limb::limbs_less_than_limbs_consttime(&r, m.limbs()) != LimbMask::True {
if !limb::limbs_less_than_limbs_consttime(&r, m.limbs()).leak() {
return Err(error::Unspecified);
}
Ok(r)
Expand Down
8 changes: 4 additions & 4 deletions src/arithmetic/bigint/modulusvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::{
use crate::{
bits::BitLength,
error,
limb::{self, Limb, LimbMask},
limb::{self, Limb},
};

/// `OwnedModulus`, without the overhead of Montgomery multiplication support.
Expand Down Expand Up @@ -47,10 +47,10 @@ impl<M> OwnedModulusValue<M> {
if n.len() < MODULUS_MIN_LIMBS {
return Err(error::KeyRejected::unexpected_error());
}
if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
if limb::limbs_are_even_constant_time(&n).leak() {
return Err(error::KeyRejected::invalid_component());
}
if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
if limb::limbs_less_than_limb_constant_time(&n, 3).leak() {
return Err(error::KeyRejected::unexpected_error());
}

Expand All @@ -62,7 +62,7 @@ impl<M> OwnedModulusValue<M> {
pub fn verify_less_than<L>(&self, l: &Modulus<L>) -> Result<(), error::Unspecified> {
if self.len_bits() > l.len_bits()
|| (self.limbs.len() == l.limbs().len()
&& limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()) != LimbMask::True)
&& !limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()).leak())
{
return Err(error::Unspecified);
}
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/bigint/private_exponent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{limb, BoxedLimbs, Limb, LimbMask, Modulus};
use super::{limb, BoxedLimbs, Limb, Modulus};
use crate::error;
use alloc::boxed::Box;

Expand All @@ -36,7 +36,7 @@ impl PrivateExponent {
// `p - 1` and so we know `dP < p - 1`.
//
// Further we know `dP != 0` because `dP` is not even.
if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
if limb::limbs_are_even_constant_time(&dP).leak() {
return Err(error::Unspecified);
}

Expand Down
4 changes: 2 additions & 2 deletions src/ec/suite_b.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//! Elliptic curve operations on P-256 & P-384.
use self::ops::*;
use crate::{arithmetic::montgomery::*, cpu, ec, error, io::der, limb::LimbMask, pkcs8};
use crate::{arithmetic::montgomery::*, cpu, ec, error, io::der, pkcs8};

// NIST SP 800-56A Step 3: "If q is an odd prime p, verify that
// yQ**2 = xQ**3 + axQ + b in GF(p), where the arithmetic is performed modulo
Expand Down Expand Up @@ -146,7 +146,7 @@ fn verify_affine_point_is_on_the_curve_scaled(
ops.elem_mul(&mut rhs, x);
ops.elem_add(&mut rhs, b_scaled);

if ops.elems_are_equal(&lhs, &rhs) != LimbMask::True {
if !ops.elems_are_equal(&lhs, &rhs).leak() {
return Err(error::Unspecified);
}

Expand Down
2 changes: 1 addition & 1 deletion src/ec/suite_b/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl CommonOps {

#[inline]
pub fn is_zero<M, E: Encoding>(&self, a: &elem::Elem<M, E>) -> bool {
limbs_are_zero_constant_time(&a.limbs[..self.num_limbs]) == LimbMask::True
limbs_are_zero_constant_time(&a.limbs[..self.num_limbs]).leak()
}

pub fn elem_verify_is_not_zero(&self, a: &Elem<R>) -> Result<(), error::Unspecified> {
Expand Down
39 changes: 25 additions & 14 deletions src/limb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,19 @@ pub const LIMB_BITS: usize = usize_from_u32(Limb::BITS);

#[cfg_attr(target_pointer_width = "64", repr(u64))]
#[cfg_attr(target_pointer_width = "32", repr(u32))]
#[derive(Debug, PartialEq)]
pub enum LimbMask {
#[cfg_attr(not(test), allow(dead_code))] // Only constructed by non-Rust & test code.
True = Limb::MAX,
#[cfg_attr(not(test), allow(dead_code))] // Only constructed by non-Rust & test code.
False = 0,
}

impl LimbMask {
pub fn leak(self) -> bool {
!matches!(self, LimbMask::False)
}
}

pub const LIMB_BYTES: usize = (LIMB_BITS + 7) / 8;

#[inline]
Expand All @@ -58,7 +65,7 @@ pub fn limbs_less_than_limbs_consttime(a: &[Limb], b: &[Limb]) -> LimbMask {

#[inline]
pub fn limbs_less_than_limbs_vartime(a: &[Limb], b: &[Limb]) -> bool {
limbs_less_than_limbs_consttime(a, b) == LimbMask::True
limbs_less_than_limbs_consttime(a, b).leak()
}

#[inline]
Expand Down Expand Up @@ -142,11 +149,11 @@ pub fn parse_big_endian_in_range_and_pad_consttime(
result: &mut [Limb],
) -> Result<(), error::Unspecified> {
parse_big_endian_and_pad_consttime(input, result)?;
if limbs_less_than_limbs_consttime(result, max_exclusive) != LimbMask::True {
if !limbs_less_than_limbs_consttime(result, max_exclusive).leak() {
return Err(error::Unspecified);
}
if allow_zero != AllowZero::Yes {
if limbs_are_zero_constant_time(result) != LimbMask::False {
if limbs_are_zero_constant_time(result).leak() {
return Err(error::Unspecified);
}
}
Expand Down Expand Up @@ -362,6 +369,10 @@ mod tests {

const MAX: Limb = Limb::MAX;

fn leak_in_test(a: LimbMask) -> bool {
a.leak()
}

#[test]
fn test_limbs_are_even() {
static EVENS: &[&[Limb]] = &[
Expand All @@ -376,7 +387,7 @@ mod tests {
&[0, 0, 0, 0, MAX],
];
for even in EVENS {
assert_eq!(limbs_are_even_constant_time(even), LimbMask::True);
assert!(leak_in_test(limbs_are_even_constant_time(even)));
}
static ODDS: &[&[Limb]] = &[
&[1],
Expand All @@ -389,7 +400,7 @@ mod tests {
&[1, 0, 0, 0, MAX],
];
for odd in ODDS {
assert_eq!(limbs_are_even_constant_time(odd), LimbMask::False);
assert!(!leak_in_test(limbs_are_even_constant_time(odd)));
}
}

Expand Down Expand Up @@ -418,20 +429,20 @@ mod tests {
#[test]
fn test_limbs_are_zero() {
for zero in ZEROES {
assert_eq!(limbs_are_zero_constant_time(zero), LimbMask::True);
assert!(leak_in_test(limbs_are_zero_constant_time(zero)));
}
for nonzero in NONZEROES {
assert_eq!(limbs_are_zero_constant_time(nonzero), LimbMask::False);
assert!(!leak_in_test(limbs_are_zero_constant_time(nonzero)));
}
}

#[test]
fn test_limbs_equal_limb() {
for zero in ZEROES {
assert_eq!(limbs_equal_limb_constant_time(zero, 0), LimbMask::True);
assert!(leak_in_test(limbs_equal_limb_constant_time(zero, 0)));
}
for nonzero in NONZEROES {
assert_eq!(limbs_equal_limb_constant_time(nonzero, 0), LimbMask::False);
assert!(!leak_in_test(limbs_equal_limb_constant_time(nonzero, 0)));
}
static EQUAL: &[(&[Limb], Limb)] = &[
(&[1], 1),
Expand All @@ -442,7 +453,7 @@ mod tests {
(&[0b100, 0], 0b100),
];
for &(a, b) in EQUAL {
assert_eq!(limbs_equal_limb_constant_time(a, b), LimbMask::True);
assert!(leak_in_test(limbs_equal_limb_constant_time(a, b)));
}
static UNEQUAL: &[(&[Limb], Limb)] = &[
(&[0], 1),
Expand All @@ -456,7 +467,7 @@ mod tests {
(&[MAX, 1], MAX),
];
for &(a, b) in UNEQUAL {
assert_eq!(limbs_equal_limb_constant_time(a, b), LimbMask::False);
assert!(!leak_in_test(limbs_equal_limb_constant_time(a, b)));
}
}

Expand All @@ -473,7 +484,7 @@ mod tests {
(&[MAX - 1, 0], MAX),
];
for &(a, b) in LESSER {
assert_eq!(limbs_less_than_limb_constant_time(a, b), LimbMask::True);
assert!(leak_in_test(limbs_less_than_limb_constant_time(a, b)));
}
static EQUAL: &[(&[Limb], Limb)] = &[
(&[0], 0),
Expand All @@ -492,7 +503,7 @@ mod tests {
(&[MAX], MAX - 1),
];
for &(a, b) in EQUAL.iter().chain(GREATER.iter()) {
assert_eq!(limbs_less_than_limb_constant_time(a, b), LimbMask::False);
assert!(!leak_in_test(limbs_less_than_limb_constant_time(a, b)));
}
}

Expand Down

0 comments on commit df14ce9

Please sign in to comment.