From 586b58b4bab477a2b5a46faa74045c222b367ca7 Mon Sep 17 00:00:00 2001 From: Soptq <32592090+Soptq@users.noreply.github.com> Date: Thu, 25 Jan 2024 23:19:42 +0800 Subject: [PATCH] feat: allow generic norm --- src/linalg/src/norm.cairo | 14 ++++++++------ src/linalg/src/tests/norm_test.cairo | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/linalg/src/norm.cairo b/src/linalg/src/norm.cairo index e8e83025..c0a81de0 100644 --- a/src/linalg/src/norm.cairo +++ b/src/linalg/src/norm.cairo @@ -1,25 +1,27 @@ -//! Norm of an u128 array. +//! Norm of an T array. use alexandria_math::fast_root::fast_nr_optimize; use alexandria_math::pow; -/// Compute the norm for an u128 array. +/// Compute the norm for an T array. /// # Arguments /// * `array` - The inputted array. /// * `ord` - The order of the norm. /// * `iter` - The number of iterations to run the algorithm /// # Returns /// * `u128` - The norm for the array. -fn norm(mut xs: Span, ord: u128, iter: usize) -> u128 { - let mut norm = 0; +fn norm, +Into, +Zeroable, +Copy>( + mut xs: Span, ord: u128, iter: usize +) -> u128 { + let mut norm: u128 = 0; loop { match xs.pop_front() { Option::Some(x_value) => { if ord == 0 { - if *x_value != 0 { + if (*x_value).is_non_zero() { norm += 1; } } else { - norm += pow(*x_value, ord); + norm += pow((*x_value).into(), ord); } }, Option::None => { break; }, diff --git a/src/linalg/src/tests/norm_test.cairo b/src/linalg/src/tests/norm_test.cairo index 92af342b..82fc6ee2 100644 --- a/src/linalg/src/tests/norm_test.cairo +++ b/src/linalg/src/tests/norm_test.cairo @@ -1,5 +1,12 @@ use alexandria_linalg::norm::norm; +// the following trait is not safe, it is only used for testing. +impl u128_to_u32 of Into { + fn into(self: u128) -> u32 { + self.try_into().unwrap() + } +} + #[test] #[available_gas(2000000)] fn norm_test_1() { @@ -20,3 +27,10 @@ fn norm_test_3() { let mut array: Array = array![3, 4]; assert(norm(array.span(), 0, 10) == 2, 'invalid l1 norm'); } + +#[test] +#[available_gas(2000000)] +fn norm_test_into() { + let mut array: Array = array![3, 4]; + assert(norm(array.span(), 2, 10) == 5, 'invalid l2 norm'); +}