From a0ce520a7b4abf92ff16c54f2e1401c8a292b2dd Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 08:49:25 +0000 Subject: [PATCH 01/12] feat: Runtime separation extracted from inliner --- acvm-repo/brillig_vm/src/memory.rs | 3 + compiler/noirc_evaluator/src/ssa.rs | 3 +- compiler/noirc_evaluator/src/ssa/ir/dfg.rs | 2 +- .../noirc_evaluator/src/ssa/ir/function.rs | 7 + compiler/noirc_evaluator/src/ssa/ir/map.rs | 2 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 185 +++++++++++++++--- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 + .../src/ssa/opt/runtime_separation.rs | 185 ++++++++++++++++++ .../src/ssa/ssa_gen/program.rs | 8 + tooling/debugger/ignored-tests.txt | 27 +-- 10 files changed, 377 insertions(+), 46 deletions(-) create mode 100644 compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs diff --git a/acvm-repo/brillig_vm/src/memory.rs b/acvm-repo/brillig_vm/src/memory.rs index feeb3706bde..90abbbec760 100644 --- a/acvm-repo/brillig_vm/src/memory.rs +++ b/acvm-repo/brillig_vm/src/memory.rs @@ -318,6 +318,9 @@ impl Memory { } pub fn read_slice(&self, addr: MemoryAddress, len: usize) -> &[MemoryValue] { + if len == 0 { + return &[]; + } &self.inner[addr.to_usize()..(addr.to_usize() + len)] } diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index c2fe7878bf8..00aeed75a98 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -55,8 +55,9 @@ pub(crate) fn optimize_into_acir( let ssa = SsaBuilder::new(program, print_passes, force_brillig_output, print_timings)? .run_pass(Ssa::defunctionalize, "After Defunctionalization:") .run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:") - .run_pass(Ssa::inline_functions, "After Inlining:") + .run_pass(Ssa::separate_runtime, "After Runtime Separation:") .run_pass(Ssa::resolve_is_unconstrained, "After Resolving IsUnconstrained:") + .run_pass(Ssa::inline_functions, "After Inlining:") // Run mem2reg with the CFG separated into blocks .run_pass(Ssa::mem2reg, "After Mem2Reg:") .run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization") diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index 85630b75614..07598569935 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -22,7 +22,7 @@ use noirc_errors::Location; /// its blocks, instructions, and values. This struct is largely responsible for /// owning most data in a function and handing out Ids to this data that can be /// shared without worrying about ownership. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub(crate) struct DataFlowGraph { /// All of the instructions in a function instructions: DenseMap, diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index a49e02b0380..c44824b464b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -64,6 +64,13 @@ impl Function { Self { name, id, entry_block, dfg, runtime: RuntimeType::Acir(InlineType::default()) } } + /// Creates a new function as a clone of the one passed in with the passed in id. + pub(crate) fn clone_with_id(id: FunctionId, another: &Function) -> Self { + let dfg = another.dfg.clone(); + let entry_block = another.entry_block; + Self { name: another.name.clone(), id, entry_block, dfg, runtime: another.runtime } + } + /// The name of the function. /// Used exclusively for debugging purposes. pub(crate) fn name(&self) -> &str { diff --git a/compiler/noirc_evaluator/src/ssa/ir/map.rs b/compiler/noirc_evaluator/src/ssa/ir/map.rs index b6055973f1c..3c3feabc390 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/map.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/map.rs @@ -115,7 +115,7 @@ impl std::fmt::Display for Id { /// access to indices is provided. Since IDs must be stable and correspond /// to indices in the internal Vec, operations that would change element /// ordering like pop, remove, swap_remove, etc, are not possible. -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct DenseMap { storage: Vec, } diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 73dc3888184..9b75712988e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -18,6 +18,7 @@ use crate::ssa::{ ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; +use im::HashSet as ImmutableHashSet; /// An arbitrary limit to the maximum number of recursive call /// frames at any point in time. @@ -26,7 +27,7 @@ const RECURSION_LIMIT: u32 = 1000; impl Ssa { /// Inline all functions within the IR. /// - /// In the case of recursive functions, this will attempt + /// In the case of recursive Acir functions, this will attempt /// to recursively inline until the RECURSION_LIMIT is reached. /// /// Functions are recursively inlined into main until either we finish @@ -40,6 +41,8 @@ impl Ssa { /// There are some attributes that allow inlining a function at a different step of codegen. /// Currently this is just `InlineType::NoPredicates` for which we have a flag indicating /// whether treating that inline functions. The default is to treat these functions as entry points. + /// + /// This step should run after runtime separation, since it relies on the runtime of the called functions being final. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn inline_functions(self) -> Ssa { Self::inline_functions_inner(self, true) @@ -51,12 +54,17 @@ impl Ssa { } fn inline_functions_inner(mut self, no_predicates_is_entry_point: bool) -> Ssa { + let recursive_functions = find_all_recursive_functions(&self); self.functions = btree_map( - get_entry_point_functions(&self, no_predicates_is_entry_point), + get_functions_to_inline_into(&self, no_predicates_is_entry_point), |entry_point| { - let new_function = - InlineContext::new(&self, entry_point, no_predicates_is_entry_point) - .inline_all(&self); + let new_function = InlineContext::new( + &self, + entry_point, + no_predicates_is_entry_point, + recursive_functions.clone(), + ) + .inline_all(&self); (entry_point, new_function) }, ); @@ -79,6 +87,8 @@ struct InlineContext { entry_point: FunctionId, no_predicates_is_entry_point: bool, + // We keep track of the recursive functions in the SSA to avoid inlining them in a brillig context. + recursive_functions: BTreeSet, } /// The per-function inlining context contains information that is only valid for one function. @@ -112,28 +122,143 @@ struct PerFunctionContext<'function> { inlining_entry: bool, } -/// The entry point functions are each function we should inline into - and each function that -/// should be left in the final program. -/// This is the `main` function, any Acir functions with a [fold inline type][InlineType::Fold], -/// and any brillig functions used. -fn get_entry_point_functions( +/// Utility function to find out the direct calls of a function. +fn called_functions(func: &Function) -> BTreeSet { + let mut called_function_ids = BTreeSet::default(); + for block_id in func.reachable_blocks() { + for instruction_id in func.dfg[block_id].instructions() { + let Instruction::Call { func: called_value_id, .. } = &func.dfg[*instruction_id] else { + continue; + }; + + if let Value::Function(function_id) = func.dfg[*called_value_id] { + called_function_ids.insert(function_id); + } + } + } + + called_function_ids +} + +/// Recursively explore the SSA to find the functions where we enter a brillig context. +fn find_entry_points_to_brillig_context( + ssa: &Ssa, + current_function: FunctionId, + explored_functions: &mut HashSet, + entry_points: &mut BTreeSet, +) { + if !explored_functions.insert(current_function) { + return; + } + + let function = &ssa.functions[¤t_function]; + if function.runtime() == RuntimeType::Brillig { + entry_points.insert(current_function); + return; + } + + let called_functions = called_functions(function); + + for called_function in called_functions { + find_entry_points_to_brillig_context( + ssa, + called_function, + explored_functions, + entry_points, + ); + } +} + +fn find_all_entry_points_to_brillig_context(ssa: &Ssa) -> BTreeSet { + let mut explored_functions = HashSet::default(); + let mut entry_points = BTreeSet::default(); + find_entry_points_to_brillig_context( + ssa, + ssa.main_id, + &mut explored_functions, + &mut entry_points, + ); + entry_points +} + +// Recursively explore the SSA to find the functions that end up calling themselves +fn find_recursive_functions( + ssa: &Ssa, + current_function: FunctionId, + mut explored_functions: ImmutableHashSet, + recursive_functions: &mut BTreeSet, +) { + if explored_functions.contains(¤t_function) { + recursive_functions.insert(current_function); + return; + } + + let called_functions = called_functions( + ssa.functions.get(¤t_function).expect("Function should exist in SSA"), + ); + + explored_functions.insert(current_function); + + for called_function in called_functions { + find_recursive_functions( + ssa, + called_function, + explored_functions.clone(), + recursive_functions, + ); + } +} + +fn find_all_recursive_functions(ssa: &Ssa) -> BTreeSet { + let mut recursive_functions = BTreeSet::default(); + find_recursive_functions( + ssa, + ssa.main_id, + ImmutableHashSet::default(), + &mut recursive_functions, + ); + recursive_functions +} + +/// The functions we should inline into (and that should be left in the final program) are: +/// - main +/// - Any Brillig function called from Acir +/// - Any Brillig recursive function (Acir recursive functions will be inlined into the main function) +/// - Any Acir functions with a [fold inline type][InlineType::Fold], +fn get_functions_to_inline_into( ssa: &Ssa, no_predicates_is_entry_point: bool, ) -> BTreeSet { + let brillig_entry_points = find_all_entry_points_to_brillig_context(ssa); let functions = ssa.functions.iter(); - let mut entry_points = functions + + let acir_entry_points: BTreeSet<_> = functions .filter(|(_, function)| { // If we have not already finished the flattening pass, functions marked // to not have predicates should be marked as entry points. let no_predicates_is_entry_point = no_predicates_is_entry_point && function.is_no_predicates(); - function.runtime().is_entry_point() || no_predicates_is_entry_point + function.runtime() != RuntimeType::Brillig && function.runtime().is_entry_point() + || no_predicates_is_entry_point }) .map(|(id, _)| *id) - .collect::>(); - - entry_points.insert(ssa.main_id); - entry_points + .collect(); + + let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa) + .iter() + .filter(|recursive_function_id| { + let function = + ssa.functions.get(recursive_function_id).expect("Function should exist in SSA"); + function.runtime() == RuntimeType::Brillig + }) + .copied() + .collect(); + + std::iter::once(ssa.main_id) + .chain(acir_entry_points) + .chain(brillig_entry_points) + .chain(brillig_recursive_functions) + .collect() } impl InlineContext { @@ -146,6 +271,7 @@ impl InlineContext { ssa: &Ssa, entry_point: FunctionId, no_predicates_is_entry_point: bool, + recursive_functions: BTreeSet, ) -> InlineContext { let source = &ssa.functions[&entry_point]; let mut builder = FunctionBuilder::new(source.name().to_owned(), entry_point); @@ -156,6 +282,7 @@ impl InlineContext { entry_point, call_stack: CallStack::new(), no_predicates_is_entry_point, + recursive_functions, } } @@ -391,14 +518,24 @@ impl<'function> PerFunctionContext<'function> { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(func_id) => { let function = &ssa.functions[&func_id]; - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be marked as entry points unless we are inlining into brillig. - let entry_point = &ssa.functions[&self.context.entry_point]; - let no_predicates_is_entry_point = - self.context.no_predicates_is_entry_point - && function.is_no_predicates() - && !matches!(entry_point.runtime(), RuntimeType::Brillig); - if function.runtime().is_entry_point() || no_predicates_is_entry_point { + + let should_retain_call = + if let RuntimeType::Acir(inline_type) = function.runtime() { + // If the called function is acir, we inline if it's not an entry point + + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be marked as entry points. + let no_predicates_is_entry_point = + self.context.no_predicates_is_entry_point + && function.is_no_predicates(); + inline_type.is_entry_point() || no_predicates_is_entry_point + } else { + // If the called function is brillig, we inline only if it's into brillig and the function is not recursive + ssa.functions[&self.context.entry_point].runtime() + != RuntimeType::Brillig + || self.context.recursive_functions.contains(&func_id) + }; + if should_retain_call { self.push_instruction(*id); } else { self.inline_function(ssa, *id, func_id, arguments); diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index f6c3f022bfc..4e5fa262696 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -18,5 +18,6 @@ mod remove_bit_shifts; mod remove_enable_side_effects; mod remove_if_else; mod resolve_is_unconstrained; +mod runtime_separation; mod simplify_cfg; mod unrolling; diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs new file mode 100644 index 00000000000..ee26707e504 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -0,0 +1,185 @@ +use std::collections::BTreeSet; + +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; + +use crate::ssa::{ + ir::{ + function::{Function, FunctionId, RuntimeType}, + instruction::Instruction, + value::{Value, ValueId}, + }, + ssa_gen::Ssa, +}; + +impl Ssa { + /// This SSA step separates the runtime of the functions in the SSA. + /// After this step, all functions with runtime `Acir` will be converted to Acir and + /// the functions with runtime `Brillig` will be converted to Brillig. + /// It does so by cloning all ACIR functions called from a Brillig context + /// and changing the runtime of the cloned functions to Brillig. + /// This pass needs to run after functions as values have been resolved (defunctionalization). + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn separate_runtime(mut self) -> Self { + RuntimeSeparatorContext::separate_runtime(&mut self); + + self + } +} + +#[derive(Debug, Default)] +struct RuntimeSeparatorContext { + acir_functions_called_from_brillig: BTreeSet, + mapped_functions: HashMap, +} + +impl RuntimeSeparatorContext { + pub(crate) fn separate_runtime(ssa: &mut Ssa) { + let mut runtime_separator = RuntimeSeparatorContext::default(); + + // We first collect all the acir functions called from a brillig context by exploring the SSA recursively + let mut processed_functions = HashSet::default(); + runtime_separator.collect_acir_functions_called_from_brillig( + ssa, + ssa.main_id, + false, + &mut processed_functions, + ); + // Now we clone the relevant acir functions and change their runtime to brillig + runtime_separator.convert_acir_functions_called_from_brillig_to_brillig(ssa); + + // Now we update any calls within a brillig context to the mapped functions exploring the SSA recursively + let mut processed_functions = HashSet::default(); + runtime_separator.replace_calls_to_mapped_functions( + ssa, + ssa.main_id, + false, + &mut processed_functions, + ); + } + + fn collect_acir_functions_called_from_brillig( + &mut self, + ssa: &Ssa, + current_func_id: FunctionId, + mut within_brillig: bool, + processed_functions: &mut HashSet<(/* within_brillig */ bool, FunctionId)>, + ) { + // Processed functions needs the within brillig flag, since it is possible to call the same function from both brillig and acir + if processed_functions.contains(&(within_brillig, current_func_id)) { + return; + } + processed_functions.insert((within_brillig, current_func_id)); + + let func = ssa.functions.get(¤t_func_id).expect("Function should exist in SSA"); + if func.runtime() == RuntimeType::Brillig { + within_brillig = true; + } + + let called_functions = called_functions(func); + + if within_brillig { + for called_func_id in called_functions.iter() { + let called_func = + ssa.functions.get(called_func_id).expect("Function should exist in SSA"); + if matches!(called_func.runtime(), RuntimeType::Acir(_)) { + self.acir_functions_called_from_brillig.insert(*called_func_id); + } + } + } + + for called_func_id in called_functions.into_iter() { + self.collect_acir_functions_called_from_brillig( + ssa, + called_func_id, + within_brillig, + processed_functions, + ); + } + } + + fn convert_acir_functions_called_from_brillig_to_brillig(&mut self, ssa: &mut Ssa) { + for acir_func_id in self.acir_functions_called_from_brillig.iter() { + let cloned_id = ssa.clone_fn(*acir_func_id); + let new_func = + ssa.functions.get_mut(&cloned_id).expect("Cloned function should exist in SSA"); + new_func.set_runtime(RuntimeType::Brillig); + self.mapped_functions.insert(*acir_func_id, cloned_id); + } + } + + fn replace_calls_to_mapped_functions( + &mut self, + ssa: &mut Ssa, + current_func_id: FunctionId, + mut within_brillig: bool, + processed_functions: &mut HashSet, + ) { + // Processed functions no longer needs the within brillig flag since we've already cloned the acir functions called from brillig + if processed_functions.contains(¤t_func_id) { + return; + } + processed_functions.insert(current_func_id); + + let func = ssa.functions.get_mut(¤t_func_id).expect("Function should exist in SSA"); + if func.runtime() == RuntimeType::Brillig { + within_brillig = true; + } + + let called_functions_values = called_functions_values(func); + + // If we are within brillig, swap the called functions with the mapped functions + if within_brillig { + for called_func_value_id in called_functions_values.iter() { + let Value::Function(called_func_id) = &func.dfg[*called_func_value_id] else { + unreachable!("Value should be a function") + }; + if let Some(mapped_func_id) = self.mapped_functions.get(called_func_id) { + let new_target_value = Value::Function(*mapped_func_id); + let mapped_value_id = func.dfg.make_value(new_target_value); + func.dfg.set_value_from_id(*called_func_value_id, mapped_value_id); + } + } + } + + // Get the called functions again after the replacements + let called_functions = called_functions(func); + for called_func_id in called_functions.into_iter() { + self.replace_calls_to_mapped_functions( + ssa, + called_func_id, + within_brillig, + processed_functions, + ); + } + } +} + +// We only consider direct calls to functions since functions as values should have been resolved +fn called_functions_values(func: &Function) -> BTreeSet { + let mut called_function_ids = BTreeSet::default(); + for block_id in func.reachable_blocks() { + for instruction_id in func.dfg[block_id].instructions() { + let Instruction::Call { func: called_value_id, .. } = &func.dfg[*instruction_id] else { + continue; + }; + + if let Value::Function(_) = func.dfg[*called_value_id] { + called_function_ids.insert(*called_value_id); + } + } + } + + called_function_ids +} + +fn called_functions(func: &Function) -> BTreeSet { + called_functions_values(func) + .into_iter() + .map(|value_id| { + let Value::Function(func_id) = func.dfg[value_id] else { + unreachable!("Value should be a function") + }; + func_id + }) + .collect() +} diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs index 21178c55c73..7a77aa76101 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -80,6 +80,14 @@ impl Ssa { self.functions.insert(new_id, function); new_id } + + /// Clones an already existing function with a fresh id + pub(crate) fn clone_fn(&mut self, existing_function_id: FunctionId) -> FunctionId { + let new_id = self.next_id.next(); + let function = Function::clone_with_id(new_id, &self.functions[&existing_function_id]); + self.functions.insert(new_id, function); + new_id + } } impl Display for Ssa { diff --git a/tooling/debugger/ignored-tests.txt b/tooling/debugger/ignored-tests.txt index a6d3c9a3a94..a9193896589 100644 --- a/tooling/debugger/ignored-tests.txt +++ b/tooling/debugger/ignored-tests.txt @@ -1,28 +1,17 @@ -array_dynamic_blackbox_input bigint -bit_shifts_comptime brillig_references brillig_to_bytes_integration debug_logs -double_verify_nested_proof -double_verify_proof -double_verify_proof_recursive -modulus -references -scalar_mul -signed_comparison -to_bytes_integration +fold_after_inlined_calls fold_basic fold_basic_nested_call fold_call_witness_condition -fold_after_inlined_calls -fold_numeric_generic_poseidon -no_predicates_basic -no_predicates_numeric_generic_poseidon -regression_4709 +fold_complex_outputs fold_distinct_return fold_fibonacci -fold_complex_outputs -slice_init_with_complex_type -hashmap -is_unconstrained \ No newline at end of file +fold_numeric_generic_poseidon +is_unconstrained +modulus +references +regression_4709 +to_bytes_integration \ No newline at end of file From 6669452c85b1c96589bf3291d4842327bd3791b3 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 10:34:30 +0000 Subject: [PATCH 02/12] refactor: improve runtime separation --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 67 +++++-------- .../src/ssa/opt/runtime_separation.rs | 96 +++++++++---------- 2 files changed, 70 insertions(+), 93 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 9b75712988e..1edcbbf37c8 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -140,47 +140,6 @@ fn called_functions(func: &Function) -> BTreeSet { called_function_ids } -/// Recursively explore the SSA to find the functions where we enter a brillig context. -fn find_entry_points_to_brillig_context( - ssa: &Ssa, - current_function: FunctionId, - explored_functions: &mut HashSet, - entry_points: &mut BTreeSet, -) { - if !explored_functions.insert(current_function) { - return; - } - - let function = &ssa.functions[¤t_function]; - if function.runtime() == RuntimeType::Brillig { - entry_points.insert(current_function); - return; - } - - let called_functions = called_functions(function); - - for called_function in called_functions { - find_entry_points_to_brillig_context( - ssa, - called_function, - explored_functions, - entry_points, - ); - } -} - -fn find_all_entry_points_to_brillig_context(ssa: &Ssa) -> BTreeSet { - let mut explored_functions = HashSet::default(); - let mut entry_points = BTreeSet::default(); - find_entry_points_to_brillig_context( - ssa, - ssa.main_id, - &mut explored_functions, - &mut entry_points, - ); - entry_points -} - // Recursively explore the SSA to find the functions that end up calling themselves fn find_recursive_functions( ssa: &Ssa, @@ -229,10 +188,30 @@ fn get_functions_to_inline_into( ssa: &Ssa, no_predicates_is_entry_point: bool, ) -> BTreeSet { - let brillig_entry_points = find_all_entry_points_to_brillig_context(ssa); - let functions = ssa.functions.iter(); + let brillig_entry_points: BTreeSet<_> = ssa + .functions + .iter() + .flat_map(|(_, function)| { + if function.runtime() != RuntimeType::Brillig { + called_functions(function) + .into_iter() + .filter(|called_function_id| { + ssa.functions + .get(called_function_id) + .expect("Function should exist in SSA") + .runtime() + == RuntimeType::Brillig + }) + .collect() + } else { + vec![] + } + }) + .collect(); - let acir_entry_points: BTreeSet<_> = functions + let acir_entry_points: BTreeSet<_> = ssa + .functions + .iter() .filter(|(_, function)| { // If we have not already finished the flattening pass, functions marked // to not have predicates should be marked as entry points. diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index ee26707e504..c289824478a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -28,8 +28,12 @@ impl Ssa { #[derive(Debug, Default)] struct RuntimeSeparatorContext { + // Original functions to clone to brillig acir_functions_called_from_brillig: BTreeSet, + // Tracks the original => cloned version mapped_functions: HashMap, + // Some original functions might not be called from ACIR at all, we store the ones that are to delete the others. + mapped_functions_called_from_acir: HashSet, } impl RuntimeSeparatorContext { @@ -44,17 +48,15 @@ impl RuntimeSeparatorContext { false, &mut processed_functions, ); + // Now we clone the relevant acir functions and change their runtime to brillig runtime_separator.convert_acir_functions_called_from_brillig_to_brillig(ssa); - // Now we update any calls within a brillig context to the mapped functions exploring the SSA recursively - let mut processed_functions = HashSet::default(); - runtime_separator.replace_calls_to_mapped_functions( - ssa, - ssa.main_id, - false, - &mut processed_functions, - ); + // Now we update any calls within a brillig context to the mapped functions + runtime_separator.replace_calls_to_mapped_functions(ssa); + + // Some functions might be unreachable now (for example an acir function only called from brillig) + prune_unreachable_functions(ssa); } fn collect_acir_functions_called_from_brillig( @@ -107,50 +109,21 @@ impl RuntimeSeparatorContext { } } - fn replace_calls_to_mapped_functions( - &mut self, - ssa: &mut Ssa, - current_func_id: FunctionId, - mut within_brillig: bool, - processed_functions: &mut HashSet, - ) { - // Processed functions no longer needs the within brillig flag since we've already cloned the acir functions called from brillig - if processed_functions.contains(¤t_func_id) { - return; - } - processed_functions.insert(current_func_id); - - let func = ssa.functions.get_mut(¤t_func_id).expect("Function should exist in SSA"); - if func.runtime() == RuntimeType::Brillig { - within_brillig = true; - } - - let called_functions_values = called_functions_values(func); - - // If we are within brillig, swap the called functions with the mapped functions - if within_brillig { - for called_func_value_id in called_functions_values.iter() { - let Value::Function(called_func_id) = &func.dfg[*called_func_value_id] else { - unreachable!("Value should be a function") - }; - if let Some(mapped_func_id) = self.mapped_functions.get(called_func_id) { - let new_target_value = Value::Function(*mapped_func_id); - let mapped_value_id = func.dfg.make_value(new_target_value); - func.dfg.set_value_from_id(*called_func_value_id, mapped_value_id); + fn replace_calls_to_mapped_functions(&mut self, ssa: &mut Ssa) { + for (_function_id, func) in ssa.functions.iter_mut() { + if func.runtime() == RuntimeType::Brillig { + for called_func_value_id in called_functions_values(func).iter() { + let Value::Function(called_func_id) = &func.dfg[*called_func_value_id] else { + unreachable!("Value should be a function") + }; + if let Some(mapped_func_id) = self.mapped_functions.get(called_func_id) { + let new_target_value = Value::Function(*mapped_func_id); + let mapped_value_id = func.dfg.make_value(new_target_value); + func.dfg.set_value_from_id(*called_func_value_id, mapped_value_id); + } } } } - - // Get the called functions again after the replacements - let called_functions = called_functions(func); - for called_func_id in called_functions.into_iter() { - self.replace_calls_to_mapped_functions( - ssa, - called_func_id, - within_brillig, - processed_functions, - ); - } } } @@ -183,3 +156,28 @@ fn called_functions(func: &Function) -> BTreeSet { }) .collect() } + +fn collect_reachable_functions( + ssa: &Ssa, + current_func_id: FunctionId, + reachable_functions: &mut HashSet, +) { + if reachable_functions.contains(¤t_func_id) { + return; + } + reachable_functions.insert(current_func_id); + + let func = ssa.functions.get(¤t_func_id).expect("Function should exist in SSA"); + let called_functions = called_functions(func); + + for called_func_id in called_functions.iter() { + collect_reachable_functions(ssa, *called_func_id, reachable_functions); + } +} + +fn prune_unreachable_functions(ssa: &mut Ssa) { + let mut reachable_functions = HashSet::default(); + collect_reachable_functions(ssa, ssa.main_id, &mut reachable_functions); + + ssa.functions.retain(|id, _value| reachable_functions.contains(id)); +} From 93d416b69e1526a2b11bb5d28df8c09248b4e4c0 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 10:45:24 +0000 Subject: [PATCH 03/12] chore: remove copied --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 1edcbbf37c8..53ce4e533b4 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -224,13 +224,12 @@ fn get_functions_to_inline_into( .collect(); let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa) - .iter() + .into_iter() .filter(|recursive_function_id| { let function = ssa.functions.get(recursive_function_id).expect("Function should exist in SSA"); function.runtime() == RuntimeType::Brillig }) - .copied() .collect(); std::iter::once(ssa.main_id) From baffbac3a76966ec5ca51cc32f5f991e0d4f982e Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 11:41:29 +0000 Subject: [PATCH 04/12] remove unused property --- compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index c289824478a..5bcc83cf6b3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -12,7 +12,7 @@ use crate::ssa::{ }; impl Ssa { - /// This SSA step separates the runtime of the functions in the SSA. + /// This optimization step separates the runtime of the functions in the SSA. /// After this step, all functions with runtime `Acir` will be converted to Acir and /// the functions with runtime `Brillig` will be converted to Brillig. /// It does so by cloning all ACIR functions called from a Brillig context @@ -32,8 +32,6 @@ struct RuntimeSeparatorContext { acir_functions_called_from_brillig: BTreeSet, // Tracks the original => cloned version mapped_functions: HashMap, - // Some original functions might not be called from ACIR at all, we store the ones that are to delete the others. - mapped_functions_called_from_acir: HashSet, } impl RuntimeSeparatorContext { From ca848fe903c1e41821c334d28beecf3e4bd52de3 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 11:49:23 +0000 Subject: [PATCH 05/12] add test --- .../src/ssa/opt/runtime_separation.rs | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index 5bcc83cf6b3..060c50fc81f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -179,3 +179,52 @@ fn prune_unreachable_functions(ssa: &mut Ssa) { ssa.functions.retain(|id, _value| reachable_functions.contains(id)); } + +#[cfg(test)] +mod test { + use noirc_frontend::monomorphization::ast::InlineType; + + use crate::ssa::{ + function_builder::FunctionBuilder, + ir::{function::RuntimeType, map::Id, types::Type}, + }; + + #[test] + fn basic_runtime_separation() { + // brillig fn foo { + // b0(): + // v0 = call bar() + // return v0 + // } + // acir fn bar { + // b0(): + // return 72 + // } + let foo_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("foo".into(), foo_id); + builder.current_function.set_runtime(RuntimeType::Brillig); + + let bar_id = Id::test_new(1); + let bar = builder.import_function(bar_id); + let results = builder.insert_call(bar, Vec::new(), vec![Type::field()]).to_vec(); + builder.terminate_with_return(results); + + builder.new_function("bar".into(), bar_id, InlineType::default()); + let expected_return = 72u128; + let seventy_two = builder.field_constant(expected_return); + builder.terminate_with_return(vec![seventy_two]); + + let ssa = builder.finish(); + assert_eq!(ssa.functions.len(), 2); + + let separated = ssa.separate_runtime(); + + // The original bar function must have been pruned + assert_eq!(separated.functions.len(), 2); + + // All functions should be brillig now + for func in separated.functions.values() { + assert_eq!(func.runtime(), RuntimeType::Brillig); + } + } +} From 8d074c11ef8a797aea653bdea6187b45dba6f82c Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 11:51:58 +0000 Subject: [PATCH 06/12] remove unneeded mut --- compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index 060c50fc81f..f651ddf41a8 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -107,7 +107,7 @@ impl RuntimeSeparatorContext { } } - fn replace_calls_to_mapped_functions(&mut self, ssa: &mut Ssa) { + fn replace_calls_to_mapped_functions(&self, ssa: &mut Ssa) { for (_function_id, func) in ssa.functions.iter_mut() { if func.runtime() == RuntimeType::Brillig { for called_func_value_id in called_functions_values(func).iter() { From 29d521a6c3e4ec8237bfb3b65aae49f7e475588c Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 13:18:48 +0000 Subject: [PATCH 07/12] test: add more complex unit test --- .../src/ssa/opt/runtime_separation.rs | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index f651ddf41a8..dc232398f00 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -182,11 +182,19 @@ fn prune_unreachable_functions(ssa: &mut Ssa) { #[cfg(test)] mod test { + use std::collections::BTreeSet; + use noirc_frontend::monomorphization::ast::InlineType; use crate::ssa::{ function_builder::FunctionBuilder, - ir::{function::RuntimeType, map::Id, types::Type}, + ir::{ + function::{Function, FunctionId, RuntimeType}, + map::Id, + types::Type, + }, + opt::runtime_separation::called_functions, + ssa_gen::Ssa, }; #[test] @@ -227,4 +235,87 @@ mod test { assert_eq!(func.runtime(), RuntimeType::Brillig); } } + + fn find_func_by_name<'ssa>( + ssa: &'ssa Ssa, + funcs: &BTreeSet, + name: &str, + ) -> &'ssa Function { + funcs + .iter() + .find_map(|id| { + let func = ssa.functions.get(id).unwrap(); + if func.name() == name { + Some(func) + } else { + None + } + }) + .unwrap() + } + + #[test] + fn same_function_shared_acir_brillig() { + // acir fn foo { + // b0(): + // v0 = call bar() + // v1 = call baz() + // return v0, v1 + // } + // brillig fn bar { + // b0(): + // v0 = call baz() + // return v0 + // } + // acir fn baz { + // b0(): + // return 72 + // } + let foo_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("foo".into(), foo_id); + + let bar_id = Id::test_new(1); + let baz_id = Id::test_new(2); + let bar = builder.import_function(bar_id); + let baz = builder.import_function(baz_id); + let v0 = builder.insert_call(bar, Vec::new(), vec![Type::field()]).to_vec(); + let v1 = builder.insert_call(baz, Vec::new(), vec![Type::field()]).to_vec(); + builder.terminate_with_return(vec![v0[0], v1[0]]); + + builder.new_brillig_function("bar".into(), bar_id); + let baz = builder.import_function(baz_id); + let v0 = builder.insert_call(baz, Vec::new(), vec![Type::field()]).to_vec(); + builder.terminate_with_return(v0); + + builder.new_function("baz".into(), baz_id, InlineType::default()); + let expected_return = 72u128; + let seventy_two = builder.field_constant(expected_return); + builder.terminate_with_return(vec![seventy_two]); + + let ssa = builder.finish(); + assert_eq!(ssa.functions.len(), 3); + + let separated = ssa.separate_runtime(); + + // The original baz function must have been duplicated + assert_eq!(separated.functions.len(), 4); + + let main_function = separated.functions.get(&separated.main_id).unwrap(); + assert_eq!(main_function.runtime(), RuntimeType::Acir(InlineType::Inline)); + + let main_calls = called_functions(main_function); + assert_eq!(main_calls.len(), 2); + + let bar = find_func_by_name(&separated, &main_calls, "bar"); + let baz_acir = find_func_by_name(&separated, &main_calls, "baz"); + + assert_eq!(baz_acir.runtime(), RuntimeType::Acir(InlineType::Inline)); + assert_eq!(bar.runtime(), RuntimeType::Brillig); + + let bar_calls = called_functions(bar); + assert_eq!(bar_calls.len(), 1); + + let baz_brillig = find_func_by_name(&separated, &bar_calls, "baz"); + assert_eq!(baz_brillig.runtime(), RuntimeType::Brillig); + } } From ef069fc962b87393d6c4d76a357879db20abb836 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Tue, 28 May 2024 13:21:08 +0000 Subject: [PATCH 08/12] add comments for clarity --- .../src/ssa/opt/runtime_separation.rs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index dc232398f00..0483721ab49 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -225,6 +225,16 @@ mod test { let ssa = builder.finish(); assert_eq!(ssa.functions.len(), 2); + // Expected result + // brillig fn foo { + // b0(): + // v0 = call bar() + // return v0 + // } + // brillig fn bar { + // b0(): + // return 72 + // } let separated = ssa.separate_runtime(); // The original bar function must have been pruned @@ -295,6 +305,26 @@ mod test { let ssa = builder.finish(); assert_eq!(ssa.functions.len(), 3); + // Expected result + // acir fn foo { + // b0(): + // v0 = call bar() + // v1 = call baz() <- baz_acir + // return v0, v1 + // } + // brillig fn bar { + // b0(): + // v0 = call baz() <- baz_brillig + // return v0 + // } + // acir fn baz { + // b0(): + // return 72 + // } + // brillig fn baz { + // b0(): + // return 72 + // } let separated = ssa.separate_runtime(); // The original baz function must have been duplicated From a848e7af39bf8c7ff415d7f83e69b25e470ca65d Mon Sep 17 00:00:00 2001 From: sirasistant Date: Wed, 29 May 2024 07:39:24 +0000 Subject: [PATCH 09/12] address PR comments --- acvm-repo/brillig_vm/src/memory.rs | 3 + .../noirc_evaluator/src/ssa/opt/inlining.rs | 64 +++++++------------ 2 files changed, 26 insertions(+), 41 deletions(-) diff --git a/acvm-repo/brillig_vm/src/memory.rs b/acvm-repo/brillig_vm/src/memory.rs index a8d70a36c80..6c6b5968cd7 100644 --- a/acvm-repo/brillig_vm/src/memory.rs +++ b/acvm-repo/brillig_vm/src/memory.rs @@ -294,6 +294,9 @@ impl Memory { } pub fn read_slice(&self, addr: MemoryAddress, len: usize) -> &[MemoryValue] { + // Allows to read a slice of uninitialized memory if the length is zero. + // Ideally we'd be able to read uninitialized memory in general (as read does) + // but that's not possible if we want to return a slice instead of owned data. if len == 0 { return &[]; } diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 13b42f1c2a1..caaba45398b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -19,7 +19,6 @@ use crate::ssa::{ ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; -use im::HashSet as ImmutableHashSet; /// An arbitrary limit to the maximum number of recursive call /// frames at any point in time. @@ -145,7 +144,7 @@ fn called_functions(func: &Function) -> BTreeSet { fn find_recursive_functions( ssa: &Ssa, current_function: FunctionId, - mut explored_functions: ImmutableHashSet, + mut explored_functions: im::HashSet, recursive_functions: &mut BTreeSet, ) { if explored_functions.contains(¤t_function) { @@ -171,12 +170,7 @@ fn find_recursive_functions( fn find_all_recursive_functions(ssa: &Ssa) -> BTreeSet { let mut recursive_functions = BTreeSet::default(); - find_recursive_functions( - ssa, - ssa.main_id, - ImmutableHashSet::default(), - &mut recursive_functions, - ); + find_recursive_functions(ssa, ssa.main_id, im::HashSet::default(), &mut recursive_functions); recursive_functions } @@ -189,40 +183,28 @@ fn get_functions_to_inline_into( ssa: &Ssa, no_predicates_is_entry_point: bool, ) -> BTreeSet { - let brillig_entry_points: BTreeSet<_> = ssa - .functions - .iter() - .flat_map(|(_, function)| { - if function.runtime() != RuntimeType::Brillig { - called_functions(function) - .into_iter() - .filter(|called_function_id| { - ssa.functions - .get(called_function_id) - .expect("Function should exist in SSA") - .runtime() - == RuntimeType::Brillig - }) - .collect() - } else { - vec![] - } - }) - .collect(); + let mut brillig_entry_points = BTreeSet::default(); + let mut acir_entry_points = BTreeSet::default(); - let acir_entry_points: BTreeSet<_> = ssa - .functions - .iter() - .filter(|(_, function)| { - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be marked as entry points. - let no_predicates_is_entry_point = - no_predicates_is_entry_point && function.is_no_predicates(); - function.runtime() != RuntimeType::Brillig && function.runtime().is_entry_point() - || no_predicates_is_entry_point - }) - .map(|(id, _)| *id) - .collect(); + for (func_id, function) in ssa.functions.iter() { + if function.runtime() == RuntimeType::Brillig { + continue; + } + + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be marked as entry points. + let no_predicates_is_entry_point = + no_predicates_is_entry_point && function.is_no_predicates(); + if function.runtime().is_entry_point() || no_predicates_is_entry_point { + acir_entry_points.insert(*func_id); + } + + for called_function_id in called_functions(function) { + if ssa.functions[&called_function_id].runtime() == RuntimeType::Brillig { + brillig_entry_points.insert(called_function_id); + } + } + } let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa) .into_iter() From a1100469883a86f2f8ad048e04231b5822a292d0 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Wed, 29 May 2024 07:48:26 +0000 Subject: [PATCH 10/12] feat: added test case for recursion in acir inside brillig --- .../acir_inside_brillig_recursion/Nargo.toml | 6 ++++++ .../acir_inside_brillig_recursion/Prover.toml | 1 + .../acir_inside_brillig_recursion/src/main.nr | 15 +++++++++++++++ 3 files changed, 22 insertions(+) create mode 100644 test_programs/execution_success/acir_inside_brillig_recursion/Nargo.toml create mode 100644 test_programs/execution_success/acir_inside_brillig_recursion/Prover.toml create mode 100644 test_programs/execution_success/acir_inside_brillig_recursion/src/main.nr diff --git a/test_programs/execution_success/acir_inside_brillig_recursion/Nargo.toml b/test_programs/execution_success/acir_inside_brillig_recursion/Nargo.toml new file mode 100644 index 00000000000..462532bb484 --- /dev/null +++ b/test_programs/execution_success/acir_inside_brillig_recursion/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "acir_inside_brillig_recursion" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/execution_success/acir_inside_brillig_recursion/Prover.toml b/test_programs/execution_success/acir_inside_brillig_recursion/Prover.toml new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/test_programs/execution_success/acir_inside_brillig_recursion/Prover.toml @@ -0,0 +1 @@ + diff --git a/test_programs/execution_success/acir_inside_brillig_recursion/src/main.nr b/test_programs/execution_success/acir_inside_brillig_recursion/src/main.nr new file mode 100644 index 00000000000..92f8524a771 --- /dev/null +++ b/test_programs/execution_success/acir_inside_brillig_recursion/src/main.nr @@ -0,0 +1,15 @@ +fn main() { + assert_eq(fibonacci(3), fibonacci_hint(3)); +} + +unconstrained fn fibonacci_hint(x: u32) -> u32 { + fibonacci(x) +} + +fn fibonacci(x: u32) -> u32 { + if x <= 1 { + x + } else { + fibonacci(x - 1) + fibonacci(x - 2) + } +} From 31dcd1272e7fad8a5ebb54c9ccee5c9561e958ce Mon Sep 17 00:00:00 2001 From: sirasistant Date: Wed, 29 May 2024 09:32:00 +0000 Subject: [PATCH 11/12] refactor: use import function --- compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index 0483721ab49..bc8bea88ba2 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -115,8 +115,7 @@ impl RuntimeSeparatorContext { unreachable!("Value should be a function") }; if let Some(mapped_func_id) = self.mapped_functions.get(called_func_id) { - let new_target_value = Value::Function(*mapped_func_id); - let mapped_value_id = func.dfg.make_value(new_target_value); + let mapped_value_id = func.dfg.import_function(*mapped_func_id); func.dfg.set_value_from_id(*called_func_value_id, mapped_value_id); } } From 68a3d3981ef4937ba16103f148c325c39b7faf68 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Wed, 29 May 2024 16:06:36 +0000 Subject: [PATCH 12/12] address PR comments --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 49 +++++++++---------- .../src/ssa/opt/runtime_separation.rs | 8 ++- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index caaba45398b..e2a7f51d0a0 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -152,9 +152,7 @@ fn find_recursive_functions( return; } - let called_functions = called_functions( - ssa.functions.get(¤t_function).expect("Function should exist in SSA"), - ); + let called_functions = called_functions(&ssa.functions[¤t_function]); explored_functions.insert(current_function); @@ -209,8 +207,7 @@ fn get_functions_to_inline_into( let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa) .into_iter() .filter(|recursive_function_id| { - let function = - ssa.functions.get(recursive_function_id).expect("Function should exist in SSA"); + let function = &ssa.functions[&recursive_function_id]; function.runtime() == RuntimeType::Brillig }) .collect(); @@ -478,28 +475,10 @@ impl<'function> PerFunctionContext<'function> { match &self.source_function.dfg[*id] { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(func_id) => { - let function = &ssa.functions[&func_id]; - - let should_retain_call = - if let RuntimeType::Acir(inline_type) = function.runtime() { - // If the called function is acir, we inline if it's not an entry point - - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be marked as entry points. - let no_predicates_is_entry_point = - self.context.no_predicates_is_entry_point - && function.is_no_predicates(); - inline_type.is_entry_point() || no_predicates_is_entry_point - } else { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive - ssa.functions[&self.context.entry_point].runtime() - != RuntimeType::Brillig - || self.context.recursive_functions.contains(&func_id) - }; - if should_retain_call { - self.push_instruction(*id); - } else { + if self.should_inline_call(ssa, func_id) { self.inline_function(ssa, *id, func_id, arguments); + } else { + self.push_instruction(*id); } } None => self.push_instruction(*id), @@ -509,6 +488,24 @@ impl<'function> PerFunctionContext<'function> { } } + fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool { + let function = &ssa.functions[&called_func_id]; + + if let RuntimeType::Acir(inline_type) = function.runtime() { + // If the called function is acir, we inline if it's not an entry point + + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be marked as entry points. + let no_predicates_is_entry_point = + self.context.no_predicates_is_entry_point && function.is_no_predicates(); + !inline_type.is_entry_point() && !no_predicates_is_entry_point + } else { + // If the called function is brillig, we inline only if it's into brillig and the function is not recursive + ssa.functions[&self.context.entry_point].runtime() == RuntimeType::Brillig + && !self.context.recursive_functions.contains(&called_func_id) + } + } + /// Inline a function call and remember the inlined return values in the values map fn inline_function( &mut self, diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index bc8bea88ba2..c0c9c0a1372 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -21,7 +21,6 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn separate_runtime(mut self) -> Self { RuntimeSeparatorContext::separate_runtime(&mut self); - self } } @@ -70,7 +69,7 @@ impl RuntimeSeparatorContext { } processed_functions.insert((within_brillig, current_func_id)); - let func = ssa.functions.get(¤t_func_id).expect("Function should exist in SSA"); + let func = &ssa.functions[¤t_func_id]; if func.runtime() == RuntimeType::Brillig { within_brillig = true; } @@ -79,8 +78,7 @@ impl RuntimeSeparatorContext { if within_brillig { for called_func_id in called_functions.iter() { - let called_func = - ssa.functions.get(called_func_id).expect("Function should exist in SSA"); + let called_func = &ssa.functions[&called_func_id]; if matches!(called_func.runtime(), RuntimeType::Acir(_)) { self.acir_functions_called_from_brillig.insert(*called_func_id); } @@ -164,7 +162,7 @@ fn collect_reachable_functions( } reachable_functions.insert(current_func_id); - let func = ssa.functions.get(¤t_func_id).expect("Function should exist in SSA"); + let func = &ssa.functions[¤t_func_id]; let called_functions = called_functions(func); for called_func_id in called_functions.iter() {