Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sirasistant committed May 29, 2024
1 parent d28c834 commit 68a3d39
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 31 deletions.
49 changes: 23 additions & 26 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ fn find_recursive_functions(
return;
}

let called_functions = called_functions(
ssa.functions.get(&current_function).expect("Function should exist in SSA"),
);
let called_functions = called_functions(&ssa.functions[&current_function]);

explored_functions.insert(current_function);

Expand Down Expand Up @@ -209,8 +207,7 @@ fn get_functions_to_inline_into(
let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa)
.into_iter()
.filter(|recursive_function_id| {
let function =
ssa.functions.get(recursive_function_id).expect("Function should exist in SSA");
let function = &ssa.functions[&recursive_function_id];
function.runtime() == RuntimeType::Brillig
})
.collect();
Expand Down Expand Up @@ -478,28 +475,10 @@ impl<'function> PerFunctionContext<'function> {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
let function = &ssa.functions[&func_id];

let should_retain_call =
if let RuntimeType::Acir(inline_type) = function.runtime() {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be marked as entry points.
let no_predicates_is_entry_point =
self.context.no_predicates_is_entry_point
&& function.is_no_predicates();
inline_type.is_entry_point() || no_predicates_is_entry_point
} else {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[&self.context.entry_point].runtime()
!= RuntimeType::Brillig
|| self.context.recursive_functions.contains(&func_id)
};
if should_retain_call {
self.push_instruction(*id);
} else {
if self.should_inline_call(ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments);
} else {
self.push_instruction(*id);
}
}
None => self.push_instruction(*id),
Expand All @@ -509,6 +488,24 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool {
let function = &ssa.functions[&called_func_id];

if let RuntimeType::Acir(inline_type) = function.runtime() {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be marked as entry points.
let no_predicates_is_entry_point =
self.context.no_predicates_is_entry_point && function.is_no_predicates();
!inline_type.is_entry_point() && !no_predicates_is_entry_point
} else {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[&self.context.entry_point].runtime() == RuntimeType::Brillig
&& !self.context.recursive_functions.contains(&called_func_id)
}
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
Expand Down
8 changes: 3 additions & 5 deletions compiler/noirc_evaluator/src/ssa/opt/runtime_separation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ impl Ssa {
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn separate_runtime(mut self) -> Self {
RuntimeSeparatorContext::separate_runtime(&mut self);

self
}
}
Expand Down Expand Up @@ -70,7 +69,7 @@ impl RuntimeSeparatorContext {
}
processed_functions.insert((within_brillig, current_func_id));

let func = ssa.functions.get(&current_func_id).expect("Function should exist in SSA");
let func = &ssa.functions[&current_func_id];
if func.runtime() == RuntimeType::Brillig {
within_brillig = true;
}
Expand All @@ -79,8 +78,7 @@ impl RuntimeSeparatorContext {

if within_brillig {
for called_func_id in called_functions.iter() {
let called_func =
ssa.functions.get(called_func_id).expect("Function should exist in SSA");
let called_func = &ssa.functions[&called_func_id];
if matches!(called_func.runtime(), RuntimeType::Acir(_)) {
self.acir_functions_called_from_brillig.insert(*called_func_id);
}
Expand Down Expand Up @@ -164,7 +162,7 @@ fn collect_reachable_functions(
}
reachable_functions.insert(current_func_id);

let func = ssa.functions.get(&current_func_id).expect("Function should exist in SSA");
let func = &ssa.functions[&current_func_id];
let called_functions = called_functions(func);

for called_func_id in called_functions.iter() {
Expand Down

0 comments on commit 68a3d39

Please sign in to comment.