Skip to content

Commit

Permalink
arithmetic: add unit tests (#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoratger authored Nov 19, 2024
1 parent 3af3fc1 commit 2ef9d27
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 47 deletions.
265 changes: 221 additions & 44 deletions ff/src/biginteger/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn adc_for_add_with_carry(a: &mut u64, b: u64, carry: u8) -> u8 {
/// Calculate a + b + carry, returning the sum
#[inline(always)]
#[doc(hidden)]
pub fn adc_no_carry(a: u64, b: u64, carry: &mut u64) -> u64 {
pub fn adc_no_carry(a: u64, b: u64, carry: &u64) -> u64 {
let tmp = a as u128 + b as u128 + *carry as u128;
tmp as u64
}
Expand Down Expand Up @@ -230,51 +230,228 @@ pub fn find_relaxed_naf(num: &[u64]) -> Vec<i8> {
res
}

#[test]
fn test_find_relaxed_naf_usefulness() {
let vec = find_naf(&[12u64]);
assert_eq!(vec.len(), 5);
#[cfg(test)]
mod tests {
use super::*;

let vec = find_relaxed_naf(&[12u64]);
assert_eq!(vec.len(), 4);
}
#[test]
fn test_adc() {
// Test addition without initial carry
let mut a = 5u64;
let carry = adc(&mut a, 10u64, 0);
assert_eq!(a, 15); // 5 + 10 = 15
assert_eq!(carry, 0); // No carry should be generated

#[test]
fn test_find_relaxed_naf_correctness() {
use ark_std::{One, UniformRand, Zero};
use num_bigint::BigInt;

let mut rng = ark_std::test_rng();

for _ in 0..10 {
let num = [
u64::rand(&mut rng),
u64::rand(&mut rng),
u64::rand(&mut rng),
u64::rand(&mut rng),
];
let relaxed_naf = find_relaxed_naf(&num);

let test = {
let mut sum = BigInt::zero();
let mut cur = BigInt::one();
for v in relaxed_naf {
sum += cur.clone() * v;
cur *= 2;
}
sum
};

let test_expected = {
let mut sum = BigInt::zero();
let mut cur = BigInt::one();
for v in num.iter() {
sum += cur.clone() * v;
cur <<= 64;
}
sum
};
// Test addition with carry when overflowing u64
let mut a = u64::MAX;
let carry = adc(&mut a, 1u64, 0);
assert_eq!(a, 0); // Overflow resets `a` to 0
assert_eq!(carry, 1); // Carry is 1 due to overflow

// Test addition with a non-zero initial carry
let mut a = 5u64;
let carry = adc(&mut a, 10u64, 1);
assert_eq!(a, 16); // 5 + 10 + 1 = 16
assert_eq!(carry, 0); // No overflow, so carry remains 0

// Test addition with carry and a large sum
let mut a = u64::MAX - 5;
let carry = adc(&mut a, 10u64, 1);
assert_eq!(a, 5); // (u64::MAX - 5 + 10 + 1) wraps around to 5
assert_eq!(carry, 1); // Carry is 1 due to overflow
}

#[test]
fn test_adc_for_add_with_carry() {
// Test addition without initial carry
let mut a = 5u64;
let carry = adc_for_add_with_carry(&mut a, 10u64, 0);
assert_eq!(a, 15); // Expect a to be 15
assert_eq!(carry, 0); // No carry should be generated

// Test addition with carry when overflowing u64
let mut a = u64::MAX;
let carry = adc_for_add_with_carry(&mut a, 1u64, 0);
assert_eq!(a, 0); // Overflow resets `a` to 0
assert_eq!(carry, 1); // Carry is 1 due to overflow

// Test addition with a non-zero initial carry
let mut a = 5u64;
let carry = adc_for_add_with_carry(&mut a, 10u64, 1);
assert_eq!(a, 16); // 5 + 10 + 1 = 16
assert_eq!(carry, 0); // No overflow, so carry remains 0

// Test addition with carry and a large sum
let mut a = u64::MAX - 5;
let carry = adc_for_add_with_carry(&mut a, 10u64, 1);
assert_eq!(a, 5); // (u64::MAX - 5 + 10 + 1) wraps around to 5
assert_eq!(carry, 1); // Carry is 1 due to overflow
}

#[test]
fn test_adc_no_carry() {
// Test addition without initial carry
let mut carry = 0;
let result = adc_no_carry(5u64, 10u64, &mut carry);
assert_eq!(result, 15); // 5 + 10 = 15
assert_eq!(carry, 0); // No carry should be generated

// Test addition with a non-zero initial carry
let mut carry = 1;
let result = adc_no_carry(5u64, 10u64, &mut carry);
assert_eq!(result, 16); // 5 + 10 + 1 = 16
assert_eq!(carry, 1); // No overflow, so carry remains 1

// Test addition that causes a carry
let mut carry = 1;
let result = adc_no_carry(u64::MAX, 1u64, &mut carry);
assert_eq!(result, 1); // u64::MAX + 1 + 1 -> 1
assert_eq!(carry, 1); // Carry is 1 due to overflow
}

#[test]
fn test_sbb() {
// Test subtraction without initial borrow
let mut a = 15u64;
let borrow = sbb(&mut a, 5u64, 0);
assert_eq!(a, 10); // 15 - 5 = 10
assert_eq!(borrow, 0); // No borrow should be generated

// Test subtraction that causes a borrow
let mut a = 5u64;
let borrow = sbb(&mut a, 10u64, 0);
assert_eq!(a, u64::MAX - 4); // Underflow, wrapping around
assert_eq!(borrow, 1); // Borrow should be 1

// Test subtraction with a non-zero initial borrow
let mut a = 15u64;
let borrow = sbb(&mut a, 5u64, 1);
assert_eq!(a, 9); // 15 - 5 - 1 = 9
assert_eq!(borrow, 0); // No borrow should be generated

// Test subtraction with borrow and a large value
let mut a = 0u64;
let borrow = sbb(&mut a, u64::MAX, 1);
assert_eq!(a, 0); // 0 - (u64::MAX + 1) -> 0
assert_eq!(borrow, 1); // Borrow should be 1
}

#[test]
fn test_sbb_for_sub_with_borrow() {
// Test subtraction without initial borrow
let mut a = 15u64;
let borrow = sbb_for_sub_with_borrow(&mut a, 5u64, 0);
assert_eq!(a, 10); // Expect a to be 10
assert_eq!(borrow, 0); // No borrow should be generated

// Test subtraction that causes a borrow
let mut a = 5u64;
let borrow = sbb_for_sub_with_borrow(&mut a, 10u64, 0);
assert_eq!(a, u64::MAX - 4); // Underflow, wrapping around
assert_eq!(borrow, 1); // Borrow should be 1

// Test subtraction with a non-zero initial borrow
let mut a = 15u64;
let borrow = sbb_for_sub_with_borrow(&mut a, 5u64, 1);
assert_eq!(a, 9); // 15 - 5 - 1 = 9
assert_eq!(borrow, 0); // No borrow should be generated

assert_eq!(test, test_expected);
// Test subtraction with borrow and a large value
let mut a = 0u64;
let borrow = sbb_for_sub_with_borrow(&mut a, u64::MAX, 1);
assert_eq!(a, 0); // 0 - (u64::MAX + 1) -> 0
assert_eq!(borrow, 1); // Borrow should be 1
}

#[test]
fn test_mac() {
// Basic multiply-accumulate without carry
let mut carry = 0;
let result = mac(1u64, 2u64, 3u64, &mut carry);
assert_eq!(result, 7); // 1 + (2 * 3) = 7
assert_eq!(carry, 0); // No overflow, carry remains 0

// Multiply-accumulate with large values that generate a carry
let mut carry = 0;
let result = mac(u64::MAX, u64::MAX, 1u64, &mut carry);
assert_eq!(result, u64::MAX - 1); // Result wraps around
assert_eq!(carry, 1); // Carry is set due to overflow
}

#[test]
fn test_mac_discard() {
// Discard lower 64 bits and set carry
let mut carry = 0;
mac_discard(1u64, 2u64, 3u64, &mut carry);
assert_eq!(carry, 0); // No overflow, carry remains 0

// Test with values that generate a carry
let mut carry = 0;
mac_discard(u64::MAX, u64::MAX, 1u64, &mut carry);
assert_eq!(carry, 1); // Carry is set due to overflow
}

#[test]
fn test_mac_with_carry() {
// Basic multiply-accumulate with carry
let mut carry = 1;
let result = mac_with_carry(1u64, 2u64, 3u64, &mut carry);
assert_eq!(result, 8); // 1 + (2 * 3) + 1 = 8
assert_eq!(carry, 0); // No overflow, carry remains 0

// Multiply-accumulate with carry and large values
let mut carry = 1;
let result = mac_with_carry(u64::MAX, u64::MAX, 1u64, &mut carry);
assert_eq!(result, u64::MAX); // Result wraps around
assert_eq!(carry, 1); // Carry is set due to overflow
}

#[test]
fn test_find_relaxed_naf_usefulness() {
let vec = find_naf(&[12u64]);
assert_eq!(vec.len(), 5);

let vec = find_relaxed_naf(&[12u64]);
assert_eq!(vec.len(), 4);
}

#[test]
fn test_find_relaxed_naf_correctness() {
use ark_std::{One, UniformRand, Zero};
use num_bigint::BigInt;

let mut rng = ark_std::test_rng();

for _ in 0..10 {
let num = [
u64::rand(&mut rng),
u64::rand(&mut rng),
u64::rand(&mut rng),
u64::rand(&mut rng),
];
let relaxed_naf = find_relaxed_naf(&num);

let test = {
let mut sum = BigInt::zero();
let mut cur = BigInt::one();
for v in relaxed_naf {
sum += cur.clone() * v;
cur *= 2;
}
sum
};

let test_expected = {
let mut sum = BigInt::zero();
let mut cur = BigInt::one();
for v in num.iter() {
sum += cur.clone() * v;
cur <<= 64;
}
sum
};

assert_eq!(test, test_expected);
}
}
}
6 changes: 3 additions & 3 deletions ff/src/fields/models/fp/montgomery_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
result.0[i - 1] =
fa::mac_with_carry(result.0[i], k, Self::MODULUS.0[i], &mut carry2);
}
result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &mut carry2);
result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &carry2);
result
});
let mut result = Fp::new_unchecked(result);
Expand Down Expand Up @@ -465,7 +465,7 @@ pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
temp.0[k] =
fa::mac_with_carry(temp.0[k], a.0[j], b.0[k], &mut carry2);
}
carry = fa::adc_no_carry(carry, 0, &mut carry2);
carry = fa::adc_no_carry(carry, 0, &carry2);
(temp, carry)
},
);
Expand All @@ -477,7 +477,7 @@ pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
result.0[i - 1] =
fa::mac_with_carry(temp.0[i], k, Self::MODULUS.0[i], &mut carry2);
}
result.0[N - 1] = fa::adc_no_carry(carry, 0, &mut carry2);
result.0[N - 1] = fa::adc_no_carry(carry, 0, &carry2);
result
});
let mut result = Fp::new_unchecked(result);
Expand Down

0 comments on commit 2ef9d27

Please sign in to comment.