Skip to content

Commit

Permalink
feat: allow main to be a brillig function (#1861)
Browse files Browse the repository at this point in the history
feat(brillig): wrap brillig fns to be top level
  • Loading branch information
sirasistant authored Jul 4, 2023
1 parent cb607f5 commit 1330a2a
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "1"
array = ["4", "5", "6"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Tests a very simple program.
//
// The feature being tested is brillig as the entry point.

unconstrained fn main(array: [Field; 3], x: pub Field) -> pub [Field; 2] {
[array[x], array[x + 1]]
}
5 changes: 1 addition & 4 deletions crates/noirc_evaluator/src/brillig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ impl Ssa {

let mut brillig = Brillig::default();
for brillig_function in brillig_functions {
// TODO: document why we are skipping the `main_id` for Brillig functions
if brillig_function.id() != self.main_id {
brillig.compile(brillig_function);
}
brillig.compile(brillig_function);
}

brillig
Expand Down
37 changes: 20 additions & 17 deletions crates/noirc_evaluator/src/ssa_refactor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use noirc_abi::Abi;

use noirc_frontend::monomorphization::ast::Program;

use self::{abi_gen::gen_abi, acir_gen::GeneratedAcir, ssa_gen::Ssa};
use self::{abi_gen::gen_abi, acir_gen::GeneratedAcir, ir::function::RuntimeType, ssa_gen::Ssa};

mod abi_gen;
mod acir_gen;
Expand All @@ -34,23 +34,26 @@ pub(crate) fn optimize_into_acir(
print_ssa_passes: bool,
) -> GeneratedAcir {
let abi_distinctness = program.return_distinctness;
let ssa = ssa_gen::generate_ssa(program).print(print_ssa_passes, "Initial SSA:");
let mut ssa = ssa_gen::generate_ssa(program).print(print_ssa_passes, "Initial SSA:");
let brillig = ssa.to_brillig();
ssa.inline_functions()
.print(print_ssa_passes, "After Inlining:")
.unroll_loops()
.print(print_ssa_passes, "After Unrolling:")
.simplify_cfg()
.print(print_ssa_passes, "After Simplifying:")
.flatten_cfg()
.print(print_ssa_passes, "After Flattening:")
.mem2reg()
.print(print_ssa_passes, "After Mem2Reg:")
.fold_constants()
.print(print_ssa_passes, "After Constant Folding:")
.dead_instruction_elimination()
.print(print_ssa_passes, "After Dead Instruction Elimination:")
.into_acir(brillig, abi_distinctness, allow_log_ops)
if let RuntimeType::Acir = ssa.main().runtime() {
ssa = ssa
.inline_functions()
.print(print_ssa_passes, "After Inlining:")
.unroll_loops()
.print(print_ssa_passes, "After Unrolling:")
.simplify_cfg()
.print(print_ssa_passes, "After Simplifying:")
.flatten_cfg()
.print(print_ssa_passes, "After Flattening:")
.mem2reg()
.print(print_ssa_passes, "After Mem2Reg:")
.fold_constants()
.print(print_ssa_passes, "After Constant Folding:")
.dead_instruction_elimination()
.print(print_ssa_passes, "After Dead Instruction Elimination:");
}
ssa.into_acir(brillig, abi_distinctness, allow_log_ops)
}

/// Compiles the Program into ACIR and applies optimizations to the arithmetic gates
Expand Down
92 changes: 79 additions & 13 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use self::acir_ir::{
use super::{
ir::{
dfg::DataFlowGraph,
function::RuntimeType,
function::{Function, RuntimeType},
instruction::{
Binary, BinaryOp, Instruction, InstructionId, Intrinsic, TerminatorInstruction,
},
Expand All @@ -21,7 +21,10 @@ use super::{
},
ssa_gen::Ssa,
};
use acvm::{acir::native_types::Expression, FieldElement};
use acvm::{
acir::{brillig_vm::Opcode, native_types::Expression},
FieldElement,
};
use iter_extended::vecmap;

pub(crate) use acir_ir::generated_acir::GeneratedAcir;
Expand Down Expand Up @@ -105,22 +108,63 @@ impl Ssa {

impl Context {
/// Converts SSA into ACIR
fn convert_ssa(mut self, ssa: Ssa, brillig: Brillig, allow_log_ops: bool) -> GeneratedAcir {
fn convert_ssa(self, ssa: Ssa, brillig: Brillig, allow_log_ops: bool) -> GeneratedAcir {
let main_func = ssa.main();
match main_func.runtime() {
RuntimeType::Acir => self.convert_acir_main(main_func, &ssa, brillig, allow_log_ops),
RuntimeType::Brillig => self.convert_brillig_main(main_func, brillig),
}
}

fn convert_acir_main(
mut self,
main_func: &Function,
ssa: &Ssa,
brillig: Brillig,
allow_log_ops: bool,
) -> GeneratedAcir {
let dfg = &main_func.dfg;
let entry_block = &dfg[main_func.entry_block()];

self.convert_ssa_block_params(entry_block.parameters(), dfg);

for instruction_id in entry_block.instructions() {
self.convert_ssa_instruction(*instruction_id, dfg, &ssa, &brillig, allow_log_ops);
self.convert_ssa_instruction(*instruction_id, dfg, ssa, &brillig, allow_log_ops);
}

self.convert_ssa_return(entry_block.terminator().unwrap(), dfg);

self.acir_context.finish()
}

fn convert_brillig_main(mut self, main_func: &Function, brillig: Brillig) -> GeneratedAcir {
let dfg = &main_func.dfg;

let inputs = vecmap(dfg[main_func.entry_block()].parameters(), |param_id| {
let typ = dfg.type_of_value(*param_id);
self.create_value_from_type(&typ, &mut |this, _| this.acir_context.add_variable())
});

let outputs: Vec<AcirType> = vecmap(self.get_return_values(main_func), |result_id| {
dfg.type_of_value(result_id).into()
});

let code = self.gen_brillig_for(main_func, &brillig);

let output_values = self.acir_context.brillig(None, code, inputs, outputs);
let output_vars: Vec<_> = output_values
.iter()
.flat_map(|value| value.clone().flatten())
.map(|value| value.0)
.collect();

for acir_var in output_vars {
self.acir_context.return_var(acir_var);
}

self.acir_context.finish()
}

/// Adds and binds `AcirVar`s for each numeric block parameter or block parameter array element.
fn convert_ssa_block_params(&mut self, params: &[ValueId], dfg: &DataFlowGraph) {
for param_id in params {
Expand Down Expand Up @@ -216,15 +260,7 @@ impl Context {
RuntimeType::Brillig => {
let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg));

// Create the entry point artifact
let mut entry_point = BrilligArtifact::to_entry_point_artifact(&brillig[*id]);
// Link the entry point with all dependencies
while let Some(unresolved_fn_label) = entry_point.first_unresolved_function_call() {
let artifact = &brillig.find_by_function_label(unresolved_fn_label.clone()).expect("Cannot find linked fn {unresolved_fn_label}");
entry_point.link_with(unresolved_fn_label, artifact);
}
// Generate the final bytecode
let code = entry_point.finish();
let code = self.gen_brillig_for(func, brillig);

let outputs: Vec<AcirType> = vecmap(result_ids, |result_id| dfg.type_of_value(*result_id).into());

Expand Down Expand Up @@ -300,6 +336,20 @@ impl Context {
}
}

fn gen_brillig_for(&self, func: &Function, brillig: &Brillig) -> Vec<Opcode> {
// Create the entry point artifact
let mut entry_point = BrilligArtifact::to_entry_point_artifact(&brillig[func.id()]);
// Link the entry point with all dependencies
while let Some(unresolved_fn_label) = entry_point.first_unresolved_function_call() {
let artifact = &brillig
.find_by_function_label(unresolved_fn_label.clone())
.expect("Cannot find linked fn {unresolved_fn_label}");
entry_point.link_with(unresolved_fn_label, artifact);
}
// Generate the final bytecode
entry_point.finish()
}

/// Handles an ArrayGet or ArraySet instruction.
/// To set an index of the array (and create a new array in doing so), pass Some(value) for
/// store_value. To just retrieve an index of the array, pass None for store_value.
Expand Down Expand Up @@ -366,6 +416,22 @@ impl Context {
self.define_result(dfg, instruction, AcirValue::Var(result, typ));
}

/// Finds the return values of a given function
fn get_return_values(&self, func: &Function) -> Vec<ValueId> {
let blocks = func.reachable_blocks();
let mut function_return_values = None;
for block in blocks {
let terminator = func.dfg[block].terminator();
if let Some(TerminatorInstruction::Return { return_values }) = terminator {
function_return_values = Some(return_values);
break;
}
}
function_return_values
.expect("Expected a return instruction, as block is finished construction")
.clone()
}

/// Converts an SSA terminator's return values into their ACIR representations
fn convert_ssa_return(&mut self, terminator: &TerminatorInstruction, dfg: &DataFlowGraph) {
let return_values = match terminator {
Expand Down
8 changes: 6 additions & 2 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ pub(crate) fn generate_ssa(program: Program) -> Ssa {
// Queue the main function for compilation
context.get_or_queue_function(main_id);

let mut function_context =
FunctionContext::new(main.name.clone(), &main.parameters, RuntimeType::Acir, &context);
let mut function_context = FunctionContext::new(
main.name.clone(),
&main.parameters,
if main.unconstrained { RuntimeType::Brillig } else { RuntimeType::Acir },
&context,
);
function_context.codegen_function_body(&main.body);

// Main has now been compiled and any other functions referenced within have been added to the
Expand Down

0 comments on commit 1330a2a

Please sign in to comment.