From 47b372c1762ed1184bf2ed9b90d7dc3e2c161880 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 2 Aug 2023 17:17:29 +0100 Subject: [PATCH 1/9] feat: Optimize away constant calls to black box functions (#1981) * feat: optimize away constant calls to black box functions * chore: remove `use SimplifyResult::*` * chore: remove unnecessary match arms * Update crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs * Update crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs --------- Co-authored-by: jfecher --- .../src/ssa_refactor/ir/instruction.rs | 156 +------- .../src/ssa_refactor/ir/instruction/call.rs | 334 ++++++++++++++++++ 2 files changed, 338 insertions(+), 152 deletions(-) create mode 100644 crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index afb47d423e2..7edb74f4206 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use acvm::{acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; use num_bigint::BigUint; @@ -14,6 +12,10 @@ use super::{ value::{Value, ValueId}, }; +mod call; + +use call::simplify_call; + /// Reference to an instruction /// /// Note that InstructionIds are not unique. That is, two InstructionIds @@ -385,156 +387,6 @@ fn simplify_cast(value: ValueId, dst_typ: &Type, dfg: &mut DataFlowGraph) -> Sim } } -/// Try to simplify this call instruction. If the instruction can be simplified to a known value, -/// that value is returned. Otherwise None is returned. -fn simplify_call(func: ValueId, arguments: &[ValueId], dfg: &mut DataFlowGraph) -> SimplifyResult { - use SimplifyResult::*; - let intrinsic = match &dfg[func] { - Value::Intrinsic(intrinsic) => *intrinsic, - _ => return None, - }; - - let constant_args: Option> = - arguments.iter().map(|value_id| dfg.get_numeric_constant(*value_id)).collect(); - - match intrinsic { - Intrinsic::ToBits(endian) => { - if let Some(constant_args) = constant_args { - let field = constant_args[0]; - let limb_count = constant_args[1].to_u128() as u32; - SimplifiedTo(constant_to_radix(endian, field, 2, limb_count, dfg)) - } else { - None - } - } - Intrinsic::ToRadix(endian) => { - if let Some(constant_args) = constant_args { - let field = constant_args[0]; - let radix = constant_args[1].to_u128() as u32; - let limb_count = constant_args[2].to_u128() as u32; - SimplifiedTo(constant_to_radix(endian, field, radix, limb_count, dfg)) - } else { - None - } - } - Intrinsic::ArrayLen => { - let slice = dfg.get_array_constant(arguments[0]); - if let Some((slice, _)) = slice { - SimplifiedTo(dfg.make_constant((slice.len() as u128).into(), Type::field())) - } else if let Some(length) = dfg.try_get_array_length(arguments[0]) { - SimplifiedTo(dfg.make_constant((length as u128).into(), Type::field())) - } else { - None - } - } - Intrinsic::SlicePushBack => { - let slice = dfg.get_array_constant(arguments[0]); - if let (Some((mut slice, element_type)), elem) = (slice, arguments[1]) { - slice.push_back(elem); - let new_slice = dfg.make_array(slice, element_type); - SimplifiedTo(new_slice) - } else { - None - } - } - Intrinsic::SlicePushFront => { - let slice = dfg.get_array_constant(arguments[0]); - if let (Some((mut slice, element_type)), elem) = (slice, arguments[1]) { - slice.push_front(elem); - let new_slice = dfg.make_array(slice, element_type); - SimplifiedTo(new_slice) - } else { - None - } - } - Intrinsic::SlicePopBack => { - let slice = dfg.get_array_constant(arguments[0]); - if let Some((mut slice, element_type)) = slice { - let elem = - slice.pop_back().expect("There are no elements in this slice to be removed"); - let new_slice = dfg.make_array(slice, element_type); - SimplifiedToMultiple(vec![new_slice, elem]) - } else { - None - } - } - Intrinsic::SlicePopFront => { - let slice = dfg.get_array_constant(arguments[0]); - if let Some((mut slice, element_type)) = slice { - let elem = - slice.pop_front().expect("There are no elements in this slice to be removed"); - let new_slice = dfg.make_array(slice, element_type); - SimplifiedToMultiple(vec![elem, new_slice]) - } else { - None - } - } - Intrinsic::SliceInsert => { - let slice = dfg.get_array_constant(arguments[0]); - let index = dfg.get_numeric_constant(arguments[1]); - if let (Some((mut slice, element_type)), Some(index), value) = - (slice, index, arguments[2]) - { - slice.insert(index.to_u128() as usize, value); - let new_slice = dfg.make_array(slice, element_type); - SimplifiedTo(new_slice) - } else { - None - } - } - Intrinsic::SliceRemove => { - let slice = dfg.get_array_constant(arguments[0]); - let index = dfg.get_numeric_constant(arguments[1]); - if let (Some((mut slice, element_type)), Some(index)) = (slice, index) { - let removed_elem = slice.remove(index.to_u128() as usize); - let new_slice = dfg.make_array(slice, element_type); - SimplifiedToMultiple(vec![new_slice, removed_elem]) - } else { - None - } - } - Intrinsic::BlackBox(_) | Intrinsic::Println | Intrinsic::Sort => None, - } -} - -/// Returns a Value::Array of constants corresponding to the limbs of the radix decomposition. -fn constant_to_radix( - endian: Endian, - field: FieldElement, - radix: u32, - limb_count: u32, - dfg: &mut DataFlowGraph, -) -> ValueId { - let bit_size = u32::BITS - (radix - 1).leading_zeros(); - let radix_big = BigUint::from(radix); - assert_eq!(BigUint::from(2u128).pow(bit_size), radix_big, "ICE: Radix must be a power of 2"); - let big_integer = BigUint::from_bytes_be(&field.to_be_bytes()); - - // Decompose the integer into its radix digits in little endian form. - let decomposed_integer = big_integer.to_radix_le(radix); - let mut limbs = vecmap(0..limb_count, |i| match decomposed_integer.get(i as usize) { - Some(digit) => FieldElement::from_be_bytes_reduce(&[*digit]), - None => FieldElement::zero(), - }); - if endian == Endian::Big { - limbs.reverse(); - } - - // For legacy reasons (see #617) the to_radix interface supports 256 bits even though - // FieldElement::max_num_bits() is only 254 bits. Any limbs beyond the specified count - // become zero padding. - let max_decomposable_bits: u32 = 256; - let limb_count_with_padding = max_decomposable_bits / bit_size; - while limbs.len() < limb_count_with_padding as usize { - limbs.push(FieldElement::zero()); - } - let result_constants: im::Vector = - limbs.into_iter().map(|limb| dfg.make_constant(limb, Type::unsigned(bit_size))).collect(); - - let typ = Type::Array(Rc::new(vec![Type::unsigned(bit_size)]), result_constants.len()); - dfg.make_array(result_constants, typ) -} - /// The possible return values for Instruction::return_types pub(crate) enum InstructionResultType { /// The result type of this instruction matches that of this operand diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs new file mode 100644 index 00000000000..96998d92fcf --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs @@ -0,0 +1,334 @@ +use std::rc::Rc; + +use acvm::{acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement}; +use iter_extended::vecmap; +use num_bigint::BigUint; + +use crate::ssa_refactor::ir::{ + dfg::DataFlowGraph, + instruction::Intrinsic, + map::Id, + types::Type, + value::{Value, ValueId}, +}; + +use super::{Endian, SimplifyResult}; + +/// Try to simplify this call instruction. If the instruction can be simplified to a known value, +/// that value is returned. Otherwise None is returned. +pub(super) fn simplify_call( + func: ValueId, + arguments: &[ValueId], + dfg: &mut DataFlowGraph, +) -> SimplifyResult { + let intrinsic = match &dfg[func] { + Value::Intrinsic(intrinsic) => *intrinsic, + _ => return SimplifyResult::None, + }; + + let constant_args: Option> = + arguments.iter().map(|value_id| dfg.get_numeric_constant(*value_id)).collect(); + + match intrinsic { + Intrinsic::ToBits(endian) => { + if let Some(constant_args) = constant_args { + let field = constant_args[0]; + let limb_count = constant_args[1].to_u128() as u32; + SimplifyResult::SimplifiedTo(constant_to_radix(endian, field, 2, limb_count, dfg)) + } else { + SimplifyResult::None + } + } + Intrinsic::ToRadix(endian) => { + if let Some(constant_args) = constant_args { + let field = constant_args[0]; + let radix = constant_args[1].to_u128() as u32; + let limb_count = constant_args[2].to_u128() as u32; + SimplifyResult::SimplifiedTo(constant_to_radix( + endian, field, radix, limb_count, dfg, + )) + } else { + SimplifyResult::None + } + } + Intrinsic::ArrayLen => { + let slice = dfg.get_array_constant(arguments[0]); + if let Some((slice, _)) = slice { + SimplifyResult::SimplifiedTo( + dfg.make_constant((slice.len() as u128).into(), Type::field()), + ) + } else if let Some(length) = dfg.try_get_array_length(arguments[0]) { + SimplifyResult::SimplifiedTo( + dfg.make_constant((length as u128).into(), Type::field()), + ) + } else { + SimplifyResult::None + } + } + Intrinsic::SlicePushBack => { + let slice = dfg.get_array_constant(arguments[0]); + if let (Some((mut slice, element_type)), elem) = (slice, arguments[1]) { + slice.push_back(elem); + let new_slice = dfg.make_array(slice, element_type); + SimplifyResult::SimplifiedTo(new_slice) + } else { + SimplifyResult::None + } + } + Intrinsic::SlicePushFront => { + let slice = dfg.get_array_constant(arguments[0]); + if let (Some((mut slice, element_type)), elem) = (slice, arguments[1]) { + slice.push_front(elem); + let new_slice = dfg.make_array(slice, element_type); + SimplifyResult::SimplifiedTo(new_slice) + } else { + SimplifyResult::None + } + } + Intrinsic::SlicePopBack => { + let slice = dfg.get_array_constant(arguments[0]); + if let Some((mut slice, element_type)) = slice { + let elem = + slice.pop_back().expect("There are no elements in this slice to be removed"); + let new_slice = dfg.make_array(slice, element_type); + SimplifyResult::SimplifiedToMultiple(vec![new_slice, elem]) + } else { + SimplifyResult::None + } + } + Intrinsic::SlicePopFront => { + let slice = dfg.get_array_constant(arguments[0]); + if let Some((mut slice, element_type)) = slice { + let elem = + slice.pop_front().expect("There are no elements in this slice to be removed"); + let new_slice = dfg.make_array(slice, element_type); + SimplifyResult::SimplifiedToMultiple(vec![elem, new_slice]) + } else { + SimplifyResult::None + } + } + Intrinsic::SliceInsert => { + let slice = dfg.get_array_constant(arguments[0]); + let index = dfg.get_numeric_constant(arguments[1]); + if let (Some((mut slice, element_type)), Some(index), value) = + (slice, index, arguments[2]) + { + slice.insert(index.to_u128() as usize, value); + let new_slice = dfg.make_array(slice, element_type); + SimplifyResult::SimplifiedTo(new_slice) + } else { + SimplifyResult::None + } + } + Intrinsic::SliceRemove => { + let slice = dfg.get_array_constant(arguments[0]); + let index = dfg.get_numeric_constant(arguments[1]); + if let (Some((mut slice, element_type)), Some(index)) = (slice, index) { + let removed_elem = slice.remove(index.to_u128() as usize); + let new_slice = dfg.make_array(slice, element_type); + SimplifyResult::SimplifiedToMultiple(vec![new_slice, removed_elem]) + } else { + SimplifyResult::None + } + } + Intrinsic::BlackBox(bb_func) => simplify_black_box_func(bb_func, arguments, dfg), + Intrinsic::Println | Intrinsic::Sort => SimplifyResult::None, + } +} + +/// Try to simplify this black box call. If the call can be simplified to a known value, +/// that value is returned. Otherwise [`SimplifyResult::None`] is returned. +fn simplify_black_box_func( + bb_func: BlackBoxFunc, + arguments: &[ValueId], + dfg: &mut DataFlowGraph, +) -> SimplifyResult { + match bb_func { + BlackBoxFunc::SHA256 => simplify_hash(dfg, arguments, acvm::blackbox_solver::sha256), + BlackBoxFunc::Blake2s => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake2s), + BlackBoxFunc::Keccak256 => { + match (dfg.get_array_constant(arguments[0]), dfg.get_numeric_constant(arguments[1])) { + (Some((input, _)), Some(num_bytes)) if array_is_constant(dfg, &input) => { + let input_bytes: Vec = to_u8_vec(dfg, input); + + let num_bytes = num_bytes.to_u128() as usize; + let truncated_input_bytes = &input_bytes[0..num_bytes]; + let hash = acvm::blackbox_solver::keccak256(truncated_input_bytes) + .expect("Rust solvable black box function should not fail"); + + let hash_values = + vecmap(hash, |byte| FieldElement::from_be_bytes_reduce(&[byte])); + + let result_array = make_constant_array(dfg, hash_values, Type::unsigned(8)); + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } + } + BlackBoxFunc::HashToField128Security => match dfg.get_array_constant(arguments[0]) { + Some((input, _)) if array_is_constant(dfg, &input) => { + let input_bytes: Vec = to_u8_vec(dfg, input); + + let field = acvm::blackbox_solver::hash_to_field_128_security(&input_bytes) + .expect("Rust solvable black box function should not fail"); + + let field_constant = dfg.make_constant(field, Type::field()); + SimplifyResult::SimplifiedTo(field_constant) + } + _ => SimplifyResult::None, + }, + + BlackBoxFunc::EcdsaSecp256k1 => { + simplify_signature(dfg, arguments, acvm::blackbox_solver::ecdsa_secp256k1_verify) + } + BlackBoxFunc::EcdsaSecp256r1 => { + simplify_signature(dfg, arguments, acvm::blackbox_solver::ecdsa_secp256r1_verify) + } + + BlackBoxFunc::FixedBaseScalarMul | BlackBoxFunc::SchnorrVerify | BlackBoxFunc::Pedersen => { + // Currently unsolvable here as we rely on an implementation in the backend. + SimplifyResult::None + } + + BlackBoxFunc::RecursiveAggregation => SimplifyResult::None, + + BlackBoxFunc::AND => { + unreachable!("ICE: `BlackBoxFunc::AND` calls should be transformed into a `BinaryOp`") + } + BlackBoxFunc::XOR => { + unreachable!("ICE: `BlackBoxFunc::XOR` calls should be transformed into a `BinaryOp`") + } + BlackBoxFunc::RANGE => { + unreachable!( + "ICE: `BlackBoxFunc::RANGE` calls should be transformed into a `Instruction::Cast`" + ) + } + } +} + +fn make_constant_array(dfg: &mut DataFlowGraph, results: Vec, typ: Type) -> ValueId { + let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); + + let typ = Type::Array(Rc::new(vec![typ]), result_constants.len()); + dfg.make_array(result_constants.into(), typ) +} + +/// Returns a Value::Array of constants corresponding to the limbs of the radix decomposition. +fn constant_to_radix( + endian: Endian, + field: FieldElement, + radix: u32, + limb_count: u32, + dfg: &mut DataFlowGraph, +) -> ValueId { + let bit_size = u32::BITS - (radix - 1).leading_zeros(); + let radix_big = BigUint::from(radix); + assert_eq!(BigUint::from(2u128).pow(bit_size), radix_big, "ICE: Radix must be a power of 2"); + let big_integer = BigUint::from_bytes_be(&field.to_be_bytes()); + + // Decompose the integer into its radix digits in little endian form. + let decomposed_integer = big_integer.to_radix_le(radix); + let mut limbs = vecmap(0..limb_count, |i| match decomposed_integer.get(i as usize) { + Some(digit) => FieldElement::from_be_bytes_reduce(&[*digit]), + None => FieldElement::zero(), + }); + if endian == Endian::Big { + limbs.reverse(); + } + + // For legacy reasons (see #617) the to_radix interface supports 256 bits even though + // FieldElement::max_num_bits() is only 254 bits. Any limbs beyond the specified count + // become zero padding. + let max_decomposable_bits: u32 = 256; + let limb_count_with_padding = max_decomposable_bits / bit_size; + while limbs.len() < limb_count_with_padding as usize { + limbs.push(FieldElement::zero()); + } + + make_constant_array(dfg, limbs, Type::unsigned(bit_size)) +} + +fn to_u8_vec(dfg: &DataFlowGraph, values: im::Vector>) -> Vec { + values + .iter() + .map(|id| { + let field = dfg + .get_numeric_constant(*id) + .expect("value id from array should point at constant"); + *field.to_be_bytes().last().unwrap() + }) + .collect() +} + +fn array_is_constant(dfg: &DataFlowGraph, values: &im::Vector>) -> bool { + values.iter().all(|value| dfg.get_numeric_constant(*value).is_some()) +} + +fn simplify_hash( + dfg: &mut DataFlowGraph, + arguments: &[ValueId], + hash_function: fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, +) -> SimplifyResult { + match dfg.get_array_constant(arguments[0]) { + Some((input, _)) if array_is_constant(dfg, &input) => { + let input_bytes: Vec = to_u8_vec(dfg, input); + + let hash = hash_function(&input_bytes) + .expect("Rust solvable black box function should not fail"); + + let hash_values = vecmap(hash, |byte| FieldElement::from_be_bytes_reduce(&[byte])); + + let result_array = make_constant_array(dfg, hash_values, Type::unsigned(8)); + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +type ECDSASignatureVerifier = fn( + hashed_msg: &[u8], + public_key_x: &[u8; 32], + public_key_y: &[u8; 32], + signature: &[u8; 64], +) -> Result; +fn simplify_signature( + dfg: &mut DataFlowGraph, + arguments: &[ValueId], + signature_verifier: ECDSASignatureVerifier, +) -> SimplifyResult { + match ( + dfg.get_array_constant(arguments[0]), + dfg.get_array_constant(arguments[1]), + dfg.get_array_constant(arguments[2]), + dfg.get_array_constant(arguments[3]), + ) { + ( + Some((public_key_x, _)), + Some((public_key_y, _)), + Some((signature, _)), + Some((hashed_message, _)), + ) if array_is_constant(dfg, &public_key_x) + && array_is_constant(dfg, &public_key_y) + && array_is_constant(dfg, &signature) + && array_is_constant(dfg, &hashed_message) => + { + let public_key_x: [u8; 32] = to_u8_vec(dfg, public_key_x) + .try_into() + .expect("ECDSA public key fields are 32 bytes"); + let public_key_y: [u8; 32] = to_u8_vec(dfg, public_key_y) + .try_into() + .expect("ECDSA public key fields are 32 bytes"); + let signature: [u8; 64] = + to_u8_vec(dfg, signature).try_into().expect("ECDSA signatures are 64 bytes"); + let hashed_message: Vec = to_u8_vec(dfg, hashed_message); + + let valid_signature = + signature_verifier(&hashed_message, &public_key_x, &public_key_y, &signature) + .expect("Rust solvable black box function should not fail"); + + let valid_signature = dfg.make_constant(valid_signature.into(), Type::bool()); + SimplifyResult::SimplifiedTo(valid_signature) + } + _ => SimplifyResult::None, + } +} From 1c21d0caf1e3b3a92266b4b8238f3e6e6c394d05 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Wed, 2 Aug 2023 17:21:35 +0100 Subject: [PATCH 2/9] fix(globals): Accurately filter literals for resolving globals (#2126) accurately filter literals for resolving globals --- .../tests/test_data/global_consts/src/main.nr | 7 +++++++ .../tests/test_data/strings/src/main.nr | 6 +++++- .../src/hir/def_collector/dc_crate.rs | 20 ++++++++++--------- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/crates/nargo_cli/tests/test_data/global_consts/src/main.nr b/crates/nargo_cli/tests/test_data/global_consts/src/main.nr index 9bcca2b8071..2ed6e4593dd 100644 --- a/crates/nargo_cli/tests/test_data/global_consts/src/main.nr +++ b/crates/nargo_cli/tests/test_data/global_consts/src/main.nr @@ -12,12 +12,19 @@ struct Dummy { y: [Field; foo::MAGIC_NUMBER] } +struct Test { + v: Field, +} +global VALS: [Test; 1] = [Test { v: 100 }]; +global NESTED = [VALS, VALS]; + fn main(a: [Field; M + N - N], b: [Field; 30 + N / 2], c : pub [Field; foo::MAGIC_NUMBER], d: [Field; foo::bar::N]) { let test_struct = Dummy { x: d, y: c }; for i in 0..foo::MAGIC_NUMBER { assert(c[i] == foo::MAGIC_NUMBER); assert(test_struct.y[i] == foo::MAGIC_NUMBER); + assert(test_struct.y[i] != NESTED[1][0].v); } assert(N != M); diff --git a/crates/nargo_cli/tests/test_data/strings/src/main.nr b/crates/nargo_cli/tests/test_data/strings/src/main.nr index bee2370201c..edf5fff55b4 100644 --- a/crates/nargo_cli/tests/test_data/strings/src/main.nr +++ b/crates/nargo_cli/tests/test_data/strings/src/main.nr @@ -1,10 +1,13 @@ use dep::std; +// Test global string literals +global HELLO_WORLD = "hello world"; + fn main(message : pub str<11>, y : Field, hex_as_string : str<4>, hex_as_field : Field) { let mut bad_message = "hello world"; assert(message == "hello world"); - bad_message = "helld world"; + assert(message == HELLO_WORLD); let x = 10; let z = x * 5; std::println(10); @@ -16,6 +19,7 @@ fn main(message : pub str<11>, y : Field, hex_as_string : str<4>, hex_as_field : assert(y == 5); // Change to y != 5 to see how the later print statements are not called std::println(array); + bad_message = "helld world"; std::println(bad_message); assert(message != bad_message); diff --git a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs index e974961a405..76fbea289be 100644 --- a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -13,7 +13,7 @@ use crate::hir::Context; use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TypeAliasId}; use crate::{ ExpressionKind, Generics, Ident, LetStatement, NoirFunction, NoirStruct, NoirTypeAlias, - ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, + ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, Literal, }; use fm::FileId; use iter_extended::vecmap; @@ -161,10 +161,10 @@ impl DefCollector { // // Additionally, we must resolve integer globals before structs since structs may refer to // the values of integer globals as numeric generics. - let (integer_globals, other_globals) = - filter_integer_globals(def_collector.collected_globals); + let (literal_globals, other_globals) = + filter_literal_globals(def_collector.collected_globals); - let mut file_global_ids = resolve_globals(context, integer_globals, crate_id, errors); + let mut file_global_ids = resolve_globals(context, literal_globals, crate_id, errors); resolve_type_aliases(context, def_collector.collected_type_aliases, crate_id, errors); @@ -274,13 +274,15 @@ where } /// Separate the globals Vec into two. The first element in the tuple will be the -/// integer literal globals, and the second will be all other globals. -fn filter_integer_globals( +/// literal globals, except for arrays, and the second will be all other globals. +/// We exclude array literals as they can contain complex types +fn filter_literal_globals( globals: Vec, ) -> (Vec, Vec) { - globals - .into_iter() - .partition(|global| matches!(&global.stmt_def.expression.kind, ExpressionKind::Literal(_))) + globals.into_iter().partition(|global| match &global.stmt_def.expression.kind { + ExpressionKind::Literal(literal) => !matches!(literal, Literal::Array(_)), + _ => false, + }) } fn resolve_globals( From 27ab78f3e298e94202b8dcc9ea44075a185a78e7 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Wed, 2 Aug 2023 19:15:45 +0100 Subject: [PATCH 3/9] chore: Use `--show-output` flag on execution rather than compilation (#2116) * move show-output to occur on execute rather than compilation * remove assert(false) from test * fix compile err * report compile errors in tests * aupdate failing constraint test * change comment and link issue --- crates/nargo/src/ops/execute.rs | 3 +- crates/nargo/src/ops/foreign_calls.rs | 5 ++- crates/nargo_cli/src/cli/execute_cmd.rs | 2 +- crates/nargo_cli/src/cli/test_cmd.rs | 9 +++-- .../tests/test_data/strings/src/main.nr | 20 ++++++++-- crates/noirc_driver/src/lib.rs | 8 ++-- crates/noirc_evaluator/src/ssa_refactor.rs | 6 +-- .../acir_gen/acir_ir/acir_variable.rs | 13 ------- .../src/ssa_refactor/acir_gen/mod.rs | 37 +++++-------------- crates/wasm/src/compile.rs | 4 +- 10 files changed, 46 insertions(+), 61 deletions(-) diff --git a/crates/nargo/src/ops/execute.rs b/crates/nargo/src/ops/execute.rs index 13ea64ed261..2a126443468 100644 --- a/crates/nargo/src/ops/execute.rs +++ b/crates/nargo/src/ops/execute.rs @@ -10,6 +10,7 @@ pub fn execute_circuit( _backend: &B, circuit: Circuit, initial_witness: WitnessMap, + show_output: bool, ) -> Result { let mut acvm = ACVM::new(B::default(), circuit.opcodes, initial_witness); @@ -23,7 +24,7 @@ pub fn execute_circuit( } ACVMStatus::Failure(error) => return Err(error.into()), ACVMStatus::RequiresForeignCall(foreign_call) => { - let foreign_call_result = ForeignCall::execute(&foreign_call)?; + let foreign_call_result = ForeignCall::execute(&foreign_call, show_output)?; acvm.resolve_pending_foreign_call(foreign_call_result); } } diff --git a/crates/nargo/src/ops/foreign_calls.rs b/crates/nargo/src/ops/foreign_calls.rs index 2abc62b1032..4d2f5988e38 100644 --- a/crates/nargo/src/ops/foreign_calls.rs +++ b/crates/nargo/src/ops/foreign_calls.rs @@ -42,11 +42,14 @@ impl ForeignCall { pub(crate) fn execute( foreign_call: &ForeignCallWaitInfo, + show_output: bool, ) -> Result { let foreign_call_name = foreign_call.function.as_str(); match Self::lookup(foreign_call_name) { Some(ForeignCall::Println) => { - Self::execute_println(&foreign_call.inputs)?; + if show_output { + Self::execute_println(&foreign_call.inputs)?; + } Ok(ForeignCallResult { values: vec![] }) } Some(ForeignCall::Sequence) => { diff --git a/crates/nargo_cli/src/cli/execute_cmd.rs b/crates/nargo_cli/src/cli/execute_cmd.rs index ca5c18585ab..a2700caee0f 100644 --- a/crates/nargo_cli/src/cli/execute_cmd.rs +++ b/crates/nargo_cli/src/cli/execute_cmd.rs @@ -132,7 +132,7 @@ pub(crate) fn execute_program( debug_data: Option<(DebugInfo, Context)>, ) -> Result> { let initial_witness = abi.encode(inputs_map, None)?; - let solved_witness_err = nargo::ops::execute_circuit(backend, circuit, initial_witness); + let solved_witness_err = nargo::ops::execute_circuit(backend, circuit, initial_witness, true); match solved_witness_err { Ok(solved_witness) => Ok(solved_witness), Err(err) => { diff --git a/crates/nargo_cli/src/cli/test_cmd.rs b/crates/nargo_cli/src/cli/test_cmd.rs index 7eb1c9bff74..e52e3e5aa8d 100644 --- a/crates/nargo_cli/src/cli/test_cmd.rs +++ b/crates/nargo_cli/src/cli/test_cmd.rs @@ -106,14 +106,17 @@ fn run_test( show_output: bool, config: &CompileOptions, ) -> Result<(), CliError> { - let mut program = compile_no_check(context, show_output, config, main) - .map_err(|_| CliError::Generic(format!("Test '{test_name}' failed to compile")))?; + let mut program = compile_no_check(context, config, main).map_err(|err| { + noirc_errors::reporter::report_all(&context.file_manager, &[err], config.deny_warnings); + CliError::Generic(format!("Test '{test_name}' failed to compile")) + })?; + // Note: We could perform this test using the unoptimized ACIR as generated by `compile_no_check`. program.circuit = optimize_circuit(backend, program.circuit).unwrap().0; // Run the backend to ensure the PWG evaluates functions like std::hash::pedersen, // otherwise constraints involving these expressions will not error. - match execute_circuit(backend, program.circuit, WitnessMap::new()) { + match execute_circuit(backend, program.circuit, WitnessMap::new(), show_output) { Ok(_) => Ok(()), Err(error) => { let writer = StandardStream::stderr(ColorChoice::Always); diff --git a/crates/nargo_cli/tests/test_data/strings/src/main.nr b/crates/nargo_cli/tests/test_data/strings/src/main.nr index edf5fff55b4..9f122c3a137 100644 --- a/crates/nargo_cli/tests/test_data/strings/src/main.nr +++ b/crates/nargo_cli/tests/test_data/strings/src/main.nr @@ -43,9 +43,8 @@ fn test_prints_strings() { fn test_prints_array() { let array = [1, 2, 3, 5, 8]; - // TODO: Printing structs currently not supported - // let s = Test { a: 1, b: 2, c: [3, 4] }; - // std::println(s); + let s = Test { a: 1, b: 2, c: [3, 4] }; + std::println(s); std::println(array); @@ -53,6 +52,21 @@ fn test_prints_array() { std::println(hash); } +fn failed_constraint(hex_as_field: Field) { + // TODO(#2116): Note that `println` will not work if a failed constraint can be + // evaluated at compile time. + // When this method is called from a test method or with constant values + // a `Failed constraint` compile error will be caught before this `println` + // is executed as the input will be a constant. + std::println(hex_as_field); + assert(hex_as_field != 0x41); +} + +#[test] +fn test_failed_constraint() { + failed_constraint(0x41); +} + struct Test { a: Field, b: Field, diff --git a/crates/noirc_driver/src/lib.rs b/crates/noirc_driver/src/lib.rs index 4d1b7fe2675..27109af6a2f 100644 --- a/crates/noirc_driver/src/lib.rs +++ b/crates/noirc_driver/src/lib.rs @@ -163,7 +163,7 @@ pub fn compile_main( } }; - let compiled_program = compile_no_check(context, true, options, main)?; + let compiled_program = compile_no_check(context, options, main)?; if options.print_acir { println!("Compiled ACIR for main (unoptimized):"); @@ -230,7 +230,7 @@ fn compile_contract( let mut errs = Vec::new(); for function_id in &contract.functions { let name = context.function_name(function_id).to_owned(); - let function = match compile_no_check(context, true, options, *function_id) { + let function = match compile_no_check(context, options, *function_id) { Ok(function) => function, Err(err) => { errs.push(err); @@ -267,14 +267,12 @@ fn compile_contract( #[allow(deprecated)] pub fn compile_no_check( context: &Context, - show_output: bool, options: &CompileOptions, main_function: FuncId, ) -> Result { let program = monomorphize(main_function, &context.def_interner); - let (circuit, debug, abi) = - create_circuit(program, options.show_ssa, options.show_brillig, show_output)?; + let (circuit, debug, abi) = create_circuit(program, options.show_ssa, options.show_brillig)?; Ok(CompiledProgram { circuit, debug, abi }) } diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 6326b45554d..c57bb330b09 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -35,7 +35,6 @@ pub mod ssa_gen; /// convert the final SSA into ACIR and return it. pub(crate) fn optimize_into_acir( program: Program, - allow_log_ops: bool, print_ssa_passes: bool, print_brillig_trace: bool, ) -> Result { @@ -63,7 +62,7 @@ pub(crate) fn optimize_into_acir( .dead_instruction_elimination() .print(print_ssa_passes, "After Dead Instruction Elimination:"); } - ssa.into_acir(brillig, abi_distinctness, allow_log_ops) + ssa.into_acir(brillig, abi_distinctness) } /// Compiles the Program into ACIR and applies optimizations to the arithmetic gates @@ -74,7 +73,6 @@ pub fn create_circuit( program: Program, enable_ssa_logging: bool, enable_brillig_logging: bool, - show_output: bool, ) -> Result<(Circuit, DebugInfo, Abi), RuntimeError> { let func_sig = program.main_function_signature.clone(); let GeneratedAcir { @@ -84,7 +82,7 @@ pub fn create_circuit( locations, input_witnesses, .. - } = optimize_into_acir(program, show_output, enable_ssa_logging, enable_brillig_logging)?; + } = optimize_into_acir(program, enable_ssa_logging, enable_brillig_logging)?; let abi = gen_abi(func_sig, &input_witnesses, return_witnesses.clone()); let public_abi = abi.clone().public_abi(); diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs index 9177dc9ae6c..d1479ef1f1b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs @@ -827,19 +827,6 @@ impl AcirContext { self.radix_decompose(endian, input_var, two_var, limb_count_var, result_element_type) } - /// Prints the given `AcirVar`s as witnesses. - pub(crate) fn print(&mut self, input: Vec) -> Result<(), RuntimeError> { - let input = Self::flatten_values(input); - - let witnesses = vecmap(input, |acir_var| { - let var_data = &self.vars[&acir_var]; - let expr = var_data.to_expression(); - self.acir_ir.get_or_create_witness(&expr) - }); - self.acir_ir.call_print(witnesses); - Ok(()) - } - /// Flatten the given Vector of AcirValues into a single vector of only variables. /// Each AcirValue::Array in the vector is recursively flattened, so each element /// will flattened into the resulting Vec. E.g. flatten_values([1, [2, 3]) == [1, 2, 3]. diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index f00f15d8f05..62a9dd5969d 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -103,10 +103,9 @@ impl Ssa { self, brillig: Brillig, abi_distinctness: AbiDistinctness, - allow_log_ops: bool, ) -> Result { let context = Context::new(); - let mut generated_acir = context.convert_ssa(self, brillig, allow_log_ops)?; + let mut generated_acir = context.convert_ssa(self, brillig)?; match abi_distinctness { AbiDistinctness::Distinct => { @@ -144,15 +143,10 @@ impl Context { } /// Converts SSA into ACIR - fn convert_ssa( - self, - ssa: Ssa, - brillig: Brillig, - allow_log_ops: bool, - ) -> Result { + fn convert_ssa(self, ssa: Ssa, brillig: Brillig) -> Result { let main_func = ssa.main(); match main_func.runtime() { - RuntimeType::Acir => self.convert_acir_main(main_func, &ssa, brillig, allow_log_ops), + RuntimeType::Acir => self.convert_acir_main(main_func, &ssa, brillig), RuntimeType::Brillig => self.convert_brillig_main(main_func, brillig), } } @@ -162,14 +156,13 @@ impl Context { main_func: &Function, ssa: &Ssa, brillig: Brillig, - allow_log_ops: bool, ) -> Result { let dfg = &main_func.dfg; let entry_block = &dfg[main_func.entry_block()]; let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?; for instruction_id in entry_block.instructions() { - self.convert_ssa_instruction(*instruction_id, dfg, ssa, &brillig, allow_log_ops)?; + self.convert_ssa_instruction(*instruction_id, dfg, ssa, &brillig)?; } self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?; @@ -294,7 +287,6 @@ impl Context { dfg: &DataFlowGraph, ssa: &Ssa, brillig: &Brillig, - allow_log_ops: bool, ) -> Result<(), RuntimeError> { let instruction = &dfg[instruction_id]; self.acir_context.set_location(dfg.get_location(&instruction_id)); @@ -339,13 +331,8 @@ impl Context { } } Value::Intrinsic(intrinsic) => { - let outputs = self.convert_ssa_intrinsic_call( - *intrinsic, - arguments, - dfg, - allow_log_ops, - result_ids, - )?; + let outputs = self + .convert_ssa_intrinsic_call(*intrinsic, arguments, dfg, result_ids)?; // Issue #1438 causes this check to fail with intrinsics that return 0 // results but the ssa form instead creates 1 unit result value. @@ -929,7 +916,6 @@ impl Context { intrinsic: Intrinsic, arguments: &[ValueId], dfg: &DataFlowGraph, - allow_log_ops: bool, result_ids: &[ValueId], ) -> Result, RuntimeError> { match intrinsic { @@ -959,13 +945,8 @@ impl Context { self.acir_context.bit_decompose(endian, field, bit_size, result_type) } - Intrinsic::Println => { - let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg)); - if allow_log_ops { - self.acir_context.print(inputs)?; - } - Ok(Vec::new()) - } + // TODO(#2115): Remove the println intrinsic as the oracle println is now used instead + Intrinsic::Println => Ok(Vec::new()), Intrinsic::Sort => { let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg)); // We flatten the inputs and retrieve the bit_size of the elements @@ -1133,7 +1114,7 @@ mod tests { let ssa = builder.finish(); let context = Context::new(); - let acir = context.convert_ssa(ssa, Brillig::default(), false).unwrap(); + let acir = context.convert_ssa(ssa, Brillig::default()).unwrap(); let expected_opcodes = vec![Opcode::Arithmetic(&Expression::one() - &Expression::from(Witness(1)))]; diff --git a/crates/wasm/src/compile.rs b/crates/wasm/src/compile.rs index 15d8d5107ea..4254110b849 100644 --- a/crates/wasm/src/compile.rs +++ b/crates/wasm/src/compile.rs @@ -107,8 +107,8 @@ pub fn compile(args: JsValue) -> JsValue { ::from_serde(&optimized_contracts).unwrap() } else { let main = context.get_main_function(&crate_id).expect("Could not find main function!"); - let mut compiled_program = compile_no_check(&context, true, &options.compile_options, main) - .expect("Compilation failed"); + let mut compiled_program = + compile_no_check(&context, &options.compile_options, main).expect("Compilation failed"); compiled_program.circuit = optimize_circuit(compiled_program.circuit); From a07b8a48924865d8425d35e40c75f48a13a81935 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 2 Aug 2023 20:00:23 +0100 Subject: [PATCH 4/9] chore: rename `ssa_refactor` module to `ssa` (#2129) --- .gitignore | 2 -- .../noirc_evaluator/src/brillig/brillig_gen.rs | 2 +- .../src/brillig/brillig_gen/brillig_block.rs | 9 +++++---- .../src/brillig/brillig_gen/brillig_fn.rs | 2 +- crates/noirc_evaluator/src/brillig/mod.rs | 2 +- crates/noirc_evaluator/src/lib.rs | 4 ++-- .../src/{ssa_refactor.rs => ssa.rs} | 0 .../src/{ssa_refactor => ssa}/abi_gen/mod.rs | 0 .../{ssa_refactor => ssa}/acir_gen/acir_ir.rs | 0 .../acir_gen/acir_ir/acir_variable.rs | 6 +++--- .../acir_gen/acir_ir/generated_acir.rs | 0 .../acir_gen/acir_ir/sort.rs | 0 .../src/{ssa_refactor => ssa}/acir_gen/mod.rs | 2 +- .../src/{ssa_refactor => ssa}/ir.rs | 0 .../src/{ssa_refactor => ssa}/ir/basic_block.rs | 0 .../src/{ssa_refactor => ssa}/ir/cfg.rs | 2 +- .../src/{ssa_refactor => ssa}/ir/dfg.rs | 4 ++-- .../src/{ssa_refactor => ssa}/ir/dom.rs | 2 +- .../src/{ssa_refactor => ssa}/ir/function.rs | 0 .../ir/function_inserter.rs | 0 .../src/{ssa_refactor => ssa}/ir/instruction.rs | 4 +--- .../{ssa_refactor => ssa}/ir/instruction/call.rs | 2 +- .../src/{ssa_refactor => ssa}/ir/map.rs | 0 .../src/{ssa_refactor => ssa}/ir/post_order.rs | 4 ++-- .../src/{ssa_refactor => ssa}/ir/printer.rs | 0 .../src/{ssa_refactor => ssa}/ir/types.rs | 0 .../src/{ssa_refactor => ssa}/ir/value.rs | 2 +- .../opt/constant_folding.rs | 4 ++-- .../{ssa_refactor => ssa}/opt/defunctionalize.rs | 2 +- .../src/{ssa_refactor => ssa}/opt/die.rs | 4 ++-- .../src/{ssa_refactor => ssa}/opt/flatten_cfg.rs | 6 +++--- .../opt/flatten_cfg/branch_analysis.rs | 6 ++---- .../src/{ssa_refactor => ssa}/opt/inlining.rs | 4 ++-- .../src/{ssa_refactor => ssa}/opt/mem2reg.rs | 4 ++-- .../src/{ssa_refactor => ssa}/opt/mod.rs | 0 .../{ssa_refactor => ssa}/opt/simplify_cfg.rs | 4 ++-- .../src/{ssa_refactor => ssa}/opt/unrolling.rs | 4 ++-- .../src/{ssa_refactor => ssa}/ssa_builder/mod.rs | 4 ++-- .../src/{ssa_refactor => ssa}/ssa_gen/context.rs | 16 ++++++++-------- .../src/{ssa_refactor => ssa}/ssa_gen/mod.rs | 0 .../src/{ssa_refactor => ssa}/ssa_gen/program.rs | 2 +- .../src/{ssa_refactor => ssa}/ssa_gen/value.rs | 4 ++-- 42 files changed, 54 insertions(+), 59 deletions(-) rename crates/noirc_evaluator/src/{ssa_refactor.rs => ssa.rs} (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/abi_gen/mod.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/acir_gen/acir_ir.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/acir_gen/acir_ir/acir_variable.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/acir_gen/acir_ir/generated_acir.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/acir_gen/acir_ir/sort.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/acir_gen/mod.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/basic_block.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/cfg.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/dfg.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/dom.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/function.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/function_inserter.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/instruction.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/instruction/call.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/map.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/post_order.rs (97%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/printer.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/types.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ir/value.rs (98%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/constant_folding.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/defunctionalize.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/die.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/flatten_cfg.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/flatten_cfg/branch_analysis.rs (98%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/inlining.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/mem2reg.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/mod.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/simplify_cfg.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/opt/unrolling.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ssa_builder/mod.rs (99%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ssa_gen/context.rs (98%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ssa_gen/mod.rs (100%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ssa_gen/program.rs (98%) rename crates/noirc_evaluator/src/{ssa_refactor => ssa}/ssa_gen/value.rs (98%) diff --git a/.gitignore b/.gitignore index af3a8e8beb2..8aec0edeadc 100644 --- a/.gitignore +++ b/.gitignore @@ -22,5 +22,3 @@ result **/target !crates/nargo_cli/tests/test_data/*/target !crates/nargo_cli/tests/test_data/*/target/witness.tr -!crates/nargo_cli/tests/test_data_ssa_refactor/*/target -!crates/nargo_cli/tests/test_data_ssa_refactor/*/target/witness.tr \ No newline at end of file diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen.rs b/crates/noirc_evaluator/src/brillig/brillig_gen.rs index 3ba04ed1afb..a1e82bbf443 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen.rs @@ -4,7 +4,7 @@ pub(crate) mod brillig_directive; pub(crate) mod brillig_fn; pub(crate) mod brillig_slice_ops; -use crate::ssa_refactor::ir::{function::Function, post_order::PostOrder}; +use crate::ssa::ir::{function::Function, post_order::PostOrder}; use std::collections::HashMap; diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index a9bbe189e57..ded6be71bd5 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -4,12 +4,13 @@ use crate::brillig::brillig_gen::brillig_slice_ops::{ use crate::brillig::brillig_ir::{ BrilligBinaryOp, BrilligContext, BRILLIG_INTEGER_ARITHMETIC_BIT_SIZE, }; -use crate::ssa_refactor::ir::function::FunctionId; -use crate::ssa_refactor::ir::instruction::{Endian, Intrinsic}; -use crate::ssa_refactor::ir::{ +use crate::ssa::ir::{ basic_block::{BasicBlock, BasicBlockId}, dfg::DataFlowGraph, - instruction::{Binary, BinaryOp, Instruction, InstructionId, TerminatorInstruction}, + function::FunctionId, + instruction::{ + Binary, BinaryOp, Endian, Instruction, InstructionId, Intrinsic, TerminatorInstruction, + }, types::{NumericType, Type}, value::{Value, ValueId}, }; diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 210d6da7be6..7c4cb5e2ced 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -8,7 +8,7 @@ use crate::{ artifact::{BrilligParameter, Label}, BrilligContext, }, - ssa_refactor::ir::{ + ssa::ir::{ dfg::DataFlowGraph, function::{Function, FunctionId}, types::{CompositeType, Type}, diff --git a/crates/noirc_evaluator/src/brillig/mod.rs b/crates/noirc_evaluator/src/brillig/mod.rs index 105475323a7..0c6ddd53a4e 100644 --- a/crates/noirc_evaluator/src/brillig/mod.rs +++ b/crates/noirc_evaluator/src/brillig/mod.rs @@ -5,7 +5,7 @@ use self::{ brillig_gen::{brillig_fn::FunctionContext, convert_ssa_function}, brillig_ir::artifact::{BrilligArtifact, Label}, }; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ function::{Function, FunctionId, RuntimeType}, value::Value, diff --git a/crates/noirc_evaluator/src/lib.rs b/crates/noirc_evaluator/src/lib.rs index c7d4f5baed6..f5403e1cf49 100644 --- a/crates/noirc_evaluator/src/lib.rs +++ b/crates/noirc_evaluator/src/lib.rs @@ -7,8 +7,8 @@ mod errors; // SSA code to create the SSA based IR // for functions and execute different optimizations. -pub mod ssa_refactor; +pub mod ssa; pub mod brillig; -pub use ssa_refactor::create_circuit; +pub use ssa::create_circuit; diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor.rs rename to crates/noirc_evaluator/src/ssa.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/abi_gen/mod.rs b/crates/noirc_evaluator/src/ssa/abi_gen/mod.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/abi_gen/mod.rs rename to crates/noirc_evaluator/src/ssa/abi_gen/mod.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir.rs b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir.rs rename to crates/noirc_evaluator/src/ssa/acir_gen/acir_ir.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs rename to crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index d1479ef1f1b..779aaa559ed 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1,9 +1,9 @@ use super::generated_acir::GeneratedAcir; use crate::brillig::brillig_gen::brillig_directive; use crate::errors::{InternalError, RuntimeError}; -use crate::ssa_refactor::acir_gen::{AcirDynamicArray, AcirValue}; -use crate::ssa_refactor::ir::types::Type as SsaType; -use crate::ssa_refactor::ir::{instruction::Endian, types::NumericType}; +use crate::ssa::acir_gen::{AcirDynamicArray, AcirValue}; +use crate::ssa::ir::types::Type as SsaType; +use crate::ssa::ir::{instruction::Endian, types::NumericType}; use acvm::acir::circuit::opcodes::{BlockId, MemOp}; use acvm::acir::circuit::Opcode; use acvm::acir::{ diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs rename to crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/sort.rs b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/sort.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/sort.rs rename to crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/sort.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs rename to crates/noirc_evaluator/src/ssa/acir_gen/mod.rs index 62a9dd5969d..331c56f59d7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -1086,7 +1086,7 @@ mod tests { use crate::{ brillig::Brillig, - ssa_refactor::{ + ssa::{ ir::{function::RuntimeType, map::Id, types::Type}, ssa_builder::FunctionBuilder, }, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir.rs b/crates/noirc_evaluator/src/ssa/ir.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir.rs rename to crates/noirc_evaluator/src/ssa/ir.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs b/crates/noirc_evaluator/src/ssa/ir/basic_block.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir/basic_block.rs rename to crates/noirc_evaluator/src/ssa/ir/basic_block.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs b/crates/noirc_evaluator/src/ssa/ir/cfg.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs rename to crates/noirc_evaluator/src/ssa/ir/cfg.rs index f08b477696a..a91123438fa 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/cfg.rs +++ b/crates/noirc_evaluator/src/ssa/ir/cfg.rs @@ -128,7 +128,7 @@ impl ControlFlowGraph { #[cfg(test)] mod tests { - use crate::ssa_refactor::ir::{instruction::TerminatorInstruction, map::Id, types::Type}; + use crate::ssa::ir::{instruction::TerminatorInstruction, map::Id, types::Type}; use super::{super::function::Function, ControlFlowGraph}; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa/ir/dfg.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs rename to crates/noirc_evaluator/src/ssa/ir/dfg.rs index 6d74e49b03b..29f5156a88c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa/ir/dfg.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, collections::HashMap}; -use crate::ssa_refactor::ir::instruction::SimplifyResult; +use crate::ssa::ir::instruction::SimplifyResult; use super::{ basic_block::{BasicBlock, BasicBlockId}, @@ -503,7 +503,7 @@ impl<'dfg> InsertInstructionResult<'dfg> { #[cfg(test)] mod tests { use super::DataFlowGraph; - use crate::ssa_refactor::ir::instruction::Instruction; + use crate::ssa::ir::instruction::Instruction; #[test] fn make_instruction() { diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs b/crates/noirc_evaluator/src/ssa/ir/dom.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs rename to crates/noirc_evaluator/src/ssa/ir/dom.rs index 4763ffffbd1..b7b1728d035 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs +++ b/crates/noirc_evaluator/src/ssa/ir/dom.rs @@ -245,7 +245,7 @@ impl DominatorTree { mod tests { use std::cmp::Ordering; - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ basic_block::BasicBlockId, dom::DominatorTree, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa/ir/function.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir/function.rs rename to crates/noirc_evaluator/src/ssa/ir/function.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs b/crates/noirc_evaluator/src/ssa/ir/function_inserter.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs rename to crates/noirc_evaluator/src/ssa/ir/function_inserter.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa/ir/instruction.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs rename to crates/noirc_evaluator/src/ssa/ir/instruction.rs index 7edb74f4206..680715fb0ec 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa/ir/instruction.rs @@ -2,13 +2,11 @@ use acvm::{acir::BlackBoxFunc, FieldElement}; use iter_extended::vecmap; use num_bigint::BigUint; -use crate::ssa_refactor::ir::types::NumericType; - use super::{ basic_block::BasicBlockId, dfg::DataFlowGraph, map::Id, - types::Type, + types::{NumericType, Type}, value::{Value, ValueId}, }; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs b/crates/noirc_evaluator/src/ssa/ir/instruction/call.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs rename to crates/noirc_evaluator/src/ssa/ir/instruction/call.rs index 96998d92fcf..2f0c077a1a7 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction/call.rs +++ b/crates/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -4,7 +4,7 @@ use acvm::{acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement}; use iter_extended::vecmap; use num_bigint::BigUint; -use crate::ssa_refactor::ir::{ +use crate::ssa::ir::{ dfg::DataFlowGraph, instruction::Intrinsic, map::Id, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs b/crates/noirc_evaluator/src/ssa/ir/map.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir/map.rs rename to crates/noirc_evaluator/src/ssa/ir/map.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/post_order.rs b/crates/noirc_evaluator/src/ssa/ir/post_order.rs similarity index 97% rename from crates/noirc_evaluator/src/ssa_refactor/ir/post_order.rs rename to crates/noirc_evaluator/src/ssa/ir/post_order.rs index 2f7b5edebe6..202f5cff716 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/post_order.rs +++ b/crates/noirc_evaluator/src/ssa/ir/post_order.rs @@ -5,7 +5,7 @@ use std::collections::HashSet; -use crate::ssa_refactor::ir::{basic_block::BasicBlockId, function::Function}; +use crate::ssa::ir::{basic_block::BasicBlockId, function::Function}; /// Depth-first traversal stack state marker for computing the cfg post-order. enum Visit { @@ -67,7 +67,7 @@ impl PostOrder { #[cfg(test)] mod tests { - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ function::{Function, RuntimeType}, map::Id, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa/ir/printer.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs rename to crates/noirc_evaluator/src/ssa/ir/printer.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/types.rs b/crates/noirc_evaluator/src/ssa/ir/types.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ir/types.rs rename to crates/noirc_evaluator/src/ssa/ir/types.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/value.rs b/crates/noirc_evaluator/src/ssa/ir/value.rs similarity index 98% rename from crates/noirc_evaluator/src/ssa_refactor/ir/value.rs rename to crates/noirc_evaluator/src/ssa/ir/value.rs index cea526058b4..54831eb4a07 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/value.rs +++ b/crates/noirc_evaluator/src/ssa/ir/value.rs @@ -1,6 +1,6 @@ use acvm::FieldElement; -use crate::ssa_refactor::ir::basic_block::BasicBlockId; +use crate::ssa::ir::basic_block::BasicBlockId; use super::{ function::FunctionId, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/constant_folding.rs b/crates/noirc_evaluator/src/ssa/opt/constant_folding.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/constant_folding.rs rename to crates/noirc_evaluator/src/ssa/opt/constant_folding.rs index acf048595d7..ea46ddf1d4f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/constant_folding.rs +++ b/crates/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use iter_extended::vecmap; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::InsertInstructionResult, function::Function, instruction::InstructionId, @@ -94,7 +94,7 @@ impl Context { mod test { use std::rc::Rc; - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ function::RuntimeType, instruction::{BinaryOp, TerminatorInstruction}, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa/opt/defunctionalize.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs rename to crates/noirc_evaluator/src/ssa/opt/defunctionalize.rs index fc3bc5d9aa6..10561bf731f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs +++ b/crates/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -9,7 +9,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use acvm::FieldElement; use iter_extended::vecmap; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, function::{Function, FunctionId, RuntimeType, Signature}, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/die.rs b/crates/noirc_evaluator/src/ssa/opt/die.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/die.rs rename to crates/noirc_evaluator/src/ssa/opt/die.rs index ef73938cc37..935568af2db 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/die.rs +++ b/crates/noirc_evaluator/src/ssa/opt/die.rs @@ -2,7 +2,7 @@ //! which the results are unused. use std::collections::HashSet; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::{BasicBlock, BasicBlockId}, dfg::DataFlowGraph, @@ -133,7 +133,7 @@ impl Context { #[cfg(test)] mod test { - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{function::RuntimeType, instruction::BinaryOp, map::Id, types::Type}, ssa_builder::FunctionBuilder, }; diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs b/crates/noirc_evaluator/src/ssa/opt/flatten_cfg.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs rename to crates/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index fdc4be085d7..1bcdf433d79 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs +++ b/crates/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -137,7 +137,7 @@ use acvm::FieldElement; use iter_extended::vecmap; use noirc_errors::Location; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, @@ -213,7 +213,7 @@ fn flatten_function_cfg(function: &mut Function) { // TODO This loops forever, if the predecessors are not then processed // TODO Because it will visit the same block again, pop it out of the queue // TODO then back into the queue again. - if let crate::ssa_refactor::ir::function::RuntimeType::Brillig = function.runtime() { + if let crate::ssa::ir::function::RuntimeType::Brillig = function.runtime() { return; } let cfg = ControlFlowGraph::with_function(function); @@ -739,7 +739,7 @@ impl<'f> Context<'f> { mod test { use std::rc::Rc; - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ dfg::DataFlowGraph, function::{Function, RuntimeType}, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs b/crates/noirc_evaluator/src/ssa/opt/flatten_cfg/branch_analysis.rs similarity index 98% rename from crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs rename to crates/noirc_evaluator/src/ssa/opt/flatten_cfg/branch_analysis.rs index bed0686e45b..1203d03f562 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs +++ b/crates/noirc_evaluator/src/ssa/opt/flatten_cfg/branch_analysis.rs @@ -21,9 +21,7 @@ //! the resulting map from each split block to each join block is returned. use std::collections::HashMap; -use crate::ssa_refactor::ir::{ - basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function, -}; +use crate::ssa::ir::{basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function}; /// Returns a `HashMap` mapping blocks that start a branch (i.e. blocks terminated with jmpif) to /// their corresponding blocks that end the branch. @@ -114,7 +112,7 @@ impl<'cfg> Context<'cfg> { #[cfg(test)] mod test { - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{cfg::ControlFlowGraph, function::RuntimeType, map::Id, types::Type}, opt::flatten_cfg::branch_analysis::find_branch_ends, ssa_builder::FunctionBuilder, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa/opt/inlining.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs rename to crates/noirc_evaluator/src/ssa/opt/inlining.rs index 7aa2f9d176a..d4c118fd3f4 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa/opt/inlining.rs @@ -6,7 +6,7 @@ use std::collections::{HashMap, HashSet}; use iter_extended::vecmap; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::InsertInstructionResult, @@ -482,7 +482,7 @@ impl<'function> PerFunctionContext<'function> { mod test { use acvm::FieldElement; - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ basic_block::BasicBlockId, function::RuntimeType, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs b/crates/noirc_evaluator/src/ssa/opt/mem2reg.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs rename to crates/noirc_evaluator/src/ssa/opt/mem2reg.rs index 15108abc490..b9e849bb77c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs +++ b/crates/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -5,7 +5,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use iter_extended::vecmap; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::DataFlowGraph, @@ -182,7 +182,7 @@ mod tests { use acvm::FieldElement; use im::vector; - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::DataFlowGraph, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa/opt/mod.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs rename to crates/noirc_evaluator/src/ssa/opt/mod.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa/opt/simplify_cfg.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs rename to crates/noirc_evaluator/src/ssa/opt/simplify_cfg.rs index 22991e38b94..58259cec90c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa/opt/simplify_cfg.rs @@ -11,7 +11,7 @@ //! Currently, 1 and 4 are unimplemented. use std::collections::HashSet; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function, instruction::TerminatorInstruction, @@ -148,7 +148,7 @@ fn try_inline_into_predecessor( #[cfg(test)] mod test { - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{ function::RuntimeType, instruction::{BinaryOp, TerminatorInstruction}, diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa/opt/unrolling.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs rename to crates/noirc_evaluator/src/ssa/opt/unrolling.rs index e5d7d6f0d5c..f6d7c952277 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -14,7 +14,7 @@ //! program that will need to be removed by a later simplify cfg pass. use std::collections::{HashMap, HashSet}; -use crate::ssa_refactor::{ +use crate::ssa::{ ir::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, dfg::DataFlowGraph, dom::DominatorTree, function::Function, function_inserter::FunctionInserter, @@ -424,7 +424,7 @@ impl<'f> LoopIteration<'f> { #[cfg(test)] mod tests { - use crate::ssa_refactor::{ + use crate::ssa::{ ir::{function::RuntimeType, instruction::BinaryOp, map::Id, types::Type}, ssa_builder::FunctionBuilder, }; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs b/crates/noirc_evaluator/src/ssa/ssa_builder/mod.rs similarity index 99% rename from crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs rename to crates/noirc_evaluator/src/ssa/ssa_builder/mod.rs index 02350d9ed17..066b5b51199 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_builder/mod.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use acvm::FieldElement; use noirc_errors::Location; -use crate::ssa_refactor::ir::{ +use crate::ssa::ir::{ basic_block::BasicBlockId, function::{Function, FunctionId}, instruction::{Binary, BinaryOp, Instruction, TerminatorInstruction}, @@ -363,7 +363,7 @@ mod tests { use acvm::FieldElement; - use crate::ssa_refactor::ir::{ + use crate::ssa::ir::{ function::RuntimeType, instruction::{Endian, Intrinsic}, map::Id, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs similarity index 98% rename from crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs rename to crates/noirc_evaluator/src/ssa/ssa_gen/context.rs index a526d93f85b..3e0bbff2a83 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -9,14 +9,14 @@ use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters}; use noirc_frontend::monomorphization::ast::{FuncId, Program}; use noirc_frontend::{BinaryOpKind, Signedness}; -use crate::ssa_refactor::ir::dfg::DataFlowGraph; -use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId; -use crate::ssa_refactor::ir::function::{Function, RuntimeType}; -use crate::ssa_refactor::ir::instruction::{BinaryOp, Endian, Intrinsic}; -use crate::ssa_refactor::ir::map::AtomicCounter; -use crate::ssa_refactor::ir::types::{NumericType, Type}; -use crate::ssa_refactor::ir::value::ValueId; -use crate::ssa_refactor::ssa_builder::FunctionBuilder; +use crate::ssa::ir::dfg::DataFlowGraph; +use crate::ssa::ir::function::FunctionId as IrFunctionId; +use crate::ssa::ir::function::{Function, RuntimeType}; +use crate::ssa::ir::instruction::{BinaryOp, Endian, Intrinsic}; +use crate::ssa::ir::map::AtomicCounter; +use crate::ssa::ir::types::{NumericType, Type}; +use crate::ssa::ir::value::ValueId; +use crate::ssa::ssa_builder::FunctionBuilder; use super::value::{Tree, Value, Values}; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs similarity index 100% rename from crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs rename to crates/noirc_evaluator/src/ssa/ssa_gen/mod.rs diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/program.rs similarity index 98% rename from crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs rename to crates/noirc_evaluator/src/ssa/ssa_gen/program.rs index aec0e4262c8..509f778f3b0 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeMap, fmt::Display}; use iter_extended::btree_map; -use crate::ssa_refactor::ir::{ +use crate::ssa::ir::{ function::{Function, FunctionId}, map::AtomicCounter, }; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/value.rs similarity index 98% rename from crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs rename to crates/noirc_evaluator/src/ssa/ssa_gen/value.rs index 2d209635610..e7bb515465b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/value.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/value.rs @@ -1,7 +1,7 @@ use iter_extended::vecmap; -use crate::ssa_refactor::ir::types::Type; -use crate::ssa_refactor::ir::value::ValueId as IrValueId; +use crate::ssa::ir::types::Type; +use crate::ssa::ir::value::ValueId as IrValueId; use super::context::FunctionContext; From ed67b10f0180aa93b04bcbd7a65864f2d898dadf Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Wed, 2 Aug 2023 21:37:22 +0200 Subject: [PATCH 5/9] chore: Initialize copy array from previous values in `array_set` (#2106) * Initialize copy array from previous values in array_set * chore: use `try_vecmap` in place of for-loop * Update crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs * Update crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs --------- Co-authored-by: Tom French Co-authored-by: jfecher --- crates/noirc_evaluator/src/ssa/acir_gen/mod.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs index 331c56f59d7..f1e71922b0b 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -589,19 +589,17 @@ impl Context { let result_array_id = result_id.to_usize() as u32; let result_block_id = BlockId(result_array_id); - // Initialize the new array with zero values - self.initialize_array(result_block_id, len, None)?; - - // Copy the values from the old array into the newly created zeroed array - for i in 0..len { + // Initialize the new array with the values from the old array + let init_values = try_vecmap(0..len, |i| { let index = AcirValue::Var( self.acir_context.add_constant(FieldElement::from(i as u128)), AcirType::NumericType(NumericType::NativeField), ); let var = index.into_var()?; let read = self.acir_context.read_from_memory(block_id, &var)?; - self.acir_context.write_to_memory(result_block_id, &var, &read)?; - } + Ok(AcirValue::Var(read, AcirType::NumericType(NumericType::NativeField))) + })?; + self.initialize_array(result_block_id, len, Some(&init_values))?; // Write the new value into the new array at the specified index let index_var = self.convert_value(index, dfg).into_var()?; From f3f6fbe45254ea206b778d191861498eef880064 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Wed, 2 Aug 2023 21:38:28 +0200 Subject: [PATCH 6/9] chore: Decouple acir blockid from ssa valueid (#2103) Decouple acir blokid from ssa valueid Co-authored-by: Tom French --- .../noirc_evaluator/src/ssa/acir_gen/mod.rs | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs index f1e71922b0b..25a0c2ee2e8 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -55,6 +55,15 @@ struct Context { /// This set is used to ensure that a MemoryOp opcode is only pushed to the circuit /// if there is already a MemoryInit opcode. initialized_arrays: HashSet, + + /// Maps SSA values to BlockId + /// A BlockId is an ACIR structure which identifies a memory block + /// Each acir memory block corresponds to a different SSA array. + memory_blocks: HashMap, BlockId>, + + /// Number of the next BlockId, it is used to construct + /// a new BlockId + max_block_id: u32, } #[derive(Clone)] @@ -139,6 +148,8 @@ impl Context { current_side_effects_enabled_var, acir_context, initialized_arrays: HashSet::new(), + memory_blocks: HashMap::new(), + max_block_id: 0, } } @@ -221,7 +232,7 @@ impl Context { match &value { AcirValue::Var(_, _) => (), AcirValue::Array(values) => { - let block_id = BlockId(param_id.to_usize() as u32); + let block_id = self.block_id(param_id); let v = vecmap(values, |v| v.clone()); self.initialize_array(block_id, values.len(), Some(&v))?; } @@ -264,6 +275,18 @@ impl Context { } } + /// Get the BlockId corresponding to the ValueId + /// If there is no matching BlockId, we create a new one. + fn block_id(&mut self, value: &ValueId) -> BlockId { + if let Some(block_id) = self.memory_blocks.get(value) { + return *block_id; + } + let block_id = BlockId(self.max_block_id); + self.max_block_id += 1; + self.memory_blocks.insert(*value, block_id); + block_id + } + /// Creates an `AcirVar` corresponding to a parameter witness to appears in the abi. A range /// constraint is added if the numeric type requires it. /// @@ -500,7 +523,7 @@ impl Context { dfg: &DataFlowGraph, ) -> Result<(), RuntimeError> { let array = dfg.resolve(array); - let block_id = BlockId(array.to_usize() as u32); + let block_id = self.block_id(&array); if !self.initialized_arrays.contains(&block_id) { match &dfg[array] { Value::Array { array, .. } => { @@ -548,11 +571,9 @@ impl Context { ) -> Result<(), InternalError> { // Fetch the internal SSA ID for the array let array = dfg.resolve(array); - let array_ssa_id = array.to_usize() as u32; - // Use the SSA ID to create a block ID - // There is currently a 1-1 mapping from array SSA ID to block ID - let block_id = BlockId(array_ssa_id); + // Use the SSA ID to get or create its block ID + let block_id = self.block_id(&array); // Every array has a length in its type, so we fetch that from // the SSA IR. @@ -586,8 +607,7 @@ impl Context { .instruction_results(instruction) .first() .expect("Array set does not have one result"); - let result_array_id = result_id.to_usize() as u32; - let result_block_id = BlockId(result_array_id); + let result_block_id = self.block_id(result_id); // Initialize the new array with the values from the old array let init_values = try_vecmap(0..len, |i| { From 35404ba9b2916cebf35519546eec0f0ae54b5516 Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Wed, 2 Aug 2023 23:09:01 +0300 Subject: [PATCH 7/9] =?UTF-8?q?feat:=20Initial=20work=20on=20rewriting=20c?= =?UTF-8?q?losures=20to=20regular=20functions=20with=20hi=E2=80=A6=20(#195?= =?UTF-8?q?9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Initial work on rewriting closures to regular functions with hidden env This commit implements the following mechanism: On a line where a lambda expression is encountered, we initialize a tuple for the captured lambda environment and we rewrite the lambda to a regular function taking this environment as an additional parameter. All calls to the closure are then modified to insert this hidden parameter. In other words, the following code: ``` let x = some_value; let closure = |a| x + a; println(closure(10)); println(closure(20)); ``` is rewritten to: ``` fn closure(env: (Field,), a: Field) -> Field { env.0 + a } let x = some_value; let closure_env = (x,); println(closure(closure_env, 10)); println(closure(closure_env, 20)); ``` In the presence of nested closures, we propagate the captured variables implicitly through all intermediate closures: ``` let x = some_value; let closure = |a, c| # here, `x` is initialized from the hidden env of the outer closure let inner_closure = |b| a + b + x inner_closure(c) ``` To make these transforms possible, the following changes were made to the logic of the HIR resolver and the monomorphization pass: * In the HIR resolver pass, the code determines the precise list of variables captured by each lambda. Along with the list, we compute the index of each captured var within the parent closure's environment (when the capture is propagated). * Introduction of a new `Closure` type in order to be able to recognize the call-sites that need the automatic environment variable treatment. It's a bit unfortunate that the Closure type is defined within the `AST` modules that are used to describe the output of the monomorphization pass, because we aim to eliminate all closures during the pass. A better solution would have been possible if the type check pass after HIR resolution was outputting types specific to the HIR pass (then the closures would exist only within this separate non-simplified type system). * The majority of the work is in the Lambda processing step in the monomorphizer which performs the necessary transformations based on the above information. Remaining things to do: * There are a number of pending TODO items for various minor unresolved loose ends in the code. * There are a lot of possible additional tests to be written. * Update docs * refactor: use panic, instead of println+assert Co-authored-by: jfecher * test: add an initial monomorphization rewrite test a lot of the machinery is copied from similar existing tests the original authors also note some of those can be refactored in something reusable * fix: address some PR comments: comment/refactor/small fixes * fix: use an unified Function object, fix some problems, comments * fix: fix code, addressing `cargo clippy` warnings * fix: replace type_of usage and remove it, as hinted in review * test: move closure-related tests to test_data * test: update closure rewrite test output * chore: apply cargo fmt changes * test: capture some variables in some tests, fix warnings, add a TODO add a TODO about returning closures * test: add simplification of #1088 as a resolve test, enable another test * fix: fix unify for closures, fix display for fn/closure types * test: update closure tests after resolving mutable bug * fix: address some review comments for closure PR: fixes/cleanup * refactor: cleanup, remove a line Co-authored-by: jfecher * refactor: cleanup Co-authored-by: jfecher * fix: fix bind_function_type env_type handling type variable binding * test: improve higher_order_fn_selector test * fix: remove skip_params/additional param logic from typechecking/display * fix: don't use closure capture logic for lambdas without captures * fix: apply cargo fmt & clippy * chore: apply cargo fmt * test: fix closure rewrite test: actually capture * chore: remove type annotation for `params` * chore: run cargo fmt --------- Co-authored-by: jfecher Co-authored-by: Alex Vitkov --- .../test_data/closures_mut_ref/Nargo.toml | 6 + .../test_data/closures_mut_ref/Prover.toml | 1 + .../test_data/closures_mut_ref/src/main.nr | 20 + .../higher_order_fn_selector/Nargo.toml | 6 + .../higher_order_fn_selector/src/main.nr | 39 ++ .../higher_order_functions/Nargo.toml | 6 + .../higher_order_functions/Prover.toml | 0 .../higher_order_functions/src/main.nr | 87 ++++ .../higher_order_functions/target/c.json | 1 + .../higher_order_functions/target/main.json | 1 + .../higher_order_functions/target/witness.tr | Bin 0 -> 112 bytes .../tests/test_data/inner_outer_cl/Nargo.toml | 6 + .../test_data/inner_outer_cl/src/main.nr | 12 + .../tests/test_data/ret_fn_ret_cl/Nargo.toml | 6 + .../tests/test_data/ret_fn_ret_cl/Prover.toml | 1 + .../tests/test_data/ret_fn_ret_cl/src/main.nr | 39 ++ .../src/ssa/ssa_gen/context.rs | 2 +- .../src/hir/def_collector/dc_crate.rs | 4 +- .../src/hir/resolution/resolver.rs | 345 ++++++++++++-- .../noirc_frontend/src/hir/type_check/expr.rs | 73 +-- .../noirc_frontend/src/hir/type_check/mod.rs | 29 +- crates/noirc_frontend/src/hir_def/expr.rs | 16 + crates/noirc_frontend/src/hir_def/function.rs | 4 +- crates/noirc_frontend/src/hir_def/types.rs | 52 ++- .../src/monomorphization/ast.rs | 17 +- .../src/monomorphization/mod.rs | 428 +++++++++++++++++- crates/noirc_frontend/src/node_interner.rs | 2 +- 27 files changed, 1078 insertions(+), 125 deletions(-) create mode 100644 crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json create mode 100644 crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json create mode 100644 crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr create mode 100644 crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr create mode 100644 crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml b/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml new file mode 100644 index 00000000000..c829bb160b1 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "closures_mut_ref" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml b/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml new file mode 100644 index 00000000000..11497a473bc --- /dev/null +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml @@ -0,0 +1 @@ +x = "0" diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr new file mode 100644 index 00000000000..ae990e004fd --- /dev/null +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr @@ -0,0 +1,20 @@ +use dep::std; + +fn main(mut x: Field) { + let one = 1; + let add1 = |z| { + *z = *z + one; + }; + + let two = 2; + let add2 = |z| { + *z = *z + two; + }; + + add1(&mut x); + assert(x == 1); + + add2(&mut x); + assert(x == 3); + +} diff --git a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml new file mode 100644 index 00000000000..3c2277e35a5 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "higher_order_fn_selector" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr new file mode 100644 index 00000000000..767cff0c409 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr @@ -0,0 +1,39 @@ +use dep::std; + +fn g(x: &mut Field) -> () { + *x *= 2; +} + +fn h(x: &mut Field) -> () { + *x *= 3; +} + +fn selector(flag: &mut bool) -> fn(&mut Field) -> () { + let my_func = if *flag { + g + } else { + h + }; + + // Flip the flag for the next function call + *flag = !(*flag); + my_func +} + +fn main() { + + let mut flag: bool = true; + + let mut x: Field = 100; + let returned_func = selector(&mut flag); + returned_func(&mut x); + + assert(x == 200); + + let mut y: Field = 100; + let returned_func2 = selector(&mut flag); + returned_func2(&mut y); + + assert(y == 300); + +} diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml b/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml new file mode 100644 index 00000000000..cf7526abc7f --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "higher_order_functions" +authors = [""] +compiler_version = "0.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml b/crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr new file mode 100644 index 00000000000..fefd23b7dbc --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr @@ -0,0 +1,87 @@ +use dep::std; + +fn main() -> pub Field { + let f = if 3 * 7 > 200 as u32 { foo } else { bar }; + assert(f()[1] == 2); + // Lambdas: + assert(twice(|x| x * 2, 5) == 20); + assert((|x, y| x + y + 1)(2, 3) == 6); + + // nested lambdas + assert((|a, b| { + a + (|c| c + 2)(b) + })(0, 1) == 3); + + + // Closures: + let a = 42; + let g = || a; + assert(g() == 42); + + // When you copy mutable variables, + // the capture of the copies shouldn't change: + let mut x = 2; + x = x + 1; + let z = x; + + // Add extra mutations to ensure we can mutate x without the + // captured z changing. + x = x + 1; + assert((|y| y + z)(1) == 4); + + // When you capture mutable variables, + // again, the captured variable doesn't change: + let closure_capturing_mutable = (|y| y + x); + assert(closure_capturing_mutable(1) == 5); + x += 1; + assert(closure_capturing_mutable(1) == 5); + + let ret = twice(add1, 3); + + test_array_functions(); + ret +} + +/// Test the array functions in std::array +fn test_array_functions() { + let myarray: [i32; 3] = [1, 2, 3]; + assert(myarray.any(|n| n > 2)); + + let evens: [i32; 3] = [2, 4, 6]; + assert(evens.all(|n| n > 1)); + + assert(evens.fold(0, |a, b| a + b) == 12); + assert(evens.reduce(|a, b| a + b) == 12); + + // TODO: is this a sort_via issue with the new backend, + // or something more general? + // + // currently it fails only with `--experimental-ssa` with + // "not yet implemented: Cast into signed" + // but it worked with the original ssa backend + // (before dropping it) + // + // opened #2121 for it + // https://github.com/noir-lang/noir/issues/2121 + + // let descending = myarray.sort_via(|a, b| a > b); + // assert(descending == [3, 2, 1]); + + assert(evens.map(|n| n / 2) == myarray); +} + +fn foo() -> [u32; 2] { + [1, 3] +} + +fn bar() -> [u32; 2] { + [3, 2] +} + +fn add1(x: Field) -> Field { + x + 1 +} + +fn twice(f: fn(Field) -> Field, x: Field) -> Field { + f(f(x)) +} diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json b/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json new file mode 100644 index 00000000000..c1233b8160b --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[],"param_witnesses":{},"return_type":null,"return_witnesses":[]},"bytecode":[155,194,56,97,194,4,0],"proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json b/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json new file mode 100644 index 00000000000..8d7a1566313 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[{"name":"x","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"y","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"z","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"}],"param_witnesses":{"x":[1],"y":[2],"z":[3]},"return_type":null,"return_witnesses":[]},"bytecode":"H4sIAAAAAAAA/9WUTW6DMBSEJ/yFhoY26bYLjoAxBLPrVYpK7n+EgmoHamWXeShYQsYSvJ+Z9/kDwCf+1m58ArsXi3PgnUN7dt/u7P9fdi8fW8rlATduCW89GFe5l2iMES90YBd+EyTyjIjtGYIm+HF1eanroa0GpdV3WXW9acq66S9GGdWY5qcyWg+mNm3Xd23ZqVoP6tp0+moDJ5AxNOTUWdk6VUTsOSb6wtRPCuDYziaZAzGA92OMFCsAPCUqMAOcQg5gZwIb4BdsA+A9seeU6AtTPymAUzubZA7EAD6MMTKsAPCUqMAMcAY5gJ0JbIBfsQ2AD8SeM6IvTP2kAM7sbJI5EAP4OMbIsQLAU6ICM8A55AB2JrABfsM2AD4Se86Jvjy5freeQ2LPObGud6J+Ce5ADz6LzJqX9Z4W75HdgzszkQj0BC+Pr6PohSpl0kkg7hm84Zfq+8z36N/l9OyaLtcv2EfpKJUUAAA=","proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr b/crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr new file mode 100644 index 0000000000000000000000000000000000000000..a539f87a55498eeaff3e546ac9126cea0091fa70 GIT binary patch literal 112 zcmV-$0FVD4iwFP!00002|E<$W3cw%?h2hTg=t&aVF5LAhrT4#sir&CKAZGQE2Z Field { + x + 1 +} + +fn ret_fn() -> fn(Field) -> Field { + f +} + +// TODO: in the advanced implicitly generic function with closures branch +// which would support higher-order functions in a better way +// support returning closures: +// +// fn ret_closure() -> fn(Field) -> Field { +// let y = 1; +// let inner_closure = |z| -> Field{ +// z + y +// }; +// inner_closure +// } + +fn ret_lambda() -> fn(Field) -> Field { + let cl = |z: Field| -> Field { + z + 1 + }; + cl +} + +fn main(x : Field) { + let result_fn = ret_fn(); + assert(result_fn(x) == x + 1); + + // let result_closure = ret_closure(); + // assert(result_closure(x) == x + 1); + + let result_lambda = ret_lambda(); + assert(result_lambda(x) == x + 1); +} diff --git a/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs index 3e0bbff2a83..c3578e5ee7e 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -218,7 +218,7 @@ impl<'a> FunctionContext<'a> { } ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"), ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"), - ast::Type::Function(_, _) => Type::Function, + ast::Type::Function(_, _, _) => Type::Function, ast::Type::Slice(element) => { let element_types = Self::convert_type(element).flatten(); Type::Slice(Rc::new(element_types)) diff --git a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs index 76fbea289be..2beebf6871c 100644 --- a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -12,8 +12,8 @@ use crate::hir::type_check::{type_check_func, TypeChecker}; use crate::hir::Context; use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TypeAliasId}; use crate::{ - ExpressionKind, Generics, Ident, LetStatement, NoirFunction, NoirStruct, NoirTypeAlias, - ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, Literal, + ExpressionKind, Generics, Ident, LetStatement, Literal, NoirFunction, NoirStruct, + NoirTypeAlias, ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, }; use fm::FileId; use iter_extended::vecmap; diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 8b4f97dbd8e..681c853899f 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -12,10 +12,10 @@ // // XXX: Resolver does not check for unused functions use crate::hir_def::expr::{ - HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, - HirConstructorExpression, HirExpression, HirForExpression, HirIdent, HirIfExpression, - HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, - HirMethodCallExpression, HirPrefixExpression, + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, + HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, + HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, }; use crate::token::Attribute; use regex::Regex; @@ -58,6 +58,13 @@ type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; +pub struct LambdaContext { + captures: Vec, + /// the index in the scope tree + /// (sometimes being filled by ScopeTree's find method) + scope_index: usize, +} + /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 /// definition in scope, and to convert the AST into the HIR. /// @@ -81,12 +88,10 @@ pub struct Resolver<'a> { /// were declared in. generics: Vec<(Rc, TypeVariable, Span)>, - /// Lambdas share the function scope of the function they're defined in, - /// so to identify whether they use any variables from the parent function - /// we keep track of the scope index a variable is declared in. When a lambda - /// is declared we push a scope and set this lambda_index to the scope index. - /// Any variable from a scope less than that must be from the parent function. - lambda_index: usize, + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, } /// ResolverMetas are tagged onto each definition to track how many times they are used @@ -112,7 +117,7 @@ impl<'a> Resolver<'a> { self_type: None, generics: Vec::new(), errors: Vec::new(), - lambda_index: 0, + lambda_stack: Vec::new(), file, } } @@ -125,10 +130,6 @@ impl<'a> Resolver<'a> { self.errors.push(err); } - fn current_lambda_index(&self) -> usize { - self.scopes.current_scope_index() - } - /// Resolving a function involves interning the metadata /// interning any statements inside of the function /// and interning the function itself @@ -279,25 +280,25 @@ impl<'a> Resolver<'a> { // // If a variable is not found, then an error is logged and a dummy id // is returned, for better error reporting UX - fn find_variable_or_default(&mut self, name: &Ident) -> HirIdent { + fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { self.find_variable(name).unwrap_or_else(|error| { self.push_err(error); let id = DefinitionId::dummy_id(); let location = Location::new(name.span(), self.file); - HirIdent { location, id } + (HirIdent { location, id }, 0) }) } - fn find_variable(&mut self, name: &Ident) -> Result { + fn find_variable(&mut self, name: &Ident) -> Result<(HirIdent, usize), ResolverError> { // Find the definition for this Ident let scope_tree = self.scopes.current_scope_tree(); let variable = scope_tree.find(&name.0.contents); let location = Location::new(name.span(), self.file); - if let Some((variable_found, _)) = variable { + if let Some((variable_found, scope)) = variable { variable_found.num_times_used += 1; let id = variable_found.ident.id; - Ok(HirIdent { location, id }) + Ok((HirIdent { location, id }, scope)) } else { Err(ResolverError::VariableNotDeclared { name: name.0.contents.clone(), @@ -363,7 +364,8 @@ impl<'a> Resolver<'a> { UnresolvedType::Function(args, ret) => { let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); - Type::Function(args, ret) + let env = Box::new(Type::Unit); + Type::Function(args, ret, env) } UnresolvedType::MutableReference(element) => { Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) @@ -517,24 +519,24 @@ impl<'a> Resolver<'a> { } } - fn get_ident_from_path(&mut self, path: Path) -> HirIdent { + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { let location = Location::new(path.span(), self.file); let error = match path.as_ident().map(|ident| self.find_variable(ident)) { - Some(Ok(ident)) => return ident, + Some(Ok(found)) => return found, // Try to look it up as a global, but still issue the first error if we fail Some(Err(error)) => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(_) => error, }, None => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(error) => error, }, }; self.push_err(error); let id = DefinitionId::dummy_id(); - HirIdent { location, id } + (HirIdent { location, id }, 0) } /// Translates an UnresolvedType to a Type @@ -705,7 +707,7 @@ impl<'a> Resolver<'a> { }); } - let mut typ = Type::Function(parameter_types, return_type); + let mut typ = Type::Function(parameter_types, return_type, Box::new(Type::Unit)); if !generics.is_empty() { typ = Type::Forall(generics, Box::new(typ)); @@ -837,12 +839,14 @@ impl<'a> Resolver<'a> { Self::find_numeric_generics_in_type(field, found); } } - Type::Function(parameters, return_type) => { + + Type::Function(parameters, return_type, _env) => { for parameter in parameters { Self::find_numeric_generics_in_type(parameter, found); } Self::find_numeric_generics_in_type(return_type, found); } + Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { if let Type::NamedGeneric(type_variable, name) = generic { @@ -915,7 +919,7 @@ impl<'a> Resolver<'a> { fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { match lvalue { LValue::Ident(ident) => { - HirLValue::Ident(self.find_variable_or_default(&ident), Type::Error) + HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error) } LValue::MemberAccess { object, field_name } => { let object = Box::new(self.resolve_lvalue(*object)); @@ -933,6 +937,39 @@ impl<'a> Resolver<'a> { } } + fn resolve_local_variable(&mut self, hir_ident: HirIdent, var_scope_index: usize) { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let pos = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if pos.is_none() { + self.lambda_stack[lambda_index] + .captures + .push(HirCapturedVar { ident: hir_ident, transitive_capture_index }); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(pos.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )); + } + } + } + } + pub fn resolve_expression(&mut self, expr: Expression) -> ExprId { let hir_expr = match expr.kind { ExpressionKind::Literal(literal) => HirExpression::Literal(match literal { @@ -965,7 +1002,20 @@ impl<'a> Resolver<'a> { // Otherwise, then it is referring to an Identifier // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; // If the expression is a singular indent, we search the resolver's current scope as normal. - let hir_ident = self.get_ident_from_path(path); + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(_) => {} + DefinitionKind::Global(_) => {} + DefinitionKind::GenericType(_) => {} + // We ignore the above definition kinds because only local variables can be captured by closures. + DefinitionKind::Local(_) => { + self.resolve_local_variable(hir_ident, var_scope_index); + } + } + } + HirExpression::Ident(hir_ident) } ExpressionKind::Prefix(prefix) => { @@ -1087,8 +1137,9 @@ impl<'a> Resolver<'a> { // We must stay in the same function scope as the parent function to allow for closures // to capture variables. This is currently limited to immutable variables. ExpressionKind::Lambda(lambda) => self.in_new_scope(|this| { - let new_index = this.current_lambda_index(); - let old_index = std::mem::replace(&mut this.lambda_index, new_index); + let scope_index = this.scopes.current_scope_index(); + + this.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); let parameters = vecmap(lambda.parameters, |(pattern, typ)| { let parameter = DefinitionKind::Local(None); @@ -1098,8 +1149,14 @@ impl<'a> Resolver<'a> { let return_type = this.resolve_inferred_type(lambda.return_type); let body = this.resolve_expression(lambda.body); - this.lambda_index = old_index; - HirExpression::Lambda(HirLambda { parameters, return_type, body }) + let lambda_context = this.lambda_stack.pop().unwrap(); + + HirExpression::Lambda(HirLambda { + parameters, + return_type, + body, + captures: lambda_context.captures, + }) }), }; @@ -1411,6 +1468,7 @@ pub fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result< #[cfg(test)] mod test { + use core::panic; use std::collections::HashMap; use fm::FileId; @@ -1419,10 +1477,14 @@ mod test { use crate::hir::def_map::{ModuleData, ModuleId, ModuleOrigin}; use crate::hir::resolution::errors::ResolverError; use crate::hir::resolution::import::PathResolutionError; + use crate::hir::resolution::resolver::StmtId; use crate::graph::CrateId; + use crate::hir_def::expr::HirExpression; use crate::hir_def::function::HirFunction; + use crate::hir_def::stmt::HirStatement; use crate::node_interner::{FuncId, NodeInterner}; + use crate::ParsedModule; use crate::{ hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId}, parse_program, Path, @@ -1432,29 +1494,24 @@ mod test { // func_namespace is used to emulate the fact that functions can be imported // and functions can be forward declared - fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { + fn init_src_code_resolution( + src: &str, + ) -> (ParsedModule, NodeInterner, HashMap, FileId, TestPathResolver) { let (program, errors) = parse_program(src); - assert!(errors.is_empty()); - - let mut interner = NodeInterner::default(); - - let func_ids = vecmap(&func_namespace, |name| { - let id = interner.push_fn(HirFunction::empty()); - interner.push_function_definition(name.to_string(), id); - id - }); - - let mut path_resolver = TestPathResolver(HashMap::new()); - for (name, id) in func_namespace.into_iter().zip(func_ids) { - path_resolver.insert_func(name.to_owned(), id); + if !errors.is_empty() { + panic!("Unexpected parse errors in test code: {:?}", errors); } + let interner: NodeInterner = NodeInterner::default(); + let mut def_maps: HashMap = HashMap::new(); let file = FileId::default(); let mut modules = arena::Arena::new(); modules.insert(ModuleData::new(None, ModuleOrigin::File(file), false)); + let path_resolver = TestPathResolver(HashMap::new()); + def_maps.insert( CrateId::dummy_id(), CrateDefMap { @@ -1465,10 +1522,30 @@ mod test { }, ); + (program, interner, def_maps, file, path_resolver) + } + + // func_namespace is used to emulate the fact that functions can be imported + // and functions can be forward declared + fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { + let (program, mut interner, def_maps, file, mut path_resolver) = + init_src_code_resolution(src); + + let func_ids = vecmap(&func_namespace, |name| { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(name.to_string(), id); + id + }); + + for (name, id) in func_namespace.into_iter().zip(func_ids) { + path_resolver.insert_func(name.to_owned(), id); + } + let mut errors = Vec::new(); for func in program.functions { let id = interner.push_fn(HirFunction::empty()); interner.push_function_definition(func.name().to_string(), id); + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); let (_, _, err) = resolver.resolve_function(func, id, ModuleId::dummy_id()); errors.extend(err); @@ -1477,6 +1554,81 @@ mod test { errors } + fn get_program_captures(src: &str) -> Vec> { + let (program, mut interner, def_maps, file, mut path_resolver) = + init_src_code_resolution(src); + + let mut all_captures: Vec> = Vec::new(); + for func in program.functions { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(func.name().clone().to_string(), id); + path_resolver.insert_func(func.name().to_owned(), id); + + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); + let (hir_func, _, _) = resolver.resolve_function(func, id, ModuleId::dummy_id()); + + // Iterate over function statements and apply filtering function + parse_statement_blocks( + hir_func.block(&interner).statements(), + &interner, + &mut all_captures, + ); + } + all_captures + } + + fn parse_statement_blocks( + stmts: &[StmtId], + interner: &NodeInterner, + result: &mut Vec>, + ) { + let mut expr: HirExpression; + + for stmt_id in stmts.iter() { + let hir_stmt = interner.statement(stmt_id); + match hir_stmt { + HirStatement::Expression(expr_id) => { + expr = interner.expression(&expr_id); + } + HirStatement::Let(let_stmt) => { + expr = interner.expression(&let_stmt.expression); + } + HirStatement::Assign(assign_stmt) => { + expr = interner.expression(&assign_stmt.expression); + } + HirStatement::Constrain(constr_stmt) => { + expr = interner.expression(&constr_stmt.0); + } + HirStatement::Semi(semi_expr) => { + expr = interner.expression(&semi_expr); + } + HirStatement::Error => panic!("Invalid HirStatement!"), + } + get_lambda_captures(expr, &interner, result); // TODO: dyn filter function as parameter + } + } + + fn get_lambda_captures( + expr: HirExpression, + interner: &NodeInterner, + result: &mut Vec>, + ) { + if let HirExpression::Lambda(lambda_expr) = expr { + let mut cur_capture = Vec::new(); + + for capture in lambda_expr.captures.iter() { + cur_capture.push(interner.definition(capture.ident.id).name.clone()); + } + result.push(cur_capture); + + // Check for other captures recursively within the lambda body + let hir_body_expr = interner.expression(&lambda_expr.body); + if let HirExpression::Block(block_expr) = hir_body_expr.clone() { + parse_statement_blocks(block_expr.statements(), interner, result); + } + } + } + #[test] fn resolve_empty_function() { let src = " @@ -1656,9 +1808,103 @@ mod test { x } "#; + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + } + + #[test] + fn resolve_basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + panic!("Unexpected errors: {:?}", errors); + } + } + + #[test] + fn resolve_simplified_closure() { + // based on bug https://github.com/noir-lang/noir/issues/1088 + + let src = r#"fn do_closure(x: Field) -> Field { + let y = x; + let ret_capture = || { + y + }; + ret_capture() + } + + fn main(x: Field) { + assert(do_closure(x) == 100); + } + + "#; + let parsed_captures = get_program_captures(src); + let mut expected_captures = vec![]; + expected_captures.push(vec!["y".to_string()]); + assert_eq!(expected_captures, parsed_captures); + } + + #[test] + fn resolve_complex_closures() { + let src = r#" + fn main(x: Field) -> pub Field { + let closure_without_captures = |x| x + x; + let a = closure_without_captures(1); + + let closure_capturing_a_param = |y| y + x; + let b = closure_capturing_a_param(2); + + let closure_capturing_a_local_var = |y| y + b; + let c = closure_capturing_a_local_var(3); + + let closure_with_transitive_captures = |y| { + let d = 5; + let nested_closure = |z| { + let doubly_nested_closure = |w| w + x + b; + a + z + y + d + x + doubly_nested_closure(4) + x + y + }; + let res = nested_closure(5); + res + }; + + a + b + c + closure_with_transitive_captures(6) + } + "#; let errors = resolve_src_code(src, vec!["main", "foo"]); assert!(errors.is_empty()); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + + let expected_captures = vec![ + vec![], + vec!["x".to_string()], + vec!["b".to_string()], + vec!["x".to_string(), "b".to_string(), "a".to_string()], + vec![ + "x".to_string(), + "b".to_string(), + "a".to_string(), + "y".to_string(), + "d".to_string(), + ], + vec!["x".to_string(), "b".to_string()], + ]; + + let parsed_captures = get_program_captures(src); + + assert_eq!(expected_captures, parsed_captures); } #[test] @@ -1694,6 +1940,9 @@ mod test { } } + // possible TODO: Create a more sophisticated set of search functions over the HIR, so we can check + // that the correct variables are captured in each closure + fn path_unresolved_error(err: ResolverError, expected_unresolved_path: &str) { match err { ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 24ac5f3443e..6c111a1d6a0 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -279,6 +279,12 @@ impl<'interner> TypeChecker<'interner> { Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem))) } HirExpression::Lambda(lambda) => { + let captured_vars = + vecmap(lambda.captures, |capture| self.interner.id_type(capture.ident.id)); + + let env_type: Type = + if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; + let params = vecmap(lambda.parameters, |(pattern, typ)| { self.bind_pattern(&pattern, typ.clone()); typ @@ -294,7 +300,8 @@ impl<'interner> TypeChecker<'interner> { expr_span: span, } }); - Type::Function(params, Box::new(lambda.return_type)) + + Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } }; @@ -319,9 +326,9 @@ impl<'interner> TypeChecker<'interner> { argument_types: &mut [(Type, ExprId, noirc_errors::Span)], ) { let expected_object_type = match function_type { - Type::Function(args, _) => args.get(0), + Type::Function(args, _, _) => args.get(0), Type::Forall(_, typ) => match typ.as_ref() { - Type::Function(args, _) => args.get(0), + Type::Function(args, _, _) => args.get(0), typ => unreachable!("Unexpected type for function: {typ}"), }, typ => unreachable!("Unexpected type for function: {typ}"), @@ -870,6 +877,35 @@ impl<'interner> TypeChecker<'interner> { } } + fn bind_function_type_impl( + &mut self, + fn_params: &Vec, + fn_ret: &Type, + callsite_args: &Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + if fn_params.len() != callsite_args.len() { + self.errors.push(TypeCheckError::ParameterCountMismatch { + expected: fn_params.len(), + found: callsite_args.len(), + span, + }); + return Type::Error; + } + + for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { + arg.make_subtype_of(param, *arg_span, &mut self.errors, || { + TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + } + }); + } + + fn_ret.clone() + } + fn bind_function_type( &mut self, function: Type, @@ -886,38 +922,17 @@ impl<'interner> TypeChecker<'interner> { let ret = self.interner.next_type_variable(); let args = vecmap(args, |(arg, _, _)| arg); - let expected = Type::Function(args, Box::new(ret.clone())); + let env_type = self.interner.next_type_variable(); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); if let Err(error) = binding.borrow_mut().bind_to(expected, span) { self.errors.push(error); } ret } - Type::Function(parameters, ret) => { - if parameters.len() != args.len() { - self.errors.push(TypeCheckError::ParameterCountMismatch { - expected: parameters.len(), - found: args.len(), - span, - }); - return Type::Error; - } - - for (param, (arg, arg_id, arg_span)) in parameters.iter().zip(args) { - arg.make_subtype_with_coercions( - param, - arg_id, - self.interner, - &mut self.errors, - || TypeCheckError::TypeMismatch { - expected_typ: param.to_string(), - expr_typ: arg.to_string(), - expr_span: arg_span, - }, - ); - } - - *ret + Type::Function(parameters, ret, _env) => { + // ignoring env for subtype on purpose + self.bind_function_type_impl(parameters.as_ref(), ret.as_ref(), args.as_ref(), span) } Type::Error => Type::Error, found => { diff --git a/crates/noirc_frontend/src/hir/type_check/mod.rs b/crates/noirc_frontend/src/hir/type_check/mod.rs index 26d0e36abf9..1883c0abf62 100644 --- a/crates/noirc_frontend/src/hir/type_check/mod.rs +++ b/crates/noirc_frontend/src/hir/type_check/mod.rs @@ -152,6 +152,7 @@ impl<'interner> TypeChecker<'interner> { #[cfg(test)] mod test { use std::collections::HashMap; + use std::vec; use fm::FileId; use iter_extended::vecmap; @@ -245,7 +246,11 @@ mod test { contract_function_type: None, is_internal: None, is_unconstrained: false, - typ: Type::Function(vec![Type::field(None), Type::field(None)], Box::new(Type::Unit)), + typ: Type::Function( + vec![Type::field(None), Type::field(None)], + Box::new(Type::Unit), + Box::new(Type::Unit), + ), parameters: vec![ Param(Identifier(x), Type::field(None), noirc_abi::AbiVisibility::Private), Param(Identifier(y), Type::field(None), noirc_abi::AbiVisibility::Private), @@ -314,7 +319,29 @@ mod test { type_check_src_code(src, vec![String::from("main"), String::from("foo")]); } + #[test] + fn basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + type_check_src_code(src, vec![String::from("main"), String::from("foo")]); + } + #[test] + fn closure_with_no_args() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = || x; + closure() + } + "#; + + type_check_src_code(src, vec![String::from("main")]); + } // This is the same Stub that is in the resolver, maybe we can pull this out into a test module and re-use? struct TestPathResolver(HashMap); diff --git a/crates/noirc_frontend/src/hir_def/expr.rs b/crates/noirc_frontend/src/hir_def/expr.rs index db7db0a803d..fd980328f5f 100644 --- a/crates/noirc_frontend/src/hir_def/expr.rs +++ b/crates/noirc_frontend/src/hir_def/expr.rs @@ -197,9 +197,25 @@ impl HirBlockExpression { } } +/// A variable captured inside a closure +#[derive(Debug, Clone)] +pub struct HirCapturedVar { + pub ident: HirIdent, + + /// This will be None when the capture refers to a local variable declared + /// in the same scope as the closure. In a closure-inside-another-closure + /// scenarios, we might have a transitive captures of variables that must + /// be propagated during the construction of each closure. In this case, + /// we store the index of the captured variable in the environment of our + /// direct parent closure. We do this in order to simplify the HIR to AST + /// transformation in the monomorphization pass. + pub transitive_capture_index: Option, +} + #[derive(Debug, Clone)] pub struct HirLambda { pub parameters: Vec<(HirPattern, Type)>, pub return_type: Type, pub body: ExprId, + pub captures: Vec, } diff --git a/crates/noirc_frontend/src/hir_def/function.rs b/crates/noirc_frontend/src/hir_def/function.rs index a69e8bb08b5..225731626f0 100644 --- a/crates/noirc_frontend/src/hir_def/function.rs +++ b/crates/noirc_frontend/src/hir_def/function.rs @@ -180,9 +180,9 @@ impl FuncMeta { /// Gives the (uninstantiated) return type of this function. pub fn return_type(&self) -> &Type { match &self.typ { - Type::Function(_, ret) => ret, + Type::Function(_, ret, _env) => ret, Type::Forall(_, typ) => match typ.as_ref() { - Type::Function(_, ret) => ret, + Type::Function(_, ret, _env) => ret, _ => unreachable!(), }, _ => unreachable!(), diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index ff0a4e53fae..d77b8033ba1 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -70,8 +70,11 @@ pub enum Type { /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. NamedGeneric(TypeVariable, Rc), - /// A functions with arguments, and a return type. - Function(Vec, Box), + /// A functions with arguments, a return type and environment. + /// the environment should be `Unit` by default, + /// for closures it should contain a `Tuple` type with the captured + /// variable types. + Function(Vec, Box, Box), /// &mut T MutableReference(Box), @@ -697,9 +700,10 @@ impl Type { Type::Tuple(fields) => { fields.iter().any(|field| field.contains_numeric_typevar(target_id)) } - Type::Function(parameters, return_type) => { + Type::Function(parameters, return_type, env) => { parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id)) || return_type.contains_numeric_typevar(target_id) + || env.contains_numeric_typevar(target_id) } Type::Struct(struct_type, generics) => { generics.iter().enumerate().any(|(i, generic)| { @@ -797,9 +801,15 @@ impl std::fmt::Display for Type { let typevars = vecmap(typevars, |(var, _)| var.to_string()); write!(f, "forall {}. {}", typevars.join(" "), typ) } - Type::Function(args, ret) => { - let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {}", args.join(", "), ret) + Type::Function(args, ret, env) => { + let closure_env_text = match **env { + Type::Unit => "".to_string(), + _ => format!(" with closure environment {env}"), + }; + + let args = vecmap(args.iter(), ToString::to_string); + + write!(f, "fn({}) -> {ret}{closure_env_text}", args.join(", ")) } Type::MutableReference(element) => { write!(f, "&mut {element}") @@ -1196,9 +1206,9 @@ impl Type { } } - (Function(params_a, ret_a), Function(params_b, ret_b)) => { + (Function(params_a, ret_a, _env_a), Function(params_b, ret_b, _env_b)) => { if params_a.len() == params_b.len() { - for (a, b) in params_a.iter().zip(params_b) { + for (a, b) in params_a.iter().zip(params_b.iter()) { a.try_unify(b, span)?; } @@ -1403,7 +1413,7 @@ impl Type { } } - (Function(params_a, ret_a), Function(params_b, ret_b)) => { + (Function(params_a, ret_a, _env_a), Function(params_b, ret_b, _env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b) { a.is_subtype_of(b, span)?; @@ -1505,7 +1515,7 @@ impl Type { Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), - Type::Function(_, _) => unreachable!(), + Type::Function(_, _, _) => unreachable!(), Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"), Type::NotConstant => unreachable!(), } @@ -1620,10 +1630,11 @@ impl Type { let typ = Box::new(typ.substitute(type_bindings)); Type::Forall(typevars.clone(), typ) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, |arg| arg.substitute(type_bindings)); let ret = Box::new(ret.substitute(type_bindings)); - Type::Function(args, ret) + let env = Box::new(env.substitute(type_bindings)); + Type::Function(args, ret, env) } Type::MutableReference(element) => { Type::MutableReference(Box::new(element.substitute(type_bindings))) @@ -1660,8 +1671,10 @@ impl Type { Type::Forall(typevars, typ) => { !typevars.iter().any(|(id, _)| *id == target_id) && typ.occurs(target_id) } - Type::Function(args, ret) => { - args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) + Type::Function(args, ret, env) => { + args.iter().any(|arg| arg.occurs(target_id)) + || ret.occurs(target_id) + || env.occurs(target_id) } Type::MutableReference(element) => element.occurs(target_id), @@ -1706,11 +1719,13 @@ impl Type { self.clone() } - Function(args, ret) => { + Function(args, ret, env) => { let args = vecmap(args, |arg| arg.follow_bindings()); let ret = Box::new(ret.follow_bindings()); - Function(args, ret) + let env = Box::new(env.follow_bindings()); + Function(args, ret, env) } + MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), // Expect that this function should only be called on instantiated types @@ -1751,7 +1766,10 @@ fn convert_array_expression_to_slice( interner.push_expr_location(func, location.span, location.file); interner.push_expr_type(&call, target_type.clone()); - interner.push_expr_type(&func, Type::Function(vec![array_type], Box::new(target_type))); + interner.push_expr_type( + &func, + Type::Function(vec![array_type], Box::new(target_type), Box::new(Type::Unit)), + ); } impl BinaryTypeOperator { diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 7ad05f09231..33c3bbebff4 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -29,7 +29,6 @@ pub enum Expression { Tuple(Vec), ExtractTupleField(Box, usize), Call(Call), - Let(Let), Constrain(Box, Location), Assign(Assign), @@ -103,6 +102,12 @@ pub struct Binary { pub location: Location, } +#[derive(Debug, Clone)] +pub struct Lambda { + pub function: Ident, + pub env: Ident, +} + #[derive(Debug, Clone)] pub struct If { pub condition: Box, @@ -213,7 +218,7 @@ pub enum Type { Tuple(Vec), Slice(Box), MutableReference(Box), - Function(/*args:*/ Vec, /*ret:*/ Box), + Function(/*args:*/ Vec, /*ret:*/ Box, /*env:*/ Box), } impl Type { @@ -324,9 +329,13 @@ impl std::fmt::Display for Type { let elements = vecmap(elements, ToString::to_string); write!(f, "({})", elements.join(", ")) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {}", args.join(", "), ret) + let closure_env_text = match **env { + Type::Unit => "".to_string(), + _ => format!(" with closure environment {env}"), + }; + write!(f, "fn({}) -> {}{}", args.join(", "), ret, closure_env_text) } Type::Slice(element) => write!(f, "[{element}"), Type::MutableReference(element) => write!(f, "&mut {element}"), diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index dbe2ee080bf..c8167baf6bb 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -19,6 +19,7 @@ use crate::{ expr::*, function::{FuncMeta, Param, Parameters}, stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement}, + types, }, node_interner::{self, DefinitionKind, NodeInterner, StmtId}, token::Attribute, @@ -30,6 +31,11 @@ use self::ast::{Definition, FuncId, Function, LocalId, Program}; pub mod ast; pub mod printer; +struct LambdaContext { + env_ident: Box, + captures: Vec, +} + /// The context struct for the monomorphization pass. /// /// This struct holds the FIFO queue of functions to monomorphize, which is added to @@ -58,6 +64,8 @@ struct Monomorphizer<'interner> { /// Used to reference existing definitions in the HIR interner: &'interner NodeInterner, + lambda_envs_stack: Vec, + next_local_id: u32, next_function_id: u32, } @@ -103,6 +111,7 @@ impl<'interner> Monomorphizer<'interner> { next_local_id: 0, next_function_id: 0, interner, + lambda_envs_stack: Vec::new(), } } @@ -348,7 +357,7 @@ impl<'interner> Monomorphizer<'interner> { } HirExpression::Constructor(constructor) => self.constructor(constructor, expr), - HirExpression::Lambda(lambda) => self.lambda(lambda), + HirExpression::Lambda(lambda) => self.lambda(lambda, expr), HirExpression::MethodCall(_) => { unreachable!("Encountered HirExpression::MethodCall during monomorphization") @@ -541,6 +550,15 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Block(definitions) } + /// Find a captured variable in the innermost closure + fn lookup_captured(&mut self, id: node_interner::DefinitionId) -> Option { + let ctx = self.lambda_envs_stack.last()?; + ctx.captures + .iter() + .position(|capture| capture.ident.id == id) + .map(|index| ast::Expression::ExtractTupleField(ctx.env_ident.clone(), index)) + } + /// A local (ie non-global) ident only fn local_ident(&mut self, ident: &HirIdent) -> Option { let definition = self.interner.definition(ident.id); @@ -564,14 +582,25 @@ impl<'interner> Monomorphizer<'interner> { let definition = self.lookup_function(*func_id, expr_id, &typ); let typ = Self::convert_type(&typ); - let ident = ast::Ident { location, mutable, definition, name, typ }; - ast::Expression::Ident(ident) + let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; + let ident_expression = ast::Expression::Ident(ident); + if self.is_function_closure_type(&typ) { + ast::Expression::Tuple(vec![ + ast::Expression::ExtractTupleField( + Box::new(ident_expression.clone()), + 0usize, + ), + ast::Expression::ExtractTupleField(Box::new(ident_expression), 1usize), + ]) + } else { + ident_expression + } } DefinitionKind::Global(expr_id) => self.expr(*expr_id), - DefinitionKind::Local(_) => { + DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| { let ident = self.local_ident(&ident).unwrap(); ast::Expression::Ident(ident) - } + }), DefinitionKind::GenericType(type_variable) => { let value = match &*type_variable.borrow() { TypeBinding::Unbound(_) => { @@ -657,10 +686,11 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Tuple(fields) } - HirType::Function(args, ret) => { + HirType::Function(args, ret, env) => { let args = vecmap(args, Self::convert_type); let ret = Box::new(Self::convert_type(ret)); - ast::Type::Function(args, ret) + let env = Box::new(Self::convert_type(env)); + ast::Type::Function(args, ret, env) } HirType::MutableReference(element) => { @@ -677,19 +707,44 @@ impl<'interner> Monomorphizer<'interner> { } } + fn is_function_closure(&self, raw_func_id: node_interner::ExprId) -> bool { + let t = Self::convert_type(&self.interner.id_type(raw_func_id)); + if self.is_function_closure_type(&t) { + true + } else if let ast::Type::Tuple(elements) = t { + if elements.len() == 2 { + matches!(elements[1], ast::Type::Function(_, _, _)) + } else { + false + } + } else { + false + } + } + + fn is_function_closure_type(&self, t: &ast::Type) -> bool { + if let ast::Type::Function(_, _, env) = t { + let e = (*env).clone(); + matches!(*e, ast::Type::Tuple(_captures)) + } else { + false + } + } + fn function_call( &mut self, call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let func = Box::new(self.expr(call.func)); + let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); + let func: Box; let return_type = self.interner.id_type(id); let return_type = Self::convert_type(&return_type); let location = call.location; - if let ast::Expression::Ident(ident) = func.as_ref() { + if let ast::Expression::Ident(ident) = original_func.as_ref() { if let Definition::Oracle(name) = &ident.definition { if name.as_str() == "println" { // Oracle calls are required to be wrapped in an unconstrained function @@ -699,12 +754,39 @@ impl<'interner> Monomorphizer<'interner> { } } - self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { - func, - arguments, - return_type, - location, - })) + let mut block_expressions = vec![]; + + let is_closure = self.is_function_closure(call.func); + if is_closure { + let extracted_func: ast::Expression; + let hir_call_func = self.interner.expression(&call.func); + if let HirExpression::Lambda(l) = hir_call_func { + let (setup, closure_variable) = self.lambda_with_setup(l, call.func); + block_expressions.push(setup); + extracted_func = closure_variable; + } else { + extracted_func = *original_func; + } + func = Box::new(ast::Expression::ExtractTupleField( + Box::new(extracted_func.clone()), + 1usize, + )); + let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); + arguments.insert(0, env_argument); + } else { + func = original_func.clone(); + }; + + let call = self + .try_evaluate_call(&func, &return_type) + .unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location })); + + if !block_expressions.is_empty() { + block_expressions.push(call); + ast::Expression::Block(block_expressions) + } else { + call + } } /// Adds a function argument that contains type metadata that is required to tell @@ -914,7 +996,16 @@ impl<'interner> Monomorphizer<'interner> { } } - fn lambda(&mut self, lambda: HirLambda) -> ast::Expression { + fn lambda(&mut self, lambda: HirLambda, expr: node_interner::ExprId) -> ast::Expression { + if lambda.captures.is_empty() { + self.lambda_no_capture(lambda) + } else { + let (setup, closure_variable) = self.lambda_with_setup(lambda, expr); + ast::Expression::Block(vec![setup, closure_variable]) + } + } + + fn lambda_no_capture(&mut self, lambda: HirLambda) -> ast::Expression { let ret_type = Self::convert_type(&lambda.return_type); let lambda_name = "lambda"; let parameter_types = vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ)); @@ -935,7 +1026,8 @@ impl<'interner> Monomorphizer<'interner> { let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; self.push_function(id, function); - let typ = ast::Type::Function(parameter_types, Box::new(ret_type)); + let typ = + ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(ast::Type::Unit)); let name = lambda_name.to_owned(); ast::Expression::Ident(ast::Ident { @@ -947,6 +1039,133 @@ impl<'interner> Monomorphizer<'interner> { }) } + fn lambda_with_setup( + &mut self, + lambda: HirLambda, + expr: node_interner::ExprId, + ) -> (ast::Expression, ast::Expression) { + // returns (, ) + // which can be used directly in callsites or transformed + // directly to a single `Expression` + // for other cases by `lambda` which is called by `expr` + // + // it solves the problem of detecting special cases where + // we call something like + // `{let env$.. = ..;}.1({let env$.. = ..;}.0, ..)` + // which was leading to redefinition errors + // + // instead of detecting and extracting + // patterns in the resulting tree, + // which seems more fragile, we directly reuse the return parameters + // of this function in those cases + let ret_type = Self::convert_type(&lambda.return_type); + let lambda_name = "lambda"; + let parameter_types = vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ)); + + // Manually convert to Parameters type so we can reuse the self.parameters method + let parameters = Parameters(vecmap(lambda.parameters, |(pattern, typ)| { + Param(pattern, typ, noirc_abi::AbiVisibility::Private) + })); + + let mut converted_parameters = self.parameters(parameters); + + let id = self.next_function_id(); + let name = lambda_name.to_owned(); + let return_type = ret_type.clone(); + + let env_local_id = self.next_local_id(); + let env_name = "env"; + let env_tuple = ast::Expression::Tuple(vecmap(&lambda.captures, |capture| { + match capture.transitive_capture_index { + Some(field_index) => match self.lambda_envs_stack.last() { + Some(lambda_ctx) => ast::Expression::ExtractTupleField( + lambda_ctx.env_ident.clone(), + field_index, + ), + None => unreachable!( + "Expected to find a parent closure environment, but found none" + ), + }, + None => { + let ident = self.local_ident(&capture.ident).unwrap(); + ast::Expression::Ident(ident) + } + } + })); + let expr_type = self.interner.id_type(expr); + let env_typ = if let types::Type::Function(_, _, function_env_type) = expr_type { + Self::convert_type(&function_env_type) + } else { + unreachable!("expected a Function type for a Lambda node") + }; + + let env_let_stmt = ast::Expression::Let(ast::Let { + id: env_local_id, + mutable: false, + name: env_name.to_string(), + expression: Box::new(env_tuple), + }); + + let location = None; // TODO: This should match the location of the lambda expression + let mutable = false; + let definition = Definition::Local(env_local_id); + + let env_ident = ast::Expression::Ident(ast::Ident { + location, + mutable, + definition, + name: env_name.to_string(), + typ: env_typ.clone(), + }); + + self.lambda_envs_stack.push(LambdaContext { + env_ident: Box::new(env_ident.clone()), + captures: lambda.captures, + }); + let body = self.expr(lambda.body); + self.lambda_envs_stack.pop(); + + let lambda_fn_typ: ast::Type = + ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(env_typ.clone())); + let lambda_fn = ast::Expression::Ident(ast::Ident { + definition: Definition::Function(id), + mutable: false, + location: None, // TODO: This should match the location of the lambda expression + name: name.clone(), + typ: lambda_fn_typ.clone(), + }); + + let mut parameters = vec![]; + parameters.push((env_local_id, true, env_name.to_string(), env_typ.clone())); + parameters.append(&mut converted_parameters); + + let unconstrained = false; + let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; + self.push_function(id, function); + + let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]); + let block_local_id = self.next_local_id(); + let block_ident_name = "closure_variable"; + let block_let_stmt = ast::Expression::Let(ast::Let { + id: block_local_id, + mutable: false, + name: block_ident_name.to_string(), + expression: Box::new(ast::Expression::Block(vec![env_let_stmt, lambda_value])), + }); + + let closure_definition = Definition::Local(block_local_id); + + let closure_ident = ast::Expression::Ident(ast::Ident { + location, + mutable: false, + definition: closure_definition, + name: block_ident_name.to_string(), + typ: ast::Type::Tuple(vec![env_typ, lambda_fn_typ]), + }); + + (block_let_stmt, closure_ident) + } + /// Implements std::unsafe::zeroed by returning an appropriate zeroed /// ast literal or collection node for the given type. Note that for functions /// there is no obvious zeroed value so this should be considered unsafe to use. @@ -984,8 +1203,8 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Tuple(fields) => { ast::Expression::Tuple(vecmap(fields, |field| self.zeroed_value_of_type(field))) } - ast::Type::Function(parameter_types, ret_type) => { - self.create_zeroed_function(parameter_types, ret_type) + ast::Type::Function(parameter_types, ret_type, env) => { + self.create_zeroed_function(parameter_types, ret_type, env) } ast::Type::Slice(element_type) => { ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral { @@ -1012,6 +1231,7 @@ impl<'interner> Monomorphizer<'interner> { &mut self, parameter_types: &[ast::Type], ret_type: &ast::Type, + env_type: &ast::Type, ) -> ast::Expression { let lambda_name = "zeroed_lambda"; @@ -1034,7 +1254,11 @@ impl<'interner> Monomorphizer<'interner> { mutable: false, location: None, name: lambda_name.to_owned(), - typ: ast::Type::Function(parameter_types.to_owned(), Box::new(ret_type.clone())), + typ: ast::Type::Function( + parameter_types.to_owned(), + Box::new(ret_type.clone()), + Box::new(env_type.clone()), + ), }) } } @@ -1072,3 +1296,167 @@ fn undo_instantiation_bindings(bindings: TypeBindings) { *var.borrow_mut() = TypeBinding::Unbound(id); } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use fm::FileId; + use iter_extended::vecmap; + + use crate::{ + graph::CrateId, + hir::{ + def_map::{ + CrateDefMap, LocalModuleId, ModuleData, ModuleDefId, ModuleId, ModuleOrigin, + }, + resolution::{ + import::PathResolutionError, path_resolver::PathResolver, resolver::Resolver, + }, + }, + hir_def::function::HirFunction, + node_interner::{FuncId, NodeInterner}, + parse_program, + }; + + use super::monomorphize; + + // TODO: refactor into a more general test utility? + // mostly copied from hir / type_check / mod.rs and adapted a bit + fn type_check_src_code(src: &str, func_namespace: Vec) -> (FuncId, NodeInterner) { + let (program, errors) = parse_program(src); + let mut interner = NodeInterner::default(); + + // Using assert_eq here instead of assert(errors.is_empty()) displays + // the whole vec if the assert fails rather than just two booleans + assert_eq!(errors, vec![]); + + let main_id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition("main".into(), main_id); + + let func_ids = vecmap(&func_namespace, |name| { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(name.into(), id); + id + }); + + let mut path_resolver = TestPathResolver(HashMap::new()); + for (name, id) in func_namespace.into_iter().zip(func_ids.clone()) { + path_resolver.insert_func(name.to_owned(), id); + } + + let mut def_maps: HashMap = HashMap::new(); + let file = FileId::default(); + + let mut modules = arena::Arena::new(); + modules.insert(ModuleData::new(None, ModuleOrigin::File(file), false)); + + def_maps.insert( + CrateId::dummy_id(), + CrateDefMap { + root: path_resolver.local_module_id(), + modules, + krate: CrateId::dummy_id(), + extern_prelude: HashMap::new(), + }, + ); + + let func_meta = vecmap(program.functions, |nf| { + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); + let (hir_func, func_meta, _resolver_errors) = + resolver.resolve_function(nf, main_id, ModuleId::dummy_id()); + // TODO: not sure why, we do get an error here, + // but otherwise seem to get an ok monomorphization result + // assert_eq!(resolver_errors, vec![]); + (hir_func, func_meta) + }); + + println!("Before update_fn"); + + for ((hir_func, meta), func_id) in func_meta.into_iter().zip(func_ids.clone()) { + interner.update_fn(func_id, hir_func); + interner.push_fn_meta(meta, func_id); + } + + println!("Before type_check_func"); + + // Type check section + let errors = crate::hir::type_check::type_check_func( + &mut interner, + func_ids.first().cloned().unwrap(), + ); + assert_eq!(errors, vec![]); + (func_ids.first().cloned().unwrap(), interner) + } + + // TODO: refactor into a more general test utility? + // TestPathResolver struct and impls copied from hir / type_check / mod.rs + struct TestPathResolver(HashMap); + + impl PathResolver for TestPathResolver { + fn resolve( + &self, + _def_maps: &HashMap, + path: crate::Path, + ) -> Result { + // Not here that foo::bar and hello::foo::bar would fetch the same thing + let name = path.segments.last().unwrap(); + let mod_def = self.0.get(&name.0.contents).cloned(); + mod_def.ok_or_else(move || PathResolutionError::Unresolved(name.clone())) + } + + fn local_module_id(&self) -> LocalModuleId { + // This is not LocalModuleId::dummy since we need to use this to index into a Vec + // later and do not want to push u32::MAX number of elements before we do. + LocalModuleId(arena::Index::from_raw_parts(0, 0)) + } + + fn module_id(&self) -> ModuleId { + ModuleId { krate: CrateId::dummy_id(), local_id: self.local_module_id() } + } + } + + impl TestPathResolver { + fn insert_func(&mut self, name: String, func_id: FuncId) { + self.0.insert(name, func_id.into()); + } + } + + // a helper test method + // TODO: maybe just compare trimmed src/expected + // for easier formatting? + fn check_rewrite(src: &str, expected: &str) { + let (func, interner) = type_check_src_code(src, vec!["main".to_string()]); + let program = monomorphize(func, &interner); + // println!("[{}]", program); + assert!(format!("{}", program) == expected); + } + + #[test] + fn simple_closure_with_no_captured_variables() { + let src = r#" + fn main() -> Field { + let x = 1; + let closure = || x; + closure() + } + "#; + + let expected_rewrite = r#"fn main$f0() -> Field { + let x$0 = 1; + let closure$3 = { + let closure_variable$2 = { + let env$1 = (x$l0); + (env$l1, lambda$f1) + }; + closure_variable$l2 + }; + closure$l3.1(closure$l3.0) +} +fn lambda$f1(mut env$l1: (Field)) -> Field { + env$l1.0 +} +"#; + check_rewrite(src, expected_rewrite); + } +} diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index f5fea5c1ea7..6b3d2757c14 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -672,7 +672,7 @@ fn get_type_method_key(typ: &Type) -> Option { Type::String(_) => Some(String), Type::Unit => Some(Unit), Type::Tuple(_) => Some(Tuple), - Type::Function(_, _) => Some(Function), + Type::Function(_, _, _) => Some(Function), Type::MutableReference(element) => get_type_method_key(element), // We do not support adding methods to these types From 602168cac35ecb336c6fd23c002bcfd5bea96bfb Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Thu, 3 Aug 2023 08:55:30 +0100 Subject: [PATCH 8/9] chore: clippy fix (#2136) --- crates/noirc_evaluator/src/ssa/acir_gen/mod.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs index 25a0c2ee2e8..f473becd966 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -736,13 +736,11 @@ impl Context { ) -> Result { match self.convert_value(value_id, dfg) { AcirValue::Var(acir_var, _) => Ok(acir_var), - AcirValue::Array(array) => { - return Err(InternalError::UnExpected { - expected: "a numeric value".to_string(), - found: format!("{array:?}"), - location: self.acir_context.get_location(), - }) - } + AcirValue::Array(array) => Err(InternalError::UnExpected { + expected: "a numeric value".to_string(), + found: format!("{array:?}"), + location: self.acir_context.get_location(), + }), AcirValue::DynamicArray(_) => Err(InternalError::UnExpected { expected: "a numeric value".to_string(), found: "an array".to_string(), From 8e976ea2104153b428d9df6e06ab53051e2832a7 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Thu, 3 Aug 2023 10:12:07 +0200 Subject: [PATCH 9/9] chore: replace usage of `Directive::Quotient` with brillig opcode (#1766) * Use brillig instead of inverse directive * *wip* * Remove usage of quotient directive in favor of brillig * chore: improve import * chore: remove unnecessary brillig quotient * chore: remove unnecessary comment change * chore: comment change * chore: push clone further up call stack * chore: correct comment on `brillig_quotient` * chore: improve docs for `directive_quotient` * chore: ignore pseudocode entirely --------- Co-authored-by: TomAFrench --- .../brillig/brillig_gen/brillig_directive.rs | 60 +++++++++++++++- .../ssa/acir_gen/acir_ir/generated_acir.rs | 68 +++++++++---------- 2 files changed, 93 insertions(+), 35 deletions(-) diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs index 219a954a595..93e760f9737 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs @@ -1,4 +1,6 @@ -use acvm::acir::brillig::{BinaryFieldOp, Opcode as BrilligOpcode, RegisterIndex, Value}; +use acvm::acir::brillig::{ + BinaryFieldOp, BinaryIntOp, Opcode as BrilligOpcode, RegisterIndex, Value, +}; /// Generates brillig bytecode which computes the inverse of its input if not null, and zero else. pub(crate) fn directive_invert() -> Vec { @@ -29,3 +31,59 @@ pub(crate) fn directive_invert() -> Vec { BrilligOpcode::Stop, ] } + +/// Generates brillig bytecode which computes `a / b` and returns the quotient and remainder. +/// It returns `(0,0)` if the predicate is null. +/// +/// +/// This is equivalent to the Noir (psuedo)code +/// +/// ```ignore +/// fn quotient(a: T, b: T, predicate: bool) -> (T,T) { +/// if predicate != 0 { +/// (a/b, a-a/b*b) +/// } else { +/// (0,0) +/// } +/// } +/// ``` +pub(crate) fn directive_quotient(bit_size: u32) -> Vec { + // `a` is (0) (i.e register index 0) + // `b` is (1) + // `predicate` is (2) + vec![ + // If the predicate is zero, we jump to the exit segment + BrilligOpcode::JumpIfNot { condition: RegisterIndex::from(2), location: 6 }, + //q = a/b is set into register (3) + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::UnsignedDiv, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(3), + bit_size, + }, + //(1)= q*b + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Mul, + lhs: RegisterIndex::from(3), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(1), + bit_size, + }, + //(1) = a-q*b + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Sub, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(1), + bit_size, + }, + //(0) = q + BrilligOpcode::Mov { destination: RegisterIndex::from(0), source: RegisterIndex::from(3) }, + BrilligOpcode::Stop, + // Exit segment: we return 0,0 + BrilligOpcode::Const { destination: RegisterIndex::from(0), value: Value::from(0_usize) }, + BrilligOpcode::Const { destination: RegisterIndex::from(1), value: Value::from(0_usize) }, + BrilligOpcode::Stop, + ] +} diff --git a/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index 738387fbaab..b425eab42d3 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -11,7 +11,7 @@ use acvm::acir::{ brillig::Opcode as BrilligOpcode, circuit::{ brillig::{Brillig as AcvmBrillig, BrilligInputs, BrilligOutputs}, - directives::{LogInfo, QuotientDirective}, + directives::LogInfo, opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode}, }, native_types::Witness, @@ -432,13 +432,13 @@ impl GeneratedAcir { } } - let (q_witness, r_witness) = self.quotient_directive( - lhs.clone(), - rhs.clone(), - Some(predicate.clone()), - max_q_bits, - max_rhs_bits, - )?; + let (q_witness, r_witness) = + self.brillig_quotient(lhs.clone(), rhs.clone(), predicate.clone(), max_bit_size + 1); + + // Apply range constraints to injected witness values. + // Constrains `q` to be 0 <= q < 2^{q_max_bits}, etc. + self.range_constraint(q_witness, max_q_bits)?; + self.range_constraint(r_witness, max_rhs_bits)?; // Constrain r < rhs self.bound_constraint_with_offset(&r_witness.into(), rhs, predicate, max_rhs_bits)?; @@ -457,6 +457,32 @@ impl GeneratedAcir { Ok((q_witness, r_witness)) } + /// Adds a brillig opcode which injects witnesses with values `q = a / b` and `r = a % b`. + /// + /// Suitable range constraints for `q` and `r` must be applied externally. + pub(crate) fn brillig_quotient( + &mut self, + lhs: Expression, + rhs: Expression, + predicate: Expression, + max_bit_size: u32, + ) -> (Witness, Witness) { + // Create the witness for the result + let q_witness = self.next_witness_index(); + let r_witness = self.next_witness_index(); + + let quotient_code = brillig_directive::directive_quotient(max_bit_size); + let inputs = vec![ + BrilligInputs::Single(lhs), + BrilligInputs::Single(rhs), + BrilligInputs::Single(predicate.clone()), + ]; + let outputs = vec![BrilligOutputs::Simple(q_witness), BrilligOutputs::Simple(r_witness)]; + self.brillig(Some(predicate), quotient_code, inputs, outputs); + + (q_witness, r_witness) + } + /// Generate constraints that are satisfied iff /// lhs < rhs , when offset is 1, or /// lhs <= rhs, when offset is 0 @@ -692,32 +718,6 @@ impl GeneratedAcir { Ok(()) } - /// Adds a directive which injects witnesses with values `q = a / b` and `r = a % b`. - /// - /// Suitable range constraints are also applied to `q` and `r`. - pub(crate) fn quotient_directive( - &mut self, - a: Expression, - b: Expression, - predicate: Option, - q_max_bits: u32, - r_max_bits: u32, - ) -> Result<(Witness, Witness), RuntimeError> { - let q_witness = self.next_witness_index(); - let r_witness = self.next_witness_index(); - - let directive = - Directive::Quotient(QuotientDirective { a, b, q: q_witness, r: r_witness, predicate }); - self.push_opcode(AcirOpcode::Directive(directive)); - - // Apply range constraints to injected witness values. - // Constrains `q` to be 0 <= q < 2^{q_max_bits}, etc. - self.range_constraint(q_witness, q_max_bits)?; - self.range_constraint(r_witness, r_max_bits)?; - - Ok((q_witness, r_witness)) - } - /// Returns a `Witness` that is constrained to be: /// - `1` if lhs >= rhs /// - `0` otherwise