diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs index da9a34f1044e..01f45bf653c4 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs @@ -482,60 +482,69 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { destinations.iter().zip(destination_value_types).zip(&values) { match (destination, value_type) { - (ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => { - match output { - ForeignCallParam::Single(value) => { - self.write_value_to_memory(*value_index, value, *bit_size)?; - } - _ => return Err(format!( - "Function result size does not match brillig bytecode. Expected 1 result but got {output:?}") - ), + (ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => { + match output { + ForeignCallParam::Single(value) => { + self.write_value_to_memory(*value_index, value, *bit_size)?; } + _ => return Err(format!( + "Function result size does not match brillig bytecode. Expected 1 result but got {output:?}") + ), } - ( - ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }), - HeapValueType::Array { value_types, size: type_size }, - ) if size == type_size => { - if HeapValueType::all_simple(value_types) { - match output { - ForeignCallParam::Array(values) => { - if values.len() != *size { - return Err("Foreign call result array doesn't match expected size".to_string()); - } + } + ( + ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }), + HeapValueType::Array { value_types, size: type_size }, + ) if size == type_size => { + if HeapValueType::all_simple(value_types) { + match output { + ForeignCallParam::Array(values) => { + if values.len() != *size { + // foreign call returning flattened values into a nested type, so the sizes do not match + let destination = self.memory.read_ref(*pointer_index); + let return_type = value_type; + let mut flatten_values_idx = 0; //index of values read from flatten_values + self.write_slice_of_values_to_memory(destination, &output.fields(), &mut flatten_values_idx, return_type)?; + } else { self.write_values_to_memory_slice(*pointer_index, values, value_types)?; } - _ => { - return Err("Function result size does not match brillig bytecode size".to_string()); - } } - } else { - unimplemented!("deflattening heap arrays from foreign calls"); + _ => { + return Err("Function result size does not match brillig bytecode size".to_string()); + } } - } - ( - ValueOrArray::HeapVector(HeapVector {pointer: pointer_index, size: size_index }), - HeapValueType::Vector { value_types }, - ) => { - if HeapValueType::all_simple(value_types) { - match output { - ForeignCallParam::Array(values) => { - // Set our size in the size address - self.memory.write(*size_index, values.len().into()); + } else { + // foreign call returning flattened values into a nested type, so the sizes do not match + let destination = self.memory.read_ref(*pointer_index); + let return_type = value_type; + let mut flatten_values_idx = 0; //index of values read from flatten_values + self.write_slice_of_values_to_memory(destination, &output.fields(), &mut flatten_values_idx, return_type)?; + } + } + ( + ValueOrArray::HeapVector(HeapVector {pointer: pointer_index, size: size_index }), + HeapValueType::Vector { value_types }, + ) => { + if HeapValueType::all_simple(value_types) { + match output { + ForeignCallParam::Array(values) => { + // Set our size in the size address + self.memory.write(*size_index, values.len().into()); + self.write_values_to_memory_slice(*pointer_index, values, value_types)?; - self.write_values_to_memory_slice(*pointer_index, values, value_types)?; - } - _ => { - return Err("Function result size does not match brillig bytecode size".to_string()); - } } - } else { - unimplemented!("deflattening heap vectors from foreign calls"); + _ => { + return Err("Function result size does not match brillig bytecode size".to_string()); + } } - } - _ => { - return Err(format!("Unexpected value type {value_type:?} for destination {destination:?}")); + } else { + unimplemented!("deflattening heap vectors from foreign calls"); } } + _ => { + return Err(format!("Unexpected value type {value_type:?} for destination {destination:?}")); + } + } } let _ = @@ -596,6 +605,66 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { Ok(()) } + /// Writes flattened values to memory, using the provided type + /// Function calls itself recursively in order to work with recursive types (nested arrays) + /// values_idx is the current index in the values vector and is incremented every time + /// a value is written to memory + /// The function returns the address of the next value to be written + fn write_slice_of_values_to_memory( + &mut self, + destination: MemoryAddress, + values: &Vec, + values_idx: &mut usize, + value_type: &HeapValueType, + ) -> Result { + let mut current_pointer = destination; + match value_type { + HeapValueType::Simple(bit_size) => { + self.write_value_to_memory(destination, &values[*values_idx], *bit_size)?; + *values_idx += 1; + Ok(MemoryAddress(destination.to_usize() + 1)) + } + HeapValueType::Array { value_types, size } => { + for _ in 0..*size { + for typ in value_types { + match typ { + HeapValueType::Simple(len) => { + self.write_value_to_memory( + current_pointer, + &values[*values_idx], + *len, + )?; + *values_idx += 1; + current_pointer = MemoryAddress(current_pointer.to_usize() + 1); + } + HeapValueType::Array { .. } => { + let destination = self.memory.read_ref(current_pointer); + let destination = self.memory.read_ref(destination); + self.write_slice_of_values_to_memory( + destination, + values, + values_idx, + typ, + )?; + current_pointer = MemoryAddress(current_pointer.to_usize() + 1); + } + HeapValueType::Vector { .. } => { + return Err(format!( + "Unsupported returned type in foreign calls {:?}", + typ + )); + } + } + } + } + Ok(current_pointer) + } + HeapValueType::Vector { .. } => { + Err(format!("Unsupported returned type in foreign calls {:?}", value_type)) + } + } + } + /// Process a binary operation. /// This method will not modify the program counter. fn process_binary_field_op( diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index df7112d437d8..b441e8be3eb7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -1737,8 +1737,7 @@ impl<'block> BrilligBlock<'block> { dfg, ); let array = variable.extract_array(); - self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size); - self.brillig_context.usize_const_instruction(array.rc, 1_usize.into()); + self.allocate_nested_array(typ, Some(array)); variable } @@ -1765,6 +1764,43 @@ impl<'block> BrilligBlock<'block> { } } + fn allocate_nested_array( + &mut self, + typ: &Type, + array: Option, + ) -> BrilligVariable { + match typ { + Type::Array(types, size) => { + let array = array.unwrap_or(BrilligArray { + pointer: self.brillig_context.allocate_register(), + size: *size, + rc: self.brillig_context.allocate_register(), + }); + self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size); + self.brillig_context.usize_const_instruction(array.rc, 1_usize.into()); + + let mut index = 0_usize; + for _ in 0..*size { + for element_type in types.iter() { + match element_type { + Type::Array(_, _) => { + let inner_array = self.allocate_nested_array(element_type, None); + let idx = + self.brillig_context.make_usize_constant_instruction(index.into()); + self.store_variable_in_array(array.pointer, idx, inner_array); + } + Type::Slice(_) => unreachable!("ICE: unsupported slice type in allocate_nested_array(), expects an array or a numeric type"), + _ => (), + } + index += 1; + } + } + BrilligVariable::BrilligArray(array) + } + _ => unreachable!("ICE: allocate_nested_array() expects an array, got {typ:?}"), + } + } + /// Gets the "user-facing" length of an array. /// An array of structs with two fields would be stored as an 2 * array.len() array/vector. /// So we divide the length by the number of subitems in an item to get the user-facing length. diff --git a/noir/noir-repo/test_programs/noir_test_success/regression_4561/Nargo.toml b/noir/noir-repo/test_programs/noir_test_success/regression_4561/Nargo.toml new file mode 100644 index 000000000000..90deee74640c --- /dev/null +++ b/noir/noir-repo/test_programs/noir_test_success/regression_4561/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "regression_4561" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/noir_test_success/regression_4561/src/main.nr b/noir/noir-repo/test_programs/noir_test_success/regression_4561/src/main.nr new file mode 100644 index 000000000000..ad40941ff51c --- /dev/null +++ b/noir/noir-repo/test_programs/noir_test_success/regression_4561/src/main.nr @@ -0,0 +1,78 @@ +// Regression test for issue #4561 +use std::test::OracleMock; + +type TReturnElem = [Field; 3]; +type TReturn = [TReturnElem; 2]; + +#[oracle(simple_nested_return)] +unconstrained fn simple_nested_return_oracle() -> TReturn {} + +unconstrained fn simple_nested_return_unconstrained() -> TReturn { + simple_nested_return_oracle() +} + +#[test] +fn test_simple_nested_return() { + OracleMock::mock("simple_nested_return").returns([1, 2, 3, 4, 5, 6]); + assert_eq(simple_nested_return_unconstrained(), [[1, 2, 3], [4, 5, 6]]); +} + +#[oracle(nested_with_fields_return)] +unconstrained fn nested_with_fields_return_oracle() -> (Field, TReturn, Field) {} + +unconstrained fn nested_with_fields_return_unconstrained() -> (Field, TReturn, Field) { + nested_with_fields_return_oracle() +} + +#[test] +fn test_nested_with_fields_return() { + OracleMock::mock("nested_with_fields_return").returns((0, [1, 2, 3, 4, 5, 6], 7)); + assert_eq(nested_with_fields_return_unconstrained(), (0, [[1, 2, 3], [4, 5, 6]], 7)); +} + +#[oracle(two_nested_return)] +unconstrained fn two_nested_return_oracle() -> (Field, TReturn, Field, TReturn) {} + +unconstrained fn two_nested_return_unconstrained() -> (Field, TReturn, Field, TReturn) { + two_nested_return_oracle() +} + +#[test] +fn two_nested_return() { + OracleMock::mock("two_nested_return").returns((0, [1, 2, 3, 4, 5, 6], 7, [1, 2, 3, 4, 5, 6])); + assert_eq(two_nested_return_unconstrained(), (0, [[1, 2, 3], [4, 5, 6]], 7, [[1, 2, 3], [4, 5, 6]])); +} + +#[oracle(foo_return)] +unconstrained fn foo_return() -> (Field, TReturn, TestTypeFoo) {} +unconstrained fn foo_return_unconstrained() -> (Field, TReturn, TestTypeFoo) { + foo_return() +} + +struct TestTypeFoo { + a: Field, + b: [[[Field; 3]; 4]; 2], + c: [TReturnElem; 2], + d: TReturnElem, +} + +#[test] +fn complexe_struct_return() { + OracleMock::mock("foo_return").returns( + ( + 0, [1, 2, 3, 4, 5, 6], 7, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6] + ) + ); + let foo_x = foo_return_unconstrained(); + assert_eq((foo_x.0, foo_x.1), (0, [[1, 2, 3], [4, 5, 6]])); + assert_eq(foo_x.2.a, 7); + assert_eq( + foo_x.2.b, [ + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]] + ] + ); + let a: TReturnElem = [1, 2, 3]; + let b: TReturnElem = [4, 5, 6]; + assert_eq(foo_x.2.c, [a, b]); + assert_eq(foo_x.2.d, a); +}