Skip to content

Commit

Permalink
feat: defunctionalization pass for ssa refactor (#1870)
Browse files Browse the repository at this point in the history
* feat: defunctionalization pass

* Apply suggestions from self code review

* feat: optimize apply function generation & usage

* feat: avoid unary apply fns

* fix: clippy

* style: cleanup after peer review

* docs: updated comments on defunctionalize

* style: apply suggestions from peer review

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* style: rename field

* docs: fixed doc to avoid doctest

* style: addressed pr comments

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* refactor: extract set type of value to the dfg

* Update crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs

---------

Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
sirasistant and jfecher authored Jul 7, 2023
1 parent b21d1e2 commit 1d5d84d
Show file tree
Hide file tree
Showing 16 changed files with 457 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ fn test_multiple6(a: my2, b: my_struct, c: (my2, my_struct)) {
}


fn foo(a: [Field]) -> [Field] {

fn foo<N>(a: [Field; N]) -> [Field; N] {
a
}
fn bar() -> [Field] {

fn bar() -> [Field; 1] {
foo([0])
}

fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) {
let mut ss: my_struct = my_struct { b: x, a: x+2, };
test_multiple4(ss);
Expand Down Expand Up @@ -134,7 +137,6 @@ fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) {
assert(result[0] == arr1[0] as Field);
}


// Issue #628
fn arr_to_field(arr: [u32; 9]) -> [Field; 9] {
let mut as_field: [Field; 9] = [0 as Field; 9];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
struct MyStruct {
operation: fn (u32) -> u32,
}

fn main(x: u32) {
assert(wrapper(increment, x) == x + 1);
assert(wrapper(decrement, x) == x - 1);
assert(wrapper_with_struct(MyStruct { operation: increment }, x) == x + 1);
assert(wrapper_with_struct(MyStruct { operation: decrement }, x) == x - 1);
}

unconstrained fn wrapper(func: fn (u32) -> u32, param: u32) -> u32 {
func(param)
}

unconstrained fn increment(x: u32) -> u32 {
x + 1
}

unconstrained fn decrement(x: u32) -> u32 {
x - 1
}

unconstrained fn wrapper_with_struct(my_struct: MyStruct, param: u32) -> u32 {
let func = my_struct.operation;
func(param)
}

13 changes: 1 addition & 12 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::{
},
ssa_refactor::ir::{
function::{Function, FunctionId},
instruction::TerminatorInstruction,
types::Type,
value::ValueId,
},
Expand Down Expand Up @@ -71,17 +70,7 @@ impl FunctionContext {

/// Collects the return values of a given function
pub(crate) fn return_values(func: &Function) -> Vec<BrilligParameter> {
let blocks = func.reachable_blocks();
let mut function_return_values = None;
for block in blocks {
let terminator = func.dfg[block].terminator();
if let Some(TerminatorInstruction::Return { return_values }) = terminator {
function_return_values = Some(return_values);
break;
}
}
function_return_values
.expect("Expected a return instruction, as block is finished construction")
func.returns()
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
Expand Down
4 changes: 2 additions & 2 deletions crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl BrilligArtifact {
/// This method will offset the positions in the Brillig artifact to
/// account for the fact that it is being appended to the end of this
/// Brillig artifact (self).
pub(crate) fn link_with(&mut self, func_label: Label, obj: &BrilligArtifact) {
pub(crate) fn link_with(&mut self, obj: &BrilligArtifact) {
// Add the unresolved jumps of the linked function to this artifact.
self.add_unresolved_jumps_and_calls(obj);

Expand All @@ -169,7 +169,7 @@ impl BrilligArtifact {
self.byte_code.append(&mut byte_code);

// Remove all resolved external calls and transform them to jumps
let is_resolved = |label: &Label| label == &func_label;
let is_resolved = |label: &Label| self.labels.get(label).is_some();

let resolved_external_calls = self
.unresolved_external_call_labels
Expand Down
6 changes: 5 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ pub(crate) fn optimize_into_acir(
print_ssa_passes: bool,
) -> GeneratedAcir {
let abi_distinctness = program.return_distinctness;
let mut ssa = ssa_gen::generate_ssa(program).print(print_ssa_passes, "Initial SSA:");
let mut ssa = ssa_gen::generate_ssa(program)
.print(print_ssa_passes, "Initial SSA:")
.defunctionalize()
.print(print_ssa_passes, "After Defunctionalization:");

let brillig = ssa.to_brillig();
if let RuntimeType::Acir = ssa.main().runtime() {
ssa = ssa
Expand Down
25 changes: 4 additions & 21 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ impl Context {
self.create_value_from_type(&typ, &mut |this, _| this.acir_context.add_variable())
});

let outputs: Vec<AcirType> = vecmap(self.get_return_values(main_func), |result_id| {
dfg.type_of_value(result_id).into()
});
let outputs: Vec<AcirType> =
vecmap(main_func.returns(), |result_id| dfg.type_of_value(*result_id).into());

let code = self.gen_brillig_for(main_func, &brillig);

Expand Down Expand Up @@ -351,7 +350,7 @@ impl Context {
let artifact = &brillig
.find_by_function_label(unresolved_fn_label.clone())
.expect("Cannot find linked fn {unresolved_fn_label}");
entry_point.link_with(unresolved_fn_label, artifact);
entry_point.link_with(artifact);
}
// Generate the final bytecode
entry_point.finish()
Expand Down Expand Up @@ -423,22 +422,6 @@ impl Context {
self.define_result(dfg, instruction, AcirValue::Var(result, typ));
}

/// Finds the return values of a given function
fn get_return_values(&self, func: &Function) -> Vec<ValueId> {
let blocks = func.reachable_blocks();
let mut function_return_values = None;
for block in blocks {
let terminator = func.dfg[block].terminator();
if let Some(TerminatorInstruction::Return { return_values }) = terminator {
function_return_values = Some(return_values);
break;
}
}
function_return_values
.expect("Expected a return instruction, as block is finished construction")
.clone()
}

/// Converts an SSA terminator's return values into their ACIR representations
fn convert_ssa_return(&mut self, terminator: &TerminatorInstruction, dfg: &DataFlowGraph) {
let return_values = match terminator {
Expand Down Expand Up @@ -786,7 +769,7 @@ impl Context {
}

/// Convert a Vec<AcirVar> into a Vec<AcirValue> using the given result ids.
/// If the type of a result id is an array, several acirvars are collected into
/// If the type of a result id is an array, several acir vars are collected into
/// a single AcirValue::Array of the same length.
fn convert_vars_to_values(
vars: Vec<AcirVar>,
Expand Down
20 changes: 20 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ impl DataFlowGraph {
self.blocks.iter()
}

/// Iterate over every Value in this DFG in no particular order, including unused Values
pub(crate) fn values_iter(&self) -> impl ExactSizeIterator<Item = (ValueId, &Value)> {
self.values.iter()
}

/// Returns the parameters of the given block
pub(crate) fn block_parameters(&self, block: BasicBlockId) -> &[ValueId] {
self.blocks[block].parameters()
Expand Down Expand Up @@ -169,6 +174,21 @@ impl DataFlowGraph {
}
}

/// Set the type of value_id to the target_type.
pub(crate) fn set_type_of_value(&mut self, value_id: ValueId, target_type: Type) {
let value = &mut self.values[value_id];
match value {
Value::Instruction { typ, .. }
| Value::Param { typ, .. }
| Value::NumericConstant { typ, .. } => {
*typ = target_type;
}
_ => {
unreachable!("ICE: Cannot set type of {:?}", value);
}
}
}

/// If `original_value_id`'s underlying `Value` has been substituted for that of another
/// `ValueId`, this function will return the `ValueId` from which the substitution was taken.
/// If `original_value_id`'s underlying `Value` has not been substituted, the same `ValueId`
Expand Down
29 changes: 27 additions & 2 deletions crates/noirc_evaluator/src/ssa_refactor/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use std::collections::HashSet;

use super::basic_block::BasicBlockId;
use super::dfg::DataFlowGraph;
use super::instruction::TerminatorInstruction;
use super::map::Id;
use super::types::Type;
use super::value::ValueId;

#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub(crate) enum RuntimeType {
// A noir function, to be compiled in ACIR and executed by ACVM
Acir,
Expand Down Expand Up @@ -60,7 +61,7 @@ impl Function {

/// Runtime type of the function.
pub(crate) fn runtime(&self) -> RuntimeType {
self.runtime.clone()
self.runtime
}

/// Set runtime type of the function.
Expand All @@ -84,6 +85,21 @@ impl Function {
self.dfg.block_parameters(self.entry_block)
}

/// Returns the return types of this function.
pub(crate) fn returns(&self) -> &[ValueId] {
let blocks = self.reachable_blocks();
let mut function_return_values = None;
for block in blocks {
let terminator = self.dfg[block].terminator();
if let Some(TerminatorInstruction::Return { return_values }) = terminator {
function_return_values = Some(return_values);
break;
}
}
function_return_values
.expect("Expected a return instruction, as function construction is finished")
}

/// Collects all the reachable blocks of this function.
///
/// Note that self.dfg.basic_blocks_iter() iterates over all blocks,
Expand All @@ -102,6 +118,15 @@ impl Function {
}
}

impl std::fmt::Display for RuntimeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RuntimeType::Acir => write!(f, "acir"),
RuntimeType::Brillig => write!(f, "brillig"),
}
}
}

/// FunctionId is a reference for a function
///
/// This Id is how each function refers to other functions
Expand Down
5 changes: 5 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ir/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ impl<T> Id<T> {
Self { index, _marker: std::marker::PhantomData }
}

/// Returns the underlying index of this Id.
pub(crate) fn to_usize(self) -> usize {
self.index
}

/// Creates a test Id with the given index.
/// The name of this function makes it apparent it should only
/// be used for testing. Obtaining Ids in this way should be avoided
Expand Down
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::{

/// Helper function for Function's Display impl to pretty-print the function with the given formatter.
pub(crate) fn display_function(function: &Function, f: &mut Formatter) -> Result {
writeln!(f, "fn {} {} {{", function.name(), function.id())?;
writeln!(f, "{} fn {} {} {{", function.runtime(), function.name(), function.id())?;
display_block_with_successors(function, function.entry_block(), &mut HashSet::new(), f)?;
write!(f, "}}")
}
Expand Down
Loading

0 comments on commit 1d5d84d

Please sign in to comment.