diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index caaba45398b..e2a7f51d0a0 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -152,9 +152,7 @@ fn find_recursive_functions( return; } - let called_functions = called_functions( - ssa.functions.get(¤t_function).expect("Function should exist in SSA"), - ); + let called_functions = called_functions(&ssa.functions[¤t_function]); explored_functions.insert(current_function); @@ -209,8 +207,7 @@ fn get_functions_to_inline_into( let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa) .into_iter() .filter(|recursive_function_id| { - let function = - ssa.functions.get(recursive_function_id).expect("Function should exist in SSA"); + let function = &ssa.functions[&recursive_function_id]; function.runtime() == RuntimeType::Brillig }) .collect(); @@ -478,28 +475,10 @@ impl<'function> PerFunctionContext<'function> { match &self.source_function.dfg[*id] { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(func_id) => { - let function = &ssa.functions[&func_id]; - - let should_retain_call = - if let RuntimeType::Acir(inline_type) = function.runtime() { - // If the called function is acir, we inline if it's not an entry point - - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be marked as entry points. - let no_predicates_is_entry_point = - self.context.no_predicates_is_entry_point - && function.is_no_predicates(); - inline_type.is_entry_point() || no_predicates_is_entry_point - } else { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive - ssa.functions[&self.context.entry_point].runtime() - != RuntimeType::Brillig - || self.context.recursive_functions.contains(&func_id) - }; - if should_retain_call { - self.push_instruction(*id); - } else { + if self.should_inline_call(ssa, func_id) { self.inline_function(ssa, *id, func_id, arguments); + } else { + self.push_instruction(*id); } } None => self.push_instruction(*id), @@ -509,6 +488,24 @@ impl<'function> PerFunctionContext<'function> { } } + fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool { + let function = &ssa.functions[&called_func_id]; + + if let RuntimeType::Acir(inline_type) = function.runtime() { + // If the called function is acir, we inline if it's not an entry point + + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be marked as entry points. + let no_predicates_is_entry_point = + self.context.no_predicates_is_entry_point && function.is_no_predicates(); + !inline_type.is_entry_point() && !no_predicates_is_entry_point + } else { + // If the called function is brillig, we inline only if it's into brillig and the function is not recursive + ssa.functions[&self.context.entry_point].runtime() == RuntimeType::Brillig + && !self.context.recursive_functions.contains(&called_func_id) + } + } + /// Inline a function call and remember the inlined return values in the values map fn inline_function( &mut self, diff --git a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs index bc8bea88ba2..c0c9c0a1372 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs @@ -21,7 +21,6 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn separate_runtime(mut self) -> Self { RuntimeSeparatorContext::separate_runtime(&mut self); - self } } @@ -70,7 +69,7 @@ impl RuntimeSeparatorContext { } processed_functions.insert((within_brillig, current_func_id)); - let func = ssa.functions.get(¤t_func_id).expect("Function should exist in SSA"); + let func = &ssa.functions[¤t_func_id]; if func.runtime() == RuntimeType::Brillig { within_brillig = true; } @@ -79,8 +78,7 @@ impl RuntimeSeparatorContext { if within_brillig { for called_func_id in called_functions.iter() { - let called_func = - ssa.functions.get(called_func_id).expect("Function should exist in SSA"); + let called_func = &ssa.functions[&called_func_id]; if matches!(called_func.runtime(), RuntimeType::Acir(_)) { self.acir_functions_called_from_brillig.insert(*called_func_id); } @@ -164,7 +162,7 @@ fn collect_reachable_functions( } reachable_functions.insert(current_func_id); - let func = ssa.functions.get(¤t_func_id).expect("Function should exist in SSA"); + let func = &ssa.functions[¤t_func_id]; let called_functions = called_functions(func); for called_func_id in called_functions.iter() {