diff --git a/crates/noirc_frontend/src/hir/def_map/aztec_library.rs b/crates/noirc_frontend/src/hir/def_map/aztec_library.rs index 12aa60276ce..38a45bc16c4 100644 --- a/crates/noirc_frontend/src/hir/def_map/aztec_library.rs +++ b/crates/noirc_frontend/src/hir/def_map/aztec_library.rs @@ -5,8 +5,9 @@ use crate::graph::CrateId; use crate::{ hir::Context, token::Attribute, BlockExpression, CallExpression, CastExpression, Distinctness, Expression, ExpressionKind, ForExpression, FunctionReturnType, Ident, ImportStatement, - IndexExpression, LetStatement, Literal, MethodCallExpression, NoirFunction, ParsedModule, Path, - PathKind, Pattern, Statement, UnresolvedType, UnresolvedTypeData, Visibility, + IndexExpression, LetStatement, Literal, MemberAccessExpression, MethodCallExpression, + NoirFunction, ParsedModule, Path, PathKind, Pattern, Statement, UnresolvedType, + UnresolvedTypeData, Visibility, }; use noirc_errors::FileDiagnostic; @@ -33,6 +34,10 @@ fn variable(name: &str) -> Expression { expression(ExpressionKind::Variable(ident_path(name))) } +fn variable_ident(identifier: Ident) -> Expression { + expression(ExpressionKind::Variable(path(identifier))) +} + fn variable_path(path: Path) -> Expression { expression(ExpressionKind::Variable(path)) } @@ -61,6 +66,13 @@ fn mutable_assignment(name: &str, assigned_to: Expression) -> Statement { }) } +fn member_access(lhs: &str, rhs: &str) -> Expression { + expression(ExpressionKind::MemberAccess(Box::new(MemberAccessExpression { + lhs: variable(lhs), + rhs: ident(rhs), + }))) +} + macro_rules! chained_path { ( $base:expr $(, $tail:expr)* ) => { { @@ -101,6 +113,13 @@ fn index_array(array: Ident, index: &str) -> Expression { }))) } +fn index_array_variable(array: Expression, index: &str) -> Expression { + expression(ExpressionKind::Index(Box::new(IndexExpression { + collection: array, + index: variable(index), + }))) +} + fn import(path: Path) -> ImportStatement { ImportStatement { path, alias: None } } @@ -203,6 +222,11 @@ fn transform_function(ty: &str, func: &mut NoirFunction) { let input = create_inputs(&inputs_name); func.def.parameters.insert(0, input); + // Abstract return types such that they get added to the kernel's return_values + if let Some(return_values) = abstract_return_values(func) { + func.def.body.0.push(return_values); + } + // Push the finish method call to the end of the function let finish_def = create_context_finish(); func.def.body.0.push(finish_def); @@ -332,6 +356,124 @@ fn create_context(ty: &str, params: &[(Pattern, UnresolvedType, Visibility)]) -> injected_expressions } +/// Abstract Return Type +/// +/// This function intercepts the function's current return type and replaces it with pushes +/// To the kernel +/// +/// The replaced code: +/// ```noir +/// /// Before +/// #[aztec(private)] +/// fn foo() -> abi::PrivateCircuitPublicInputs { +/// // ... +/// let my_return_value: Field = 10; +/// context.return_values.push(my_return_value); +/// } +/// +/// /// After +/// #[aztec(private)] +/// fn foo() -> Field { +/// // ... +/// let my_return_value: Field = 10; +/// my_return_value +/// } +/// ``` +/// Similarly; Structs will be pushed to the context, after serialize() is called on them. +/// Arrays will be iterated over and each element will be pushed to the context. +/// Any primitive type that can be cast will be casted to a field and pushed to the context. +fn abstract_return_values(func: &mut NoirFunction) -> Option { + let current_return_type = func.return_type().typ; + let len = func.def.body.len(); + let last_statement = &func.def.body.0[len - 1]; + + // TODO: (length, type) => We can limit the size of the array returned to be limited by kernel size + // Doesnt need done until we have settled on a kernel size + // TODO: support tuples here and in inputs -> convert into an issue + + // Check if the return type is an expression, if it is, we can handle it + match last_statement { + Statement::Expression(expression) => match current_return_type { + // Call serialize on structs, push the whole array, calling push_array + UnresolvedTypeData::Named(..) => Some(make_struct_return_type(expression.clone())), + UnresolvedTypeData::Array(..) => Some(make_array_return_type(expression.clone())), + // Cast these types to a field before pushing + UnresolvedTypeData::Bool | UnresolvedTypeData::Integer(..) => { + Some(make_castable_return_type(expression.clone())) + } + UnresolvedTypeData::FieldElement => Some(make_return_push(expression.clone())), + _ => None, + }, + _ => None, + } +} + +/// Context Return Values +/// +/// Creates an instance to the context return values +/// ```noir +/// `context.return_values` +/// ``` +fn context_return_values() -> Expression { + member_access("context", "return_values") +} + +/// Make return Push +/// +/// Translates to: +/// `context.return_values.push({push_value})` +fn make_return_push(push_value: Expression) -> Statement { + Statement::Semi(method_call(context_return_values(), "push", vec![push_value])) +} + +/// Make Return push array +/// +/// Translates to: +/// `context.return_values.push_array({push_value})` +fn make_return_push_array(push_value: Expression) -> Statement { + Statement::Semi(method_call(context_return_values(), "push_array", vec![push_value])) +} + +/// Make struct return type +/// +/// Translates to: +/// ```noir +/// `context.return_values.push_array({push_value}.serialize())` +fn make_struct_return_type(expression: Expression) -> Statement { + let serialised_call = method_call( + expression.clone(), // variable + "serialize", // method name + vec![], // args + ); + make_return_push_array(serialised_call) +} + +/// Make array return type +/// +/// Translates to: +/// ```noir +/// for i in 0..{ident}.len() { +/// context.return_values.push({ident}[i] as Field) +/// } +/// ``` +fn make_array_return_type(expression: Expression) -> Statement { + let inner_cast_expression = + cast(index_array_variable(expression.clone(), "i"), UnresolvedTypeData::FieldElement); + create_loop_over(expression.clone(), vec![inner_cast_expression]) +} + +/// Castable return type +/// +/// Translates to: +/// ```noir +/// context.return_values.push({ident} as Field) +/// ``` +fn make_castable_return_type(expression: Expression) -> Statement { + // Cast these types to a field before pushing + let cast_expression = cast(expression.clone(), UnresolvedTypeData::FieldElement); + make_return_push(cast_expression) +} + /// Create Return Type /// /// Public functions return abi::PublicCircuitPublicInputs while @@ -407,30 +549,24 @@ fn add_struct_to_hasher(identifier: &Ident) -> Statement { )) } -fn add_array_to_hasher(identifier: &Ident) -> Statement { +fn create_loop_over(var: Expression, loop_body: Vec) -> Statement { // If this is an array of primitive types (integers / fields) we can add them each to the hasher // casted to a field // `array.len()` let end_range_expression = method_call( - variable_path(path(identifier.clone())), // variable - "len", // method name - vec![], // args + var.clone(), // variable + "len", // method name + vec![], // args ); - // Wrap in the semi thing - does that mean ended with semi colon? - // `hasher.add({ident}[i] as Field)` - let cast_expression = cast( - index_array(identifier.clone(), "i"), // lhs - `ident[i]` - UnresolvedTypeData::FieldElement, // cast to - `as Field` - ); // What will be looped over // - `hasher.add({ident}[i] as Field)` let for_loop_block = expression(ExpressionKind::Block(BlockExpression(vec![Statement::Semi(method_call( variable("hasher"), // variable "add", // method name - vec![cast_expression], + loop_body, ))]))); // `for i in 0..{ident}.len()` @@ -444,6 +580,19 @@ fn add_array_to_hasher(identifier: &Ident) -> Statement { })))) } +fn add_array_to_hasher(identifier: &Ident) -> Statement { + // If this is an array of primitive types (integers / fields) we can add them each to the hasher + // casted to a field + + // Wrap in the semi thing - does that mean ended with semi colon? + // `hasher.add({ident}[i] as Field)` + let cast_expression = cast( + index_array(identifier.clone(), "i"), // lhs - `ident[i]` + UnresolvedTypeData::FieldElement, // cast to - `as Field` + ); + create_loop_over(variable_ident(identifier.clone()), vec![cast_expression]) +} + fn add_field_to_hasher(identifier: &Ident) -> Statement { // `hasher.add({ident})` let iden = variable_path(path(identifier.clone()));