Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Separate runtimes of SSA functions before inlining #5121

Merged
merged 15 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions acvm-repo/brillig_vm/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ impl<F: AcirField> Memory<F> {
}

pub fn read_slice(&self, addr: MemoryAddress, len: usize) -> &[MemoryValue<F>] {
// Allows to read a slice of uninitialized memory if the length is zero.
// Ideally we'd be able to read uninitialized memory in general (as read does)
// but that's not possible if we want to return a slice instead of owned data.
if len == 0 {
return &[];
}
jfecher marked this conversation as resolved.
Show resolved Hide resolved
&self.inner[addr.to_usize()..(addr.to_usize() + len)]
}

Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ pub(crate) fn optimize_into_acir(
let ssa = SsaBuilder::new(program, print_passes, force_brillig_output, print_timings)?
.run_pass(Ssa::defunctionalize, "After Defunctionalization:")
.run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:")
.run_pass(Ssa::inline_functions, "After Inlining:")
.run_pass(Ssa::separate_runtime, "After Runtime Separation:")
.run_pass(Ssa::resolve_is_unconstrained, "After Resolving IsUnconstrained:")
.run_pass(Ssa::inline_functions, "After Inlining:")
// Run mem2reg with the CFG separated into blocks
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization")
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use noirc_errors::Location;
/// its blocks, instructions, and values. This struct is largely responsible for
/// owning most data in a function and handing out Ids to this data that can be
/// shared without worrying about ownership.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub(crate) struct DataFlowGraph {
/// All of the instructions in a function
instructions: DenseMap<Instruction>,
Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl Function {
Self { name, id, entry_block, dfg, runtime: RuntimeType::Acir(InlineType::default()) }
}

/// Creates a new function as a clone of the one passed in with the passed in id.
pub(crate) fn clone_with_id(id: FunctionId, another: &Function) -> Self {
let dfg = another.dfg.clone();
let entry_block = another.entry_block;
Self { name: another.name.clone(), id, entry_block, dfg, runtime: another.runtime }
}

/// The name of the function.
/// Used exclusively for debugging purposes.
pub(crate) fn name(&self) -> &str {
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl std::fmt::Display for Id<super::instruction::Instruction> {
/// access to indices is provided. Since IDs must be stable and correspond
/// to indices in the internal Vec, operations that would change element
/// ordering like pop, remove, swap_remove, etc, are not possible.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct DenseMap<T> {
storage: Vec<T>,
}
Expand Down
160 changes: 127 additions & 33 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const RECURSION_LIMIT: u32 = 1000;
impl Ssa {
/// Inline all functions within the IR.
///
/// In the case of recursive functions, this will attempt
/// In the case of recursive Acir functions, this will attempt
/// to recursively inline until the RECURSION_LIMIT is reached.
///
/// Functions are recursively inlined into main until either we finish
Expand All @@ -41,6 +41,8 @@ impl Ssa {
/// There are some attributes that allow inlining a function at a different step of codegen.
/// Currently this is just `InlineType::NoPredicates` for which we have a flag indicating
/// whether treating that inline functions. The default is to treat these functions as entry points.
///
/// This step should run after runtime separation, since it relies on the runtime of the called functions being final.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn inline_functions(self) -> Ssa {
Self::inline_functions_inner(self, true)
Expand All @@ -52,12 +54,17 @@ impl Ssa {
}

fn inline_functions_inner(mut self, no_predicates_is_entry_point: bool) -> Ssa {
let recursive_functions = find_all_recursive_functions(&self);
self.functions = btree_map(
get_entry_point_functions(&self, no_predicates_is_entry_point),
get_functions_to_inline_into(&self, no_predicates_is_entry_point),
|entry_point| {
let new_function =
InlineContext::new(&self, entry_point, no_predicates_is_entry_point)
.inline_all(&self);
let new_function = InlineContext::new(
&self,
entry_point,
no_predicates_is_entry_point,
recursive_functions.clone(),
)
.inline_all(&self);
(entry_point, new_function)
},
);
Expand All @@ -80,6 +87,8 @@ struct InlineContext {
entry_point: FunctionId,

no_predicates_is_entry_point: bool,
// We keep track of the recursive functions in the SSA to avoid inlining them in a brillig context.
recursive_functions: BTreeSet<FunctionId>,
}

/// The per-function inlining context contains information that is only valid for one function.
Expand Down Expand Up @@ -113,28 +122,101 @@ struct PerFunctionContext<'function> {
inlining_entry: bool,
}

/// The entry point functions are each function we should inline into - and each function that
/// should be left in the final program.
/// This is the `main` function, any Acir functions with a [fold inline type][InlineType::Fold],
/// and any brillig functions used.
fn get_entry_point_functions(
/// Utility function to find out the direct calls of a function.
fn called_functions(func: &Function) -> BTreeSet<FunctionId> {
let mut called_function_ids = BTreeSet::default();
for block_id in func.reachable_blocks() {
for instruction_id in func.dfg[block_id].instructions() {
let Instruction::Call { func: called_value_id, .. } = &func.dfg[*instruction_id] else {
continue;
};

if let Value::Function(function_id) = func.dfg[*called_value_id] {
called_function_ids.insert(function_id);
}
}
}

called_function_ids
}

// Recursively explore the SSA to find the functions that end up calling themselves
fn find_recursive_functions(
ssa: &Ssa,
current_function: FunctionId,
mut explored_functions: im::HashSet<FunctionId>,
recursive_functions: &mut BTreeSet<FunctionId>,
) {
if explored_functions.contains(&current_function) {
recursive_functions.insert(current_function);
return;
}

let called_functions = called_functions(&ssa.functions[&current_function]);

explored_functions.insert(current_function);

for called_function in called_functions {
find_recursive_functions(
ssa,
called_function,
explored_functions.clone(),
recursive_functions,
);
}
}

fn find_all_recursive_functions(ssa: &Ssa) -> BTreeSet<FunctionId> {
let mut recursive_functions = BTreeSet::default();
find_recursive_functions(ssa, ssa.main_id, im::HashSet::default(), &mut recursive_functions);
recursive_functions
}

/// The functions we should inline into (and that should be left in the final program) are:
/// - main
/// - Any Brillig function called from Acir
/// - Any Brillig recursive function (Acir recursive functions will be inlined into the main function)
/// - Any Acir functions with a [fold inline type][InlineType::Fold],
fn get_functions_to_inline_into(
ssa: &Ssa,
no_predicates_is_entry_point: bool,
) -> BTreeSet<FunctionId> {
let functions = ssa.functions.iter();
let mut entry_points = functions
.filter(|(_, function)| {
// If we have not already finished the flattening pass, functions marked
// to not have predicates should be marked as entry points.
let no_predicates_is_entry_point =
no_predicates_is_entry_point && function.is_no_predicates();
function.runtime().is_entry_point() || no_predicates_is_entry_point
let mut brillig_entry_points = BTreeSet::default();
let mut acir_entry_points = BTreeSet::default();

for (func_id, function) in ssa.functions.iter() {
if function.runtime() == RuntimeType::Brillig {
continue;
}

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be marked as entry points.
let no_predicates_is_entry_point =
no_predicates_is_entry_point && function.is_no_predicates();
if function.runtime().is_entry_point() || no_predicates_is_entry_point {
acir_entry_points.insert(*func_id);
}

for called_function_id in called_functions(function) {
if ssa.functions[&called_function_id].runtime() == RuntimeType::Brillig {
brillig_entry_points.insert(called_function_id);
}
}
}

let brillig_recursive_functions: BTreeSet<_> = find_all_recursive_functions(ssa)
.into_iter()
.filter(|recursive_function_id| {
let function = &ssa.functions[&recursive_function_id];
function.runtime() == RuntimeType::Brillig
})
.map(|(id, _)| *id)
.collect::<BTreeSet<_>>();
.collect();

entry_points.insert(ssa.main_id);
entry_points
std::iter::once(ssa.main_id)
.chain(acir_entry_points)
.chain(brillig_entry_points)
.chain(brillig_recursive_functions)
.collect()
}

impl InlineContext {
Expand All @@ -147,6 +229,7 @@ impl InlineContext {
ssa: &Ssa,
entry_point: FunctionId,
no_predicates_is_entry_point: bool,
recursive_functions: BTreeSet<FunctionId>,
) -> InlineContext {
let source = &ssa.functions[&entry_point];
let mut builder = FunctionBuilder::new(source.name().to_owned(), entry_point);
Expand All @@ -157,6 +240,7 @@ impl InlineContext {
entry_point,
call_stack: CallStack::new(),
no_predicates_is_entry_point,
recursive_functions,
}
}

Expand Down Expand Up @@ -391,18 +475,10 @@ impl<'function> PerFunctionContext<'function> {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
let function = &ssa.functions[&func_id];
// If we have not already finished the flattening pass, functions marked
// to not have predicates should be marked as entry points unless we are inlining into brillig.
let entry_point = &ssa.functions[&self.context.entry_point];
let no_predicates_is_entry_point =
self.context.no_predicates_is_entry_point
&& function.is_no_predicates()
&& !matches!(entry_point.runtime(), RuntimeType::Brillig);
if function.runtime().is_entry_point() || no_predicates_is_entry_point {
self.push_instruction(*id);
} else {
if self.should_inline_call(ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments);
} else {
self.push_instruction(*id);
}
}
None => self.push_instruction(*id),
Expand All @@ -412,6 +488,24 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool {
let function = &ssa.functions[&called_func_id];

if let RuntimeType::Acir(inline_type) = function.runtime() {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be marked as entry points.
let no_predicates_is_entry_point =
self.context.no_predicates_is_entry_point && function.is_no_predicates();
!inline_type.is_entry_point() && !no_predicates_is_entry_point
} else {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[&self.context.entry_point].runtime() == RuntimeType::Brillig
&& !self.context.recursive_functions.contains(&called_func_id)
}
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ mod remove_bit_shifts;
mod remove_enable_side_effects;
mod remove_if_else;
mod resolve_is_unconstrained;
mod runtime_separation;
mod simplify_cfg;
mod unrolling;
Loading
Loading