Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use Zac's quicksort algorithm in stdlib sorting #5940

Merged
merged 15 commits into from
Sep 11, 2024
116 changes: 116 additions & 0 deletions noir_stdlib/src/array/check_shuffle.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::cmp::Eq;

unconstrained fn __get_shuffle_indices<T, let N: u32>(lhs: [T; N], rhs: [T; N]) -> [Field; N] where T: Eq {
let mut shuffle_indices: [Field;N ] = [0; N];

let mut shuffle_mask: [bool; N] = [false; N];
for i in 0..N {
let mut found = false;
for j in 0..N {
if ((shuffle_mask[j] == false) & (!found)) {
if (lhs[i] == rhs[j]) {
found = true;
shuffle_indices[i] = j as Field;
shuffle_mask[j] = true;
}
}
if (found) {
continue;
}
}
assert(found == true, "check_shuffle, lhs and rhs arrays do not contain equivalent values");
}

shuffle_indices
}

unconstrained fn __get_index<let N: u32>(indices: [Field; N], idx: Field) -> Field {
let mut result = 0;
for i in 0..N {
if (indices[i] == idx) {
result = i as Field;
break;
}
}
result
}

pub(crate) fn check_shuffle<T, let N: u32>(lhs: [T; N], rhs: [T; N]) where T: Eq {
unsafe {
let shuffle_indices = __get_shuffle_indices(lhs, rhs);

for i in 0..N {
let idx = __get_index(shuffle_indices, i as Field);
assert_eq(shuffle_indices[idx], i as Field);
}
for i in 0..N {
let idx = shuffle_indices[i];
let expected = rhs[idx];
let result = lhs[i];
assert_eq(expected, result);
}
}
}

mod test {
use super::check_shuffle;
use crate::cmp::Eq;

struct CompoundStruct {
a: bool,
b: Field,
c: u64
}
impl Eq for CompoundStruct {
fn eq(self, other: Self) -> bool {
(self.a == other.a) & (self.b == other.b) & (self.c == other.c)
}
}

#[test]
fn test_shuffle() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [2, 0, 3, 1, 4];
check_shuffle(lhs, rhs);
}

#[test]
fn test_shuffle_identity() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [0, 1, 2, 3, 4];
check_shuffle(lhs, rhs);
}

#[test(should_fail_with = "check_shuffle, lhs and rhs arrays do not contain equivalent values")]
fn test_shuffle_fail() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [0, 1, 2, 3, 5];
check_shuffle(lhs, rhs);
}

#[test(should_fail_with = "check_shuffle, lhs and rhs arrays do not contain equivalent values")]
fn test_shuffle_duplicates() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [0, 1, 2, 3, 3];
check_shuffle(lhs, rhs);
}

#[test]
fn test_shuffle_compound_struct() {
let lhs: [CompoundStruct; 5] = [
CompoundStruct { a: false, b: 0, c: 12345 },
CompoundStruct { a: false, b: -100, c: 54321 },
CompoundStruct { a: true, b: 5, c: 0xffffffffffffffff },
CompoundStruct { a: true, b: 9814, c: 0xeeffee0011001133 },
CompoundStruct { a: false, b: 0x155, c: 0 }
];
let rhs: [CompoundStruct; 5] = [
CompoundStruct { a: false, b: 0x155, c: 0 },
CompoundStruct { a: false, b: 0, c: 12345 },
CompoundStruct { a: false, b: -100, c: 54321 },
CompoundStruct { a: true, b: 9814, c: 0xeeffee0011001133 },
CompoundStruct { a: true, b: 5, c: 0xffffffffffffffff }
];
check_shuffle(lhs, rhs);
}
}
96 changes: 31 additions & 65 deletions noir_stdlib/src/array.nr → noir_stdlib/src/array/mod.nr
Original file line number Diff line number Diff line change
@@ -1,63 +1,15 @@
use crate::cmp::Ord;
use crate::cmp::{Eq, Ord};
use crate::convert::From;
use crate::runtime::is_unconstrained;

mod check_shuffle;
mod quicksort;

impl<T, let N: u32> [T; N] {
/// Returns the length of the slice.
#[builtin(array_len)]
pub fn len(self) -> u32 {}

pub fn sort(self) -> Self where T: Ord {
self.sort_via(|a: T, b: T| a <= b)
}

pub fn sort_via<Env>(self, ordering: fn[Env](T, T) -> bool) -> Self {
let sorted_index = unsafe {
// Safety: These indices are asserted to be the sorted element indices via `find_index`
let sorted_index: [u32; N] = self.get_sorting_index(ordering);

for i in 0..N {
let pos = find_index(sorted_index, i);
assert(sorted_index[pos] == i);
}

sorted_index
};

// Sort the array using the indexes
let mut result = self;
for i in 0..N {
result[i] = self[sorted_index[i]];
}
// Ensure the array is sorted
for i in 0..N - 1 {
assert(ordering(result[i], result[i + 1]));
}

result
}

/// Returns the index of the elements in the array that would sort it, using the provided custom sorting function.
unconstrained fn get_sorting_index<Env>(self, ordering: fn[Env](T, T) -> bool) -> [u32; N] {
let mut result = [0; N];
let mut a = self;
for i in 0..N {
result[i] = i;
}
for i in 1..N {
for j in 0..i {
if ordering(a[i], a[j]) {
let old_a_j = a[j];
a[j] = a[i];
a[i] = old_a_j;
let old_j = result[j];
result[j] = result[i];
result[i] = old_j;
}
}
}
result
}

#[builtin(as_slice)]
pub fn as_slice(self) -> [T] {}

Expand Down Expand Up @@ -114,25 +66,39 @@ impl<T, let N: u32> [T; N] {
}
}

impl<T, let N: u32> [T; N] where T: Ord + Eq {
pub fn sort(self) -> Self {
self.sort_via(|a: T, b: T| a <= b)
}
}

impl<T, let N: u32> [T; N] where T: Eq {

pub fn sort_via<Env>(self, ordering: fn[Env](T, T) -> bool) -> Self {
unsafe {
// Safety: `sorted` array is checked to be:
// a. a permutation of `input`'s elements
// b. satisfying the predicate `ordering`
let sorted = quicksort::quicksort(self, ordering);

if !is_unconstrained() {
for i in 0..N - 1 {
assert(ordering(sorted[i], sorted[i + 1]));
}
check_shuffle::check_shuffle(self, sorted);
jfecher marked this conversation as resolved.
Show resolved Hide resolved
}
sorted
}
}
}

impl<let N: u32> [u8; N] {
/// Convert a sequence of bytes as-is into a string.
/// This function performs no UTF-8 validation or similar.
#[builtin(array_as_str_unchecked)]
pub fn as_str_unchecked(self) -> str<N> {}
}

// helper function used to look up the position of a value in an array of Field
// Note that function returns 0 if the value is not found
unconstrained fn find_index<let N: u32>(a: [u32; N], find: u32) -> u32 {
let mut result = 0;
for i in 0..a.len() {
if a[i] == find {
result = i;
}
}
result
}

impl<let N: u32> From<str<N>> for [u8; N] {
fn from(s: str<N>) -> Self {
s.as_bytes()
Expand Down
39 changes: 39 additions & 0 deletions noir_stdlib/src/array/quicksort.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
unconstrained fn partition<T, Env, let N: u32>(
arr: &mut [T; N],
low: u32,
high: u32,
sortfn: fn[Env](T, T) -> bool
) -> u32 {
let pivot = high;
let mut i = low;
for j in low..high {
if (sortfn(arr[j], arr[pivot])) {
let temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i += 1;
}
}
let temp = arr[i];
arr[i] = arr[pivot];
arr[pivot] = temp;
i
}

unconstrained fn quicksort_recursive<T, Env, let N: u32>(arr: &mut [T; N], low: u32, high: u32, sortfn: fn[Env](T, T) -> bool) {
if low < high {
let pivot_index = partition(arr, low, high, sortfn);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1, sortfn);
}
quicksort_recursive(arr, pivot_index + 1, high, sortfn);
}
}

unconstrained pub(crate) fn quicksort<T, Env, let N: u32>(_arr: [T; N], sortfn: fn[Env](T, T) -> bool) -> [T; N] {
let mut arr: [T; N] = _arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1, sortfn);
}
arr
}
Loading