Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for nested arrays on brillig gen #2029

Merged
merged 6 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.6.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "0"
y = "1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
struct Header {
params: [Field; 3],
}

struct MyNote {
plain: Field,
array: [Field; 2],
header: Header,
}

unconstrained fn access_nested(notes: [MyNote; 2], x: Field, y: Field) -> Field {
notes[x].array[y] + notes[y].array[x] + notes[x].plain + notes[y].header.params[x]
}

unconstrained fn create_inside_brillig(x: Field, y: Field) {
let header = Header { params: [1, 2, 3]};
let note0 = MyNote { array: [1, 2], plain : 3, header };
let note1 = MyNote { array: [4, 5], plain : 6, header };
assert(access_nested([note0, note1], x, y) == (2 + 4 + 3 + 1));
}

fn main(x: Field, y: Field) {
let header = Header { params: [1, 2, 3]};
let note0 = MyNote { array: [1, 2], plain : 3, header };
let note1 = MyNote { array: [4, 5], plain : 6, header };

create_inside_brillig(x, y);
assert(access_nested([note0, note1], x, y) == (2 + 4 + 3 + 1));
}

6 changes: 1 addition & 5 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ pub(crate) fn convert_ssa_function(func: &Function, enable_debug_trace: bool) ->
let mut function_context =
FunctionContext { function_id: func.id(), ssa_value_to_brillig_variable: HashMap::new() };

let mut brillig_context = BrilligContext::new(
FunctionContext::parameters(func),
FunctionContext::return_values(func),
enable_debug_trace,
);
let mut brillig_context = BrilligContext::new(enable_debug_trace);

brillig_context.enter_context(FunctionContext::function_id_to_function_label(func.id()));
for block in reverse_post_order {
Expand Down
124 changes: 68 additions & 56 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use acvm::FieldElement;
use iter_extended::vecmap;

use super::brillig_black_box::convert_black_box_call;
use super::brillig_fn::{compute_size_of_composite_type, FunctionContext};
use super::brillig_fn::FunctionContext;
use super::brillig_slice_ops::{
slice_insert_operation, slice_pop_back_operation, slice_pop_front_operation,
slice_push_front_operation, slice_remove_operation,
Expand Down Expand Up @@ -397,11 +397,8 @@ impl<'block> BrilligBlock<'block> {
}
Instruction::ArrayGet { array, index } => {
let result_ids = dfg.instruction_results(instruction_id);
let destination_register = self.function_context.create_register_variable(
self.brillig_context,
result_ids[0],
dfg,
);
let destination_variable =
self.function_context.create_variable(self.brillig_context, result_ids[0], dfg);

let array_variable = self.convert_ssa_value(*array, dfg);
let array_pointer = match array_variable {
Expand All @@ -411,12 +408,12 @@ impl<'block> BrilligBlock<'block> {
};

let index_register = self.convert_ssa_register_value(*index, dfg);
self.brillig_context.array_get(array_pointer, index_register, destination_register);
self.convert_ssa_array_get(array_pointer, index_register, destination_variable);
}
Instruction::ArraySet { array, index, value } => {
let source_variable = self.convert_ssa_value(*array, dfg);
let index_register = self.convert_ssa_register_value(*index, dfg);
let value_register = self.convert_ssa_register_value(*value, dfg);
let value_variable = self.convert_ssa_value(*value, dfg);

let result_ids = dfg.instruction_results(instruction_id);
let destination_variable =
Expand All @@ -426,7 +423,7 @@ impl<'block> BrilligBlock<'block> {
source_variable,
destination_variable,
index_register,
value_register,
value_variable,
);
}
_ => todo!("ICE: Instruction not supported {instruction:?}"),
Expand Down Expand Up @@ -486,14 +483,31 @@ impl<'block> BrilligBlock<'block> {
.post_call_prep_returns_load_registers(&returned_registers, &saved_registers);
}

fn convert_ssa_array_get(
&mut self,
array_pointer: RegisterIndex,
index_register: RegisterIndex,
destination_variable: RegisterOrMemory,
) {
match destination_variable {
RegisterOrMemory::RegisterIndex(destination_register) => {
self.brillig_context.array_get(array_pointer, index_register, destination_register);
}
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => {
self.brillig_context.array_get(array_pointer, index_register, pointer);
}
RegisterOrMemory::HeapVector(_) => unimplemented!("ICE: Array get for vector"),
}
}

/// Array set operation in SSA returns a new array or slice that is a copy of the parameter array or slice
/// With a specific value changed.
fn convert_ssa_array_set(
&mut self,
source_variable: RegisterOrMemory,
destination_variable: RegisterOrMemory,
index_register: RegisterIndex,
value_register: RegisterIndex,
value_variable: RegisterOrMemory,
) {
let destination_pointer = match destination_variable {
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => pointer,
Expand Down Expand Up @@ -532,11 +546,30 @@ impl<'block> BrilligBlock<'block> {
}

// Then set the value in the newly created array
self.brillig_context.array_set(destination_pointer, index_register, value_register);
self.store_variable_in_array(destination_pointer, index_register, value_variable);

self.brillig_context.deallocate_register(source_size_as_register);
}

fn store_variable_in_array(
&mut self,
destination_pointer: RegisterIndex,
index_register: RegisterIndex,
value_variable: RegisterOrMemory,
) {
match value_variable {
RegisterOrMemory::RegisterIndex(value_register) => {
self.brillig_context.array_set(destination_pointer, index_register, value_register);
}
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => {
self.brillig_context.array_set(destination_pointer, index_register, pointer);
}
RegisterOrMemory::HeapVector(_) => {
unimplemented!("ICE: cannot store a vector in array")
}
}
}

/// Convert the SSA slice operations to brillig slice operations
fn convert_ssa_slice_intrinsic_call(
&mut self,
Expand Down Expand Up @@ -661,47 +694,6 @@ impl<'block> BrilligBlock<'block> {
}
}

/// This function allows storing a Value in memory starting at the address specified by the
/// address_register. The value can be a single value or an array. The function will recursively
/// store the value in memory.
fn store_in_memory(
&mut self,
address_register: RegisterIndex,
value_id: ValueId,
dfg: &DataFlowGraph,
) {
let value = &dfg[value_id];
match value {
Value::Param { .. } | Value::Instruction { .. } | Value::NumericConstant { .. } => {
let value_register = self.convert_ssa_register_value(value_id, dfg);
self.brillig_context.store_instruction(address_register, value_register);
}
Value::Array { array, element_type } => {
// Allocate a register for the iterator
let iterator_register = self.brillig_context.allocate_register();
// Set the iterator to the address of the array
self.brillig_context.mov_instruction(iterator_register, address_register);

let size_of_item_register = self
.brillig_context
.make_constant(compute_size_of_composite_type(element_type).into());

for element_id in array.iter() {
// Store the item in memory
self.store_in_memory(iterator_register, *element_id, dfg);
// Increment the iterator by the size of the items
self.brillig_context.memory_op(
iterator_register,
size_of_item_register,
iterator_register,
BinaryIntOp::Add,
);
}
}
_ => unimplemented!("ICE: Value {:?} not storeable in memory", value),
}
}

/// Converts an SSA cast to a sequence of Brillig opcodes.
/// Casting is only necessary when shrinking the bit size of a numeric value.
fn convert_cast(
Expand Down Expand Up @@ -802,14 +794,34 @@ impl<'block> BrilligBlock<'block> {
self.brillig_context.const_instruction(register_index, (*constant).into());
new_variable
}
Value::Array { .. } => {
let new_variable =
self.function_context.create_variable(self.brillig_context, value_id, dfg);
Value::Array { array, .. } => {
let new_variable = self.function_context.get_or_create_variable(
self.brillig_context,
value_id,
dfg,
);
let heap_array = self.function_context.extract_heap_array(new_variable);

self.brillig_context
.allocate_fixed_length_array(heap_array.pointer, heap_array.size);
self.store_in_memory(heap_array.pointer, value_id, dfg);

// Allocate a register for the iterator
let iterator_register = self.brillig_context.make_constant(0_usize.into());

for element_id in array.iter() {
let element_variable = self.convert_ssa_value(*element_id, dfg);
// Store the item in memory
self.store_variable_in_array(
heap_array.pointer,
iterator_register,
element_variable,
);
// Increment the iterator
self.brillig_context.usize_op_in_place(iterator_register, BinaryIntOp::Add, 1);
}

self.brillig_context.deallocate_register(iterator_register);

new_variable
}
_ => {
Expand Down
53 changes: 26 additions & 27 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use acvm::brillig_vm::brillig::{HeapArray, HeapVector, RegisterIndex, RegisterOrMemory};
use iter_extended::vecmap;

use crate::{
brillig::brillig_ir::{
Expand Down Expand Up @@ -37,9 +38,9 @@ impl FunctionContext {
let register = brillig_context.allocate_register();
RegisterOrMemory::RegisterIndex(register)
}
Type::Array(_, _) => {
Type::Array(item_typ, elem_count) => {
let pointer_register = brillig_context.allocate_register();
let size = compute_size_of_type(&typ);
let size = compute_array_length(&item_typ, elem_count);
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_register, size })
}
Type::Slice(_) => {
Expand Down Expand Up @@ -139,18 +140,31 @@ impl FunctionContext {
function_id.to_string()
}

fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter {
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Simple,
Type::Array(item_type, size) => BrilligParameter::Array(
vecmap(item_type.iter(), |item_typ| {
FunctionContext::ssa_type_to_parameter(item_typ)
}),
*size,
),
Type::Slice(item_type) => {
BrilligParameter::Slice(vecmap(item_type.iter(), |item_typ| {
FunctionContext::ssa_type_to_parameter(item_typ)
}))
}
_ => unimplemented!("Unsupported function parameter/return type {typ:?}"),
}
}

/// Collects the parameters of a given function
pub(crate) fn parameters(func: &Function) -> Vec<BrilligParameter> {
func.parameters()
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Register,
Type::Array(..) => BrilligParameter::HeapArray(compute_size_of_type(&typ)),
Type::Slice(_) => BrilligParameter::HeapVector,
_ => unimplemented!("Unsupported function parameter type {typ:?}"),
}
FunctionContext::ssa_type_to_parameter(&typ)
})
.collect()
}
Expand All @@ -161,28 +175,13 @@ impl FunctionContext {
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Register,
Type::Array(..) => BrilligParameter::HeapArray(compute_size_of_type(&typ)),
Type::Slice(_) => BrilligParameter::HeapVector,
_ => unimplemented!("Unsupported return value type {typ:?}"),
}
FunctionContext::ssa_type_to_parameter(&typ)
})
.collect()
}
}

/// Computes the size of an SSA composite type
pub(crate) fn compute_size_of_composite_type(typ: &CompositeType) -> usize {
typ.iter().map(compute_size_of_type).sum()
}

/// Finds out the size of a given SSA type
/// This is needed to store values in memory
pub(crate) fn compute_size_of_type(typ: &Type) -> usize {
match typ {
Type::Numeric(_) => 1,
Type::Array(types, item_count) => compute_size_of_composite_type(types) * item_count,
_ => todo!("ICE: Type not supported {typ:?}"),
}
/// Computes the length of an array. This will match with the indexes that SSA will issue
pub(crate) fn compute_array_length(item_typ: &CompositeType, elem_count: usize) -> usize {
item_typ.len() * elem_count
}
Loading