Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add norm #266

Merged
merged 4 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Scarb.lock
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ dependencies = [
[[package]]
name = "alexandria_linalg"
version = "0.1.0"
dependencies = [
"alexandria_math",
]

[[package]]
name = "alexandria_math"
Expand Down
4 changes: 4 additions & 0 deletions src/linalg/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Linear Algebra

## [Norm](./src/norm.cairo)

Calculate the norm of an u128 array ([see also](https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html)).

## [Dot product](./src/dot.cairo)

The dot product or scalar product is an algebraic operation that takes two equal-length sequences of numbers (usually coordinate vectors), and returns a single number. Algebraically, the dot product is the sum of the products of the corresponding entries of the two sequences of numbers ([see also](https://en.wikipedia.org/wiki/Dot_product)).
Expand Down
3 changes: 3 additions & 0 deletions src/linalg/Scarb.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ homepage = "https://github.com/keep-starknet-strange/alexandria/tree/main/src/li

[tool]
fmt.workspace = true

[dependencies]
alexandria_math = { path = "../math" }
1 change: 1 addition & 0 deletions src/linalg/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod dot;
mod kron;
mod norm;

#[cfg(test)]
mod tests;
38 changes: 38 additions & 0 deletions src/linalg/src/norm.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//! Norm of an T array.
use alexandria_math::fast_root::fast_nr_optimize;
use alexandria_math::pow;

/// Compute the norm for an T array.
/// # Arguments
/// * `array` - The inputted array.
/// * `ord` - The order of the norm.
/// * `iter` - The number of iterations to run the algorithm
/// # Returns
/// * `u128` - The norm for the array.
fn norm<T, +Into<T, u128>, +Zeroable<T>, +Copy<T>>(
mut xs: Span<T>, ord: u128, iter: usize
) -> u128 {
let mut norm: u128 = 0;
loop {
match xs.pop_front() {
Option::Some(x_value) => {
if ord == 0 {
if (*x_value).is_non_zero() {
norm += 1;
}
} else {
norm += pow((*x_value).into(), ord);
}
},
Option::None => { break; },
};
};

if ord == 0 {
return norm;
}

norm = fast_nr_optimize(norm, ord, iter);

norm
Soptq marked this conversation as resolved.
Show resolved Hide resolved
}
1 change: 1 addition & 0 deletions src/linalg/src/tests.cairo
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod dot_test;
mod kron_test;
mod norm_test;
29 changes: 29 additions & 0 deletions src/linalg/src/tests/norm_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use alexandria_linalg::norm::norm;

#[test]
#[available_gas(2000000)]
fn norm_test_1() {
let mut array: Array<u128> = array![3, 4];
Soptq marked this conversation as resolved.
Show resolved Hide resolved
assert(norm(array.span(), 2, 10) == 5, 'invalid l2 norm');
}

#[test]
#[available_gas(2000000)]
fn norm_test_2() {
let mut array: Array<u128> = array![3, 4];
assert(norm(array.span(), 1, 10) == 7, 'invalid l1 norm');
}

#[test]
#[available_gas(2000000)]
fn norm_test_3() {
let mut array: Array<u128> = array![3, 4];
assert(norm(array.span(), 0, 10) == 2, 'invalid l1 norm');
}

#[test]
#[available_gas(2000000)]
fn norm_test_into() {
let mut array: Array<u32> = array![3, 4];
assert(norm(array.span(), 2, 10) == 5, 'invalid l2 norm');
}
Loading