Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: pull out foreign call nested array changes #7216

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 112 additions & 43 deletions noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,60 +482,69 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
destinations.iter().zip(destination_value_types).zip(&values)
{
match (destination, value_type) {
(ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => {
match output {
ForeignCallParam::Single(value) => {
self.write_value_to_memory(*value_index, value, *bit_size)?;
}
_ => return Err(format!(
"Function result size does not match brillig bytecode. Expected 1 result but got {output:?}")
),
(ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => {
match output {
ForeignCallParam::Single(value) => {
self.write_value_to_memory(*value_index, value, *bit_size)?;
}
_ => return Err(format!(
"Function result size does not match brillig bytecode. Expected 1 result but got {output:?}")
),
}
(
ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }),
HeapValueType::Array { value_types, size: type_size },
) if size == type_size => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
return Err("Foreign call result array doesn't match expected size".to_string());
}
}
(
ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }),
HeapValueType::Array { value_types, size: type_size },
) if size == type_size => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
// foreign call returning flattened values into a nested type, so the sizes do not match
let destination = self.memory.read_ref(*pointer_index);
let return_type = value_type;
let mut flatten_values_idx = 0; //index of values read from flatten_values
self.write_slice_of_values_to_memory(destination, &output.fields(), &mut flatten_values_idx, return_type)?;
} else {
self.write_values_to_memory_slice(*pointer_index, values, value_types)?;
}
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
} else {
unimplemented!("deflattening heap arrays from foreign calls");
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
}
(
ValueOrArray::HeapVector(HeapVector {pointer: pointer_index, size: size_index }),
HeapValueType::Vector { value_types },
) => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size address
self.memory.write(*size_index, values.len().into());
} else {
// foreign call returning flattened values into a nested type, so the sizes do not match
let destination = self.memory.read_ref(*pointer_index);
let return_type = value_type;
let mut flatten_values_idx = 0; //index of values read from flatten_values
self.write_slice_of_values_to_memory(destination, &output.fields(), &mut flatten_values_idx, return_type)?;
}
}
(
ValueOrArray::HeapVector(HeapVector {pointer: pointer_index, size: size_index }),
HeapValueType::Vector { value_types },
) => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size address
self.memory.write(*size_index, values.len().into());
self.write_values_to_memory_slice(*pointer_index, values, value_types)?;

self.write_values_to_memory_slice(*pointer_index, values, value_types)?;
}
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
} else {
unimplemented!("deflattening heap vectors from foreign calls");
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
}
_ => {
return Err(format!("Unexpected value type {value_type:?} for destination {destination:?}"));
} else {
unimplemented!("deflattening heap vectors from foreign calls");
}
}
_ => {
return Err(format!("Unexpected value type {value_type:?} for destination {destination:?}"));
}
}
}

let _ =
Expand Down Expand Up @@ -596,6 +605,66 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
Ok(())
}

/// Writes flattened values to memory, using the provided type
/// Function calls itself recursively in order to work with recursive types (nested arrays)
/// values_idx is the current index in the values vector and is incremented every time
/// a value is written to memory
/// The function returns the address of the next value to be written
fn write_slice_of_values_to_memory(
&mut self,
destination: MemoryAddress,
values: &Vec<F>,
values_idx: &mut usize,
value_type: &HeapValueType,
) -> Result<MemoryAddress, String> {
let mut current_pointer = destination;
match value_type {
HeapValueType::Simple(bit_size) => {
self.write_value_to_memory(destination, &values[*values_idx], *bit_size)?;
*values_idx += 1;
Ok(MemoryAddress(destination.to_usize() + 1))
}
HeapValueType::Array { value_types, size } => {
for _ in 0..*size {
for typ in value_types {
match typ {
HeapValueType::Simple(len) => {
self.write_value_to_memory(
current_pointer,
&values[*values_idx],
*len,
)?;
*values_idx += 1;
current_pointer = MemoryAddress(current_pointer.to_usize() + 1);
}
HeapValueType::Array { .. } => {
let destination = self.memory.read_ref(current_pointer);
let destination = self.memory.read_ref(destination);
self.write_slice_of_values_to_memory(
destination,
values,
values_idx,
typ,
)?;
current_pointer = MemoryAddress(current_pointer.to_usize() + 1);
}
HeapValueType::Vector { .. } => {
return Err(format!(
"Unsupported returned type in foreign calls {:?}",
typ
));
}
}
}
}
Ok(current_pointer)
}
HeapValueType::Vector { .. } => {
Err(format!("Unsupported returned type in foreign calls {:?}", value_type))
}
}
}

/// Process a binary operation.
/// This method will not modify the program counter.
fn process_binary_field_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1737,8 +1737,7 @@ impl<'block> BrilligBlock<'block> {
dfg,
);
let array = variable.extract_array();
self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size);
self.brillig_context.usize_const_instruction(array.rc, 1_usize.into());
self.allocate_nested_array(typ, Some(array));

variable
}
Expand All @@ -1765,6 +1764,43 @@ impl<'block> BrilligBlock<'block> {
}
}

fn allocate_nested_array(
&mut self,
typ: &Type,
array: Option<BrilligArray>,
) -> BrilligVariable {
match typ {
Type::Array(types, size) => {
let array = array.unwrap_or(BrilligArray {
pointer: self.brillig_context.allocate_register(),
size: *size,
rc: self.brillig_context.allocate_register(),
});
self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size);
self.brillig_context.usize_const_instruction(array.rc, 1_usize.into());

let mut index = 0_usize;
for _ in 0..*size {
for element_type in types.iter() {
match element_type {
Type::Array(_, _) => {
let inner_array = self.allocate_nested_array(element_type, None);
let idx =
self.brillig_context.make_usize_constant_instruction(index.into());
self.store_variable_in_array(array.pointer, idx, inner_array);
}
Type::Slice(_) => unreachable!("ICE: unsupported slice type in allocate_nested_array(), expects an array or a numeric type"),
_ => (),
}
index += 1;
}
}
BrilligVariable::BrilligArray(array)
}
_ => unreachable!("ICE: allocate_nested_array() expects an array, got {typ:?}"),
}
}

/// Gets the "user-facing" length of an array.
/// An array of structs with two fields would be stored as an 2 * array.len() array/vector.
/// So we divide the length by the number of subitems in an item to get the user-facing length.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "regression_4561"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Regression test for issue #4561
use std::test::OracleMock;

type TReturnElem = [Field; 3];
type TReturn = [TReturnElem; 2];

#[oracle(simple_nested_return)]
unconstrained fn simple_nested_return_oracle() -> TReturn {}

unconstrained fn simple_nested_return_unconstrained() -> TReturn {
simple_nested_return_oracle()
}

#[test]
fn test_simple_nested_return() {
OracleMock::mock("simple_nested_return").returns([1, 2, 3, 4, 5, 6]);
assert_eq(simple_nested_return_unconstrained(), [[1, 2, 3], [4, 5, 6]]);
}

#[oracle(nested_with_fields_return)]
unconstrained fn nested_with_fields_return_oracle() -> (Field, TReturn, Field) {}

unconstrained fn nested_with_fields_return_unconstrained() -> (Field, TReturn, Field) {
nested_with_fields_return_oracle()
}

#[test]
fn test_nested_with_fields_return() {
OracleMock::mock("nested_with_fields_return").returns((0, [1, 2, 3, 4, 5, 6], 7));
assert_eq(nested_with_fields_return_unconstrained(), (0, [[1, 2, 3], [4, 5, 6]], 7));
}

#[oracle(two_nested_return)]
unconstrained fn two_nested_return_oracle() -> (Field, TReturn, Field, TReturn) {}

unconstrained fn two_nested_return_unconstrained() -> (Field, TReturn, Field, TReturn) {
two_nested_return_oracle()
}

#[test]
fn two_nested_return() {
OracleMock::mock("two_nested_return").returns((0, [1, 2, 3, 4, 5, 6], 7, [1, 2, 3, 4, 5, 6]));
assert_eq(two_nested_return_unconstrained(), (0, [[1, 2, 3], [4, 5, 6]], 7, [[1, 2, 3], [4, 5, 6]]));
}

#[oracle(foo_return)]
unconstrained fn foo_return() -> (Field, TReturn, TestTypeFoo) {}
unconstrained fn foo_return_unconstrained() -> (Field, TReturn, TestTypeFoo) {
foo_return()
}

struct TestTypeFoo {
a: Field,
b: [[[Field; 3]; 4]; 2],
c: [TReturnElem; 2],
d: TReturnElem,
}

#[test]
fn complexe_struct_return() {
OracleMock::mock("foo_return").returns(
(
0, [1, 2, 3, 4, 5, 6], 7, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6]
)
);
let foo_x = foo_return_unconstrained();
assert_eq((foo_x.0, foo_x.1), (0, [[1, 2, 3], [4, 5, 6]]));
assert_eq(foo_x.2.a, 7);
assert_eq(
foo_x.2.b, [
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]
]
);
let a: TReturnElem = [1, 2, 3];
let b: TReturnElem = [4, 5, 6];
assert_eq(foo_x.2.c, [a, b]);
assert_eq(foo_x.2.d, a);
}
Loading