From 2ef9d27a8e1e277112b6f04e65e03dab59bbb87c Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Tue, 19 Nov 2024 19:00:31 +0100 Subject: [PATCH] arithmetic: add unit tests (#879) --- ff/src/biginteger/arithmetic.rs | 265 +++++++++++++++--- ff/src/fields/models/fp/montgomery_backend.rs | 6 +- 2 files changed, 224 insertions(+), 47 deletions(-) diff --git a/ff/src/biginteger/arithmetic.rs b/ff/src/biginteger/arithmetic.rs index 90ebe0a6e..227c75593 100644 --- a/ff/src/biginteger/arithmetic.rs +++ b/ff/src/biginteger/arithmetic.rs @@ -41,7 +41,7 @@ pub fn adc_for_add_with_carry(a: &mut u64, b: u64, carry: u8) -> u8 { /// Calculate a + b + carry, returning the sum #[inline(always)] #[doc(hidden)] -pub fn adc_no_carry(a: u64, b: u64, carry: &mut u64) -> u64 { +pub fn adc_no_carry(a: u64, b: u64, carry: &u64) -> u64 { let tmp = a as u128 + b as u128 + *carry as u128; tmp as u64 } @@ -230,51 +230,228 @@ pub fn find_relaxed_naf(num: &[u64]) -> Vec { res } -#[test] -fn test_find_relaxed_naf_usefulness() { - let vec = find_naf(&[12u64]); - assert_eq!(vec.len(), 5); +#[cfg(test)] +mod tests { + use super::*; - let vec = find_relaxed_naf(&[12u64]); - assert_eq!(vec.len(), 4); -} + #[test] + fn test_adc() { + // Test addition without initial carry + let mut a = 5u64; + let carry = adc(&mut a, 10u64, 0); + assert_eq!(a, 15); // 5 + 10 = 15 + assert_eq!(carry, 0); // No carry should be generated -#[test] -fn test_find_relaxed_naf_correctness() { - use ark_std::{One, UniformRand, Zero}; - use num_bigint::BigInt; - - let mut rng = ark_std::test_rng(); - - for _ in 0..10 { - let num = [ - u64::rand(&mut rng), - u64::rand(&mut rng), - u64::rand(&mut rng), - u64::rand(&mut rng), - ]; - let relaxed_naf = find_relaxed_naf(&num); - - let test = { - let mut sum = BigInt::zero(); - let mut cur = BigInt::one(); - for v in relaxed_naf { - sum += cur.clone() * v; - cur *= 2; - } - sum - }; - - let test_expected = { - let mut sum = BigInt::zero(); - let mut cur = BigInt::one(); - for v in num.iter() { - sum += cur.clone() * v; - cur <<= 64; - } - sum - }; + // Test addition with carry when overflowing u64 + let mut a = u64::MAX; + let carry = adc(&mut a, 1u64, 0); + assert_eq!(a, 0); // Overflow resets `a` to 0 + assert_eq!(carry, 1); // Carry is 1 due to overflow + + // Test addition with a non-zero initial carry + let mut a = 5u64; + let carry = adc(&mut a, 10u64, 1); + assert_eq!(a, 16); // 5 + 10 + 1 = 16 + assert_eq!(carry, 0); // No overflow, so carry remains 0 + + // Test addition with carry and a large sum + let mut a = u64::MAX - 5; + let carry = adc(&mut a, 10u64, 1); + assert_eq!(a, 5); // (u64::MAX - 5 + 10 + 1) wraps around to 5 + assert_eq!(carry, 1); // Carry is 1 due to overflow + } + + #[test] + fn test_adc_for_add_with_carry() { + // Test addition without initial carry + let mut a = 5u64; + let carry = adc_for_add_with_carry(&mut a, 10u64, 0); + assert_eq!(a, 15); // Expect a to be 15 + assert_eq!(carry, 0); // No carry should be generated + + // Test addition with carry when overflowing u64 + let mut a = u64::MAX; + let carry = adc_for_add_with_carry(&mut a, 1u64, 0); + assert_eq!(a, 0); // Overflow resets `a` to 0 + assert_eq!(carry, 1); // Carry is 1 due to overflow + + // Test addition with a non-zero initial carry + let mut a = 5u64; + let carry = adc_for_add_with_carry(&mut a, 10u64, 1); + assert_eq!(a, 16); // 5 + 10 + 1 = 16 + assert_eq!(carry, 0); // No overflow, so carry remains 0 + + // Test addition with carry and a large sum + let mut a = u64::MAX - 5; + let carry = adc_for_add_with_carry(&mut a, 10u64, 1); + assert_eq!(a, 5); // (u64::MAX - 5 + 10 + 1) wraps around to 5 + assert_eq!(carry, 1); // Carry is 1 due to overflow + } + + #[test] + fn test_adc_no_carry() { + // Test addition without initial carry + let mut carry = 0; + let result = adc_no_carry(5u64, 10u64, &mut carry); + assert_eq!(result, 15); // 5 + 10 = 15 + assert_eq!(carry, 0); // No carry should be generated + + // Test addition with a non-zero initial carry + let mut carry = 1; + let result = adc_no_carry(5u64, 10u64, &mut carry); + assert_eq!(result, 16); // 5 + 10 + 1 = 16 + assert_eq!(carry, 1); // No overflow, so carry remains 1 + + // Test addition that causes a carry + let mut carry = 1; + let result = adc_no_carry(u64::MAX, 1u64, &mut carry); + assert_eq!(result, 1); // u64::MAX + 1 + 1 -> 1 + assert_eq!(carry, 1); // Carry is 1 due to overflow + } + + #[test] + fn test_sbb() { + // Test subtraction without initial borrow + let mut a = 15u64; + let borrow = sbb(&mut a, 5u64, 0); + assert_eq!(a, 10); // 15 - 5 = 10 + assert_eq!(borrow, 0); // No borrow should be generated + + // Test subtraction that causes a borrow + let mut a = 5u64; + let borrow = sbb(&mut a, 10u64, 0); + assert_eq!(a, u64::MAX - 4); // Underflow, wrapping around + assert_eq!(borrow, 1); // Borrow should be 1 + + // Test subtraction with a non-zero initial borrow + let mut a = 15u64; + let borrow = sbb(&mut a, 5u64, 1); + assert_eq!(a, 9); // 15 - 5 - 1 = 9 + assert_eq!(borrow, 0); // No borrow should be generated + + // Test subtraction with borrow and a large value + let mut a = 0u64; + let borrow = sbb(&mut a, u64::MAX, 1); + assert_eq!(a, 0); // 0 - (u64::MAX + 1) -> 0 + assert_eq!(borrow, 1); // Borrow should be 1 + } + + #[test] + fn test_sbb_for_sub_with_borrow() { + // Test subtraction without initial borrow + let mut a = 15u64; + let borrow = sbb_for_sub_with_borrow(&mut a, 5u64, 0); + assert_eq!(a, 10); // Expect a to be 10 + assert_eq!(borrow, 0); // No borrow should be generated + + // Test subtraction that causes a borrow + let mut a = 5u64; + let borrow = sbb_for_sub_with_borrow(&mut a, 10u64, 0); + assert_eq!(a, u64::MAX - 4); // Underflow, wrapping around + assert_eq!(borrow, 1); // Borrow should be 1 + + // Test subtraction with a non-zero initial borrow + let mut a = 15u64; + let borrow = sbb_for_sub_with_borrow(&mut a, 5u64, 1); + assert_eq!(a, 9); // 15 - 5 - 1 = 9 + assert_eq!(borrow, 0); // No borrow should be generated - assert_eq!(test, test_expected); + // Test subtraction with borrow and a large value + let mut a = 0u64; + let borrow = sbb_for_sub_with_borrow(&mut a, u64::MAX, 1); + assert_eq!(a, 0); // 0 - (u64::MAX + 1) -> 0 + assert_eq!(borrow, 1); // Borrow should be 1 + } + + #[test] + fn test_mac() { + // Basic multiply-accumulate without carry + let mut carry = 0; + let result = mac(1u64, 2u64, 3u64, &mut carry); + assert_eq!(result, 7); // 1 + (2 * 3) = 7 + assert_eq!(carry, 0); // No overflow, carry remains 0 + + // Multiply-accumulate with large values that generate a carry + let mut carry = 0; + let result = mac(u64::MAX, u64::MAX, 1u64, &mut carry); + assert_eq!(result, u64::MAX - 1); // Result wraps around + assert_eq!(carry, 1); // Carry is set due to overflow + } + + #[test] + fn test_mac_discard() { + // Discard lower 64 bits and set carry + let mut carry = 0; + mac_discard(1u64, 2u64, 3u64, &mut carry); + assert_eq!(carry, 0); // No overflow, carry remains 0 + + // Test with values that generate a carry + let mut carry = 0; + mac_discard(u64::MAX, u64::MAX, 1u64, &mut carry); + assert_eq!(carry, 1); // Carry is set due to overflow + } + + #[test] + fn test_mac_with_carry() { + // Basic multiply-accumulate with carry + let mut carry = 1; + let result = mac_with_carry(1u64, 2u64, 3u64, &mut carry); + assert_eq!(result, 8); // 1 + (2 * 3) + 1 = 8 + assert_eq!(carry, 0); // No overflow, carry remains 0 + + // Multiply-accumulate with carry and large values + let mut carry = 1; + let result = mac_with_carry(u64::MAX, u64::MAX, 1u64, &mut carry); + assert_eq!(result, u64::MAX); // Result wraps around + assert_eq!(carry, 1); // Carry is set due to overflow + } + + #[test] + fn test_find_relaxed_naf_usefulness() { + let vec = find_naf(&[12u64]); + assert_eq!(vec.len(), 5); + + let vec = find_relaxed_naf(&[12u64]); + assert_eq!(vec.len(), 4); + } + + #[test] + fn test_find_relaxed_naf_correctness() { + use ark_std::{One, UniformRand, Zero}; + use num_bigint::BigInt; + + let mut rng = ark_std::test_rng(); + + for _ in 0..10 { + let num = [ + u64::rand(&mut rng), + u64::rand(&mut rng), + u64::rand(&mut rng), + u64::rand(&mut rng), + ]; + let relaxed_naf = find_relaxed_naf(&num); + + let test = { + let mut sum = BigInt::zero(); + let mut cur = BigInt::one(); + for v in relaxed_naf { + sum += cur.clone() * v; + cur *= 2; + } + sum + }; + + let test_expected = { + let mut sum = BigInt::zero(); + let mut cur = BigInt::one(); + for v in num.iter() { + sum += cur.clone() * v; + cur <<= 64; + } + sum + }; + + assert_eq!(test, test_expected); + } } } diff --git a/ff/src/fields/models/fp/montgomery_backend.rs b/ff/src/fields/models/fp/montgomery_backend.rs index 04114f482..06fc84a6a 100644 --- a/ff/src/fields/models/fp/montgomery_backend.rs +++ b/ff/src/fields/models/fp/montgomery_backend.rs @@ -437,7 +437,7 @@ pub trait MontConfig: 'static + Sync + Send + Sized { result.0[i - 1] = fa::mac_with_carry(result.0[i], k, Self::MODULUS.0[i], &mut carry2); } - result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &mut carry2); + result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &carry2); result }); let mut result = Fp::new_unchecked(result); @@ -465,7 +465,7 @@ pub trait MontConfig: 'static + Sync + Send + Sized { temp.0[k] = fa::mac_with_carry(temp.0[k], a.0[j], b.0[k], &mut carry2); } - carry = fa::adc_no_carry(carry, 0, &mut carry2); + carry = fa::adc_no_carry(carry, 0, &carry2); (temp, carry) }, ); @@ -477,7 +477,7 @@ pub trait MontConfig: 'static + Sync + Send + Sized { result.0[i - 1] = fa::mac_with_carry(temp.0[i], k, Self::MODULUS.0[i], &mut carry2); } - result.0[N - 1] = fa::adc_no_carry(carry, 0, &mut carry2); + result.0[N - 1] = fa::adc_no_carry(carry, 0, &carry2); result }); let mut result = Fp::new_unchecked(result);