From 1d5d84dd6db650aa9c136d3e9746a6544cf13945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Fri, 7 Jul 2023 11:27:44 +0200 Subject: [PATCH] feat: defunctionalization pass for ssa refactor (#1870) * feat: defunctionalization pass * Apply suggestions from self code review * feat: optimize apply function generation & usage * feat: avoid unary apply fns * fix: clippy * style: cleanup after peer review * docs: updated comments on defunctionalize * style: apply suggestions from peer review * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * style: rename field * docs: fixed doc to avoid doctest * style: addressed pr comments * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * Update crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs Co-authored-by: jfecher * refactor: extract set type of value to the dfg * Update crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs --------- Co-authored-by: jfecher --- .../7_function/src/main.nr | 8 +- .../brillig_fns_as_values/Nargo.toml | 5 + .../brillig_fns_as_values/Prover.toml | 1 + .../brillig_fns_as_values/src/main.nr | 28 ++ .../src/brillig/brillig_gen/brillig_fn.rs | 13 +- .../src/brillig/brillig_ir/artifact.rs | 4 +- crates/noirc_evaluator/src/ssa_refactor.rs | 6 +- .../src/ssa_refactor/acir_gen/mod.rs | 25 +- .../src/ssa_refactor/ir/dfg.rs | 20 ++ .../src/ssa_refactor/ir/function.rs | 29 +- .../src/ssa_refactor/ir/map.rs | 5 + .../src/ssa_refactor/ir/printer.rs | 2 +- .../src/ssa_refactor/opt/defunctionalize.rs | 337 ++++++++++++++++++ .../src/ssa_refactor/opt/mod.rs | 1 + .../src/ssa_refactor/ssa_gen/program.rs | 11 + cspell.json | 4 + 16 files changed, 457 insertions(+), 42 deletions(-) create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml create mode 100644 crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr create mode 100644 crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr index 5a23b493871..26ecf6dda28 100644 --- a/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/7_function/src/main.nr @@ -88,12 +88,15 @@ fn test_multiple6(a: my2, b: my_struct, c: (my2, my_struct)) { } -fn foo(a: [Field]) -> [Field] { + +fn foo(a: [Field; N]) -> [Field; N] { a } -fn bar() -> [Field] { + +fn bar() -> [Field; 1] { foo([0]) } + fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { let mut ss: my_struct = my_struct { b: x, a: x+2, }; test_multiple4(ss); @@ -134,7 +137,6 @@ fn main(x: u32 , y: u32 , a: Field, arr1: [u32; 9], arr2: [u32; 9]) { assert(result[0] == arr1[0] as Field); } - // Issue #628 fn arr_to_field(arr: [u32; 9]) -> [Field; 9] { let mut as_field: [Field; 9] = [0 as Field; 9]; diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml new file mode 100644 index 00000000000..e0b467ce5da --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Nargo.toml @@ -0,0 +1,5 @@ +[package] +authors = [""] +compiler_version = "0.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml new file mode 100644 index 00000000000..11497a473bc --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/Prover.toml @@ -0,0 +1 @@ +x = "0" diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr new file mode 100644 index 00000000000..5af542301ec --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/brillig_fns_as_values/src/main.nr @@ -0,0 +1,28 @@ +struct MyStruct { + operation: fn (u32) -> u32, +} + +fn main(x: u32) { + assert(wrapper(increment, x) == x + 1); + assert(wrapper(decrement, x) == x - 1); + assert(wrapper_with_struct(MyStruct { operation: increment }, x) == x + 1); + assert(wrapper_with_struct(MyStruct { operation: decrement }, x) == x - 1); +} + +unconstrained fn wrapper(func: fn (u32) -> u32, param: u32) -> u32 { + func(param) +} + +unconstrained fn increment(x: u32) -> u32 { + x + 1 +} + +unconstrained fn decrement(x: u32) -> u32 { + x - 1 +} + +unconstrained fn wrapper_with_struct(my_struct: MyStruct, param: u32) -> u32 { + let func = my_struct.operation; + func(param) +} + diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 819f0ae26c7..a501e9117a2 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -9,7 +9,6 @@ use crate::{ }, ssa_refactor::ir::{ function::{Function, FunctionId}, - instruction::TerminatorInstruction, types::Type, value::ValueId, }, @@ -71,17 +70,7 @@ impl FunctionContext { /// Collects the return values of a given function pub(crate) fn return_values(func: &Function) -> Vec { - let blocks = func.reachable_blocks(); - let mut function_return_values = None; - for block in blocks { - let terminator = func.dfg[block].terminator(); - if let Some(TerminatorInstruction::Return { return_values }) = terminator { - function_return_values = Some(return_values); - break; - } - } - function_return_values - .expect("Expected a return instruction, as block is finished construction") + func.returns() .iter() .map(|&value_id| { let typ = func.dfg.type_of_value(value_id); diff --git a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index 2eaeee8f636..71b06537bd5 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -152,7 +152,7 @@ impl BrilligArtifact { /// This method will offset the positions in the Brillig artifact to /// account for the fact that it is being appended to the end of this /// Brillig artifact (self). - pub(crate) fn link_with(&mut self, func_label: Label, obj: &BrilligArtifact) { + pub(crate) fn link_with(&mut self, obj: &BrilligArtifact) { // Add the unresolved jumps of the linked function to this artifact. self.add_unresolved_jumps_and_calls(obj); @@ -169,7 +169,7 @@ impl BrilligArtifact { self.byte_code.append(&mut byte_code); // Remove all resolved external calls and transform them to jumps - let is_resolved = |label: &Label| label == &func_label; + let is_resolved = |label: &Label| self.labels.get(label).is_some(); let resolved_external_calls = self .unresolved_external_call_labels diff --git a/crates/noirc_evaluator/src/ssa_refactor.rs b/crates/noirc_evaluator/src/ssa_refactor.rs index 34189252f37..a61eae4ca97 100644 --- a/crates/noirc_evaluator/src/ssa_refactor.rs +++ b/crates/noirc_evaluator/src/ssa_refactor.rs @@ -31,7 +31,11 @@ pub(crate) fn optimize_into_acir( print_ssa_passes: bool, ) -> GeneratedAcir { let abi_distinctness = program.return_distinctness; - let mut ssa = ssa_gen::generate_ssa(program).print(print_ssa_passes, "Initial SSA:"); + let mut ssa = ssa_gen::generate_ssa(program) + .print(print_ssa_passes, "Initial SSA:") + .defunctionalize() + .print(print_ssa_passes, "After Defunctionalization:"); + let brillig = ssa.to_brillig(); if let RuntimeType::Acir = ssa.main().runtime() { ssa = ssa diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index 307bcce5a35..2b49bae9e80 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -148,9 +148,8 @@ impl Context { self.create_value_from_type(&typ, &mut |this, _| this.acir_context.add_variable()) }); - let outputs: Vec = vecmap(self.get_return_values(main_func), |result_id| { - dfg.type_of_value(result_id).into() - }); + let outputs: Vec = + vecmap(main_func.returns(), |result_id| dfg.type_of_value(*result_id).into()); let code = self.gen_brillig_for(main_func, &brillig); @@ -351,7 +350,7 @@ impl Context { let artifact = &brillig .find_by_function_label(unresolved_fn_label.clone()) .expect("Cannot find linked fn {unresolved_fn_label}"); - entry_point.link_with(unresolved_fn_label, artifact); + entry_point.link_with(artifact); } // Generate the final bytecode entry_point.finish() @@ -423,22 +422,6 @@ impl Context { self.define_result(dfg, instruction, AcirValue::Var(result, typ)); } - /// Finds the return values of a given function - fn get_return_values(&self, func: &Function) -> Vec { - let blocks = func.reachable_blocks(); - let mut function_return_values = None; - for block in blocks { - let terminator = func.dfg[block].terminator(); - if let Some(TerminatorInstruction::Return { return_values }) = terminator { - function_return_values = Some(return_values); - break; - } - } - function_return_values - .expect("Expected a return instruction, as block is finished construction") - .clone() - } - /// Converts an SSA terminator's return values into their ACIR representations fn convert_ssa_return(&mut self, terminator: &TerminatorInstruction, dfg: &DataFlowGraph) { let return_values = match terminator { @@ -786,7 +769,7 @@ impl Context { } /// Convert a Vec into a Vec using the given result ids. - /// If the type of a result id is an array, several acirvars are collected into + /// If the type of a result id is an array, several acir vars are collected into /// a single AcirValue::Array of the same length. fn convert_vars_to_values( vars: Vec, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs index 393b85fdd2f..9104b65d16f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dfg.rs @@ -109,6 +109,11 @@ impl DataFlowGraph { self.blocks.iter() } + /// Iterate over every Value in this DFG in no particular order, including unused Values + pub(crate) fn values_iter(&self) -> impl ExactSizeIterator { + self.values.iter() + } + /// Returns the parameters of the given block pub(crate) fn block_parameters(&self, block: BasicBlockId) -> &[ValueId] { self.blocks[block].parameters() @@ -169,6 +174,21 @@ impl DataFlowGraph { } } + /// Set the type of value_id to the target_type. + pub(crate) fn set_type_of_value(&mut self, value_id: ValueId, target_type: Type) { + let value = &mut self.values[value_id]; + match value { + Value::Instruction { typ, .. } + | Value::Param { typ, .. } + | Value::NumericConstant { typ, .. } => { + *typ = target_type; + } + _ => { + unreachable!("ICE: Cannot set type of {:?}", value); + } + } + } + /// If `original_value_id`'s underlying `Value` has been substituted for that of another /// `ValueId`, this function will return the `ValueId` from which the substitution was taken. /// If `original_value_id`'s underlying `Value` has not been substituted, the same `ValueId` diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs index 8fe2fe745ff..76395ea74ab 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function.rs @@ -2,11 +2,12 @@ use std::collections::HashSet; use super::basic_block::BasicBlockId; use super::dfg::DataFlowGraph; +use super::instruction::TerminatorInstruction; use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir, @@ -60,7 +61,7 @@ impl Function { /// Runtime type of the function. pub(crate) fn runtime(&self) -> RuntimeType { - self.runtime.clone() + self.runtime } /// Set runtime type of the function. @@ -84,6 +85,21 @@ impl Function { self.dfg.block_parameters(self.entry_block) } + /// Returns the return types of this function. + pub(crate) fn returns(&self) -> &[ValueId] { + let blocks = self.reachable_blocks(); + let mut function_return_values = None; + for block in blocks { + let terminator = self.dfg[block].terminator(); + if let Some(TerminatorInstruction::Return { return_values }) = terminator { + function_return_values = Some(return_values); + break; + } + } + function_return_values + .expect("Expected a return instruction, as function construction is finished") + } + /// Collects all the reachable blocks of this function. /// /// Note that self.dfg.basic_blocks_iter() iterates over all blocks, @@ -102,6 +118,15 @@ impl Function { } } +impl std::fmt::Display for RuntimeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RuntimeType::Acir => write!(f, "acir"), + RuntimeType::Brillig => write!(f, "brillig"), + } + } +} + /// FunctionId is a reference for a function /// /// This Id is how each function refers to other functions diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs index e7f9d812de3..bb0da6a8558 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/map.rs @@ -25,6 +25,11 @@ impl Id { Self { index, _marker: std::marker::PhantomData } } + /// Returns the underlying index of this Id. + pub(crate) fn to_usize(self) -> usize { + self.index + } + /// Creates a test Id with the given index. /// The name of this function makes it apparent it should only /// be used for testing. Obtaining Ids in this way should be avoided diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs index 071f1a16029..f2fb90b3464 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/printer.rs @@ -15,7 +15,7 @@ use super::{ /// Helper function for Function's Display impl to pretty-print the function with the given formatter. pub(crate) fn display_function(function: &Function, f: &mut Formatter) -> Result { - writeln!(f, "fn {} {} {{", function.name(), function.id())?; + writeln!(f, "{} fn {} {} {{", function.runtime(), function.name(), function.id())?; display_block_with_successors(function, function.entry_block(), &mut HashSet::new(), f)?; write!(f, "}}") } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs new file mode 100644 index 00000000000..c31d0c58deb --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs @@ -0,0 +1,337 @@ +//! This module defines the defunctionalization pass for the SSA IR. +//! The purpose of this pass is to transforms all functions used as values into +//! constant numbers (fields) that represent the function id. That way all calls +//! with a non-literal target can be replaced with a call to an apply function. +//! The apply function is a dispatch function that takes the function id as a parameter +//! and dispatches to the correct target. +use std::collections::{HashMap, HashSet}; + +use acvm::FieldElement; +use iter_extended::vecmap; + +use crate::ssa_refactor::{ + ir::{ + basic_block::BasicBlockId, + function::{Function, FunctionId, RuntimeType}, + instruction::{BinaryOp, Instruction}, + types::{NumericType, Type}, + value::Value, + }, + ssa_builder::FunctionBuilder, + ssa_gen::Ssa, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct FunctionSignature { + parameters: Vec, + returns: Vec, + runtime: RuntimeType, +} + +impl FunctionSignature { + fn from(function: &Function) -> Self { + let parameters = vecmap(function.parameters(), |param| function.dfg.type_of_value(*param)); + let returns = vecmap(function.returns(), |ret| function.dfg.type_of_value(*ret)); + let runtime = function.runtime(); + Self { parameters, returns, runtime } + } +} + +/// Represents an 'apply' function created by this pass to dispatch higher order functions to. +/// Pseudocode of an `apply` function is given below: +/// ```text +/// fn apply(function_id: Field, arg1: Field, arg2: Field) -> Field { +/// match function_id { +/// 0 -> function0(arg1, arg2), +/// 1 -> function0(arg1, arg2), +/// ... +/// N -> functionN(arg1, arg2), +/// } +/// } +/// ``` +/// Apply functions generally take the function to apply as their first parameter. This is a Field value +/// obtained by converting the FunctionId into a Field. The remaining parameters of apply are the +/// arguments to forward to this function when calling it internally. +#[derive(Debug, Clone, Copy)] +struct ApplyFunction { + id: FunctionId, + dispatches_to_multiple_functions: bool, +} + +/// Performs defunctionalization on all functions +/// This is done by changing all functions as value to be a number (FieldElement) +/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants +#[derive(Debug, Clone)] +struct DefunctionalizationContext { + fn_to_runtime: HashMap, + variants: HashMap>, + apply_functions: HashMap, +} + +impl Ssa { + pub(crate) fn defunctionalize(mut self) -> Ssa { + // Find all functions used as value that share the same signature + let variants = find_variants(&self); + + let apply_functions = create_apply_functions(&mut self, &variants); + let fn_to_runtime = + self.functions.iter().map(|(func_id, func)| (*func_id, func.runtime())).collect(); + + let context = DefunctionalizationContext { fn_to_runtime, variants, apply_functions }; + + context.defunctionalize_all(&mut self); + self + } +} + +impl DefunctionalizationContext { + /// Defunctionalize all functions in the Ssa + fn defunctionalize_all(mut self, ssa: &mut Ssa) { + for function in ssa.functions.values_mut() { + self.defunctionalize(function); + } + } + + /// Defunctionalize a single function + fn defunctionalize(&mut self, func: &mut Function) { + let mut target_function_ids = HashSet::new(); + + for block_id in func.reachable_blocks() { + let block = &func.dfg[block_id]; + let instructions = block.instructions().to_vec(); + + for instruction_id in instructions { + let instruction = func.dfg[instruction_id].clone(); + let mut replacement_instruction = None; + // Operate on call instructions + let (target_func_id, mut arguments) = match instruction { + Instruction::Call { func: target_func_id, arguments } => { + (target_func_id, arguments) + } + _ => continue, + }; + + match func.dfg[target_func_id] { + // If the target is a function used as value + Value::Param { .. } | Value::Instruction { .. } => { + // Collect the argument types + let argument_types = vecmap(&arguments, |arg| func.dfg.type_of_value(*arg)); + + // Collect the result types + let result_types = + vecmap(func.dfg.instruction_results(instruction_id), |result| { + func.dfg.type_of_value(*result) + }); + // Find the correct apply function + let apply_function = self.get_apply_function(&FunctionSignature { + parameters: argument_types, + returns: result_types, + runtime: func.runtime(), + }); + target_function_ids.insert(apply_function.id); + + // Replace the instruction with a call to apply + let apply_function_value_id = func.dfg.import_function(apply_function.id); + if apply_function.dispatches_to_multiple_functions { + arguments.insert(0, target_func_id); + } + let func = apply_function_value_id; + replacement_instruction = Some(Instruction::Call { func, arguments }); + } + Value::Function(id) => { + target_function_ids.insert(id); + } + _ => {} + } + if let Some(new_instruction) = replacement_instruction { + func.dfg[instruction_id] = new_instruction; + } + } + } + + // Change the type of all the values that are not call targets to NativeField + let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id); + for value_id in value_ids { + if let Type::Function = &func.dfg[value_id].get_type() { + match &func.dfg[value_id] { + // If the value is a static function, transform it to the function id + Value::Function(id) => { + if !target_function_ids.contains(id) { + let new_value = + func.dfg.make_constant(function_id_to_field(*id), Type::field()); + func.dfg.set_value_from_id(value_id, new_value); + } + } + // If the value is a function used as value, just change the type of it + Value::Instruction { .. } | Value::Param { .. } => { + func.dfg.set_type_of_value(value_id, Type::field()); + } + _ => {} + } + } + } + } + + /// Returns the apply function for the given signature + fn get_apply_function(&self, signature: &FunctionSignature) -> ApplyFunction { + *self.apply_functions.get(signature).expect("Could not find apply function") + } +} + +/// Collects all functions used as a value by their signatures +fn find_variants(ssa: &Ssa) -> HashMap> { + let mut variants: HashMap> = HashMap::new(); + let mut functions_used_as_values = HashSet::new(); + + for function in ssa.functions.values() { + functions_used_as_values.extend(functions_as_values(function)); + } + + for function_id in functions_used_as_values { + let function = &ssa.functions[&function_id]; + let signature = FunctionSignature::from(function); + variants.entry(signature).or_default().push(function_id); + } + + variants +} + +/// Finds all literal functions used as values in the given function +fn functions_as_values(func: &Function) -> HashSet { + let mut literal_functions: HashSet<_> = func + .dfg + .values_iter() + .filter_map(|(id, _)| match func.dfg[id] { + Value::Function(id) => Some(id), + _ => None, + }) + .collect(); + + for block_id in func.reachable_blocks() { + let block = &func.dfg[block_id]; + for instruction_id in block.instructions() { + let instruction = &func.dfg[*instruction_id]; + let target_value = match instruction { + Instruction::Call { func, .. } => func, + _ => continue, + }; + let target_id = match func.dfg[*target_value] { + Value::Function(id) => id, + _ => continue, + }; + literal_functions.remove(&target_id); + } + } + literal_functions +} + +fn create_apply_functions( + ssa: &mut Ssa, + variants_map: &HashMap>, +) -> HashMap { + let mut apply_functions = HashMap::new(); + for (signature, variants) in variants_map.iter() { + let dispatches_to_multiple_functions = variants.len() > 1; + let id = if dispatches_to_multiple_functions { + create_apply_function(ssa, signature, variants) + } else { + variants[0] + }; + apply_functions + .insert(signature.clone(), ApplyFunction { id, dispatches_to_multiple_functions }); + } + apply_functions +} + +fn function_id_to_field(function_id: FunctionId) -> FieldElement { + (function_id.to_usize() as u128).into() +} + +/// Creates an apply function for the given signature and variants +fn create_apply_function( + ssa: &mut Ssa, + signature: &FunctionSignature, + function_ids: &[FunctionId], +) -> FunctionId { + assert!(!function_ids.is_empty()); + ssa.add_fn(|id| { + let mut function_builder = FunctionBuilder::new("apply".to_string(), id, signature.runtime); + let target_id = function_builder.add_parameter(Type::field()); + let params_ids = + vecmap(signature.parameters.clone(), |typ| function_builder.add_parameter(typ)); + + let mut previous_target_block = None; + for (index, function_id) in function_ids.iter().enumerate() { + let is_last = index == function_ids.len() - 1; + let mut next_function_block = None; + + let function_id_constant = function_builder.numeric_constant( + function_id_to_field(*function_id), + Type::Numeric(NumericType::NativeField), + ); + let condition = + function_builder.insert_binary(target_id, BinaryOp::Eq, function_id_constant); + + // If it's not the last function to dispatch, create an if statement + if !is_last { + next_function_block = Some(function_builder.insert_block()); + let executor_block = function_builder.insert_block(); + + function_builder.terminate_with_jmpif( + condition, + executor_block, + next_function_block.unwrap(), + ); + function_builder.switch_to_block(executor_block); + } else { + // Else just constrain the condition + function_builder.insert_constrain(condition); + } + // Find the target block or build it if necessary + let current_block = function_builder.current_block(); + + let target_block = build_return_block( + &mut function_builder, + current_block, + signature.returns.clone(), + previous_target_block, + ); + previous_target_block = Some(target_block); + + // Call the function + let target_function_value = function_builder.import_function(*function_id); + let call_results = function_builder + .insert_call(target_function_value, params_ids.clone(), signature.returns.clone()) + .to_vec(); + + // Jump to the target block for returning + function_builder.terminate_with_jmp(target_block, call_results); + + if let Some(next_block) = next_function_block { + // Switch to the next block for the else branch + function_builder.switch_to_block(next_block); + } + } + function_builder.current_function + }) +} + +/// Crates a return block, if no previous return exists, it will create a final return +/// Else, it will create a bypass return block that points to the previous return block +fn build_return_block( + builder: &mut FunctionBuilder, + previous_block: BasicBlockId, + passed_types: Vec, + target: Option, +) -> BasicBlockId { + let return_block = builder.insert_block(); + builder.switch_to_block(return_block); + + let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ)); + match target { + None => builder.terminate_with_return(params), + Some(target) => builder.terminate_with_jmp(target, params), + } + builder.switch_to_block(previous_block); + return_block +} diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs index 56c5fa689ad..0d4ad594486 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/mod.rs @@ -4,6 +4,7 @@ //! simpler form until the IR only has a single function remaining with 1 block within it. //! Generally, these passes are also expected to minimize the final amount of instructions. mod constant_folding; +mod defunctionalize; mod die; mod flatten_cfg; mod inlining; diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs index ba98c658505..aec0e4262c8 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/program.rs @@ -38,6 +38,17 @@ impl Ssa { pub(crate) fn main_mut(&mut self) -> &mut Function { self.functions.get_mut(&self.main_id).expect("ICE: Ssa should have a main function") } + + /// Adds a new function to the program + pub(crate) fn add_fn( + &mut self, + build_with_id: impl FnOnce(FunctionId) -> Function, + ) -> FunctionId { + let new_id = self.next_id.next(); + let function = build_with_id(new_id); + self.functions.insert(new_id, function); + new_id + } } impl Display for Ssa { diff --git a/cspell.json b/cspell.json index e10e700cdb6..92c3154f2b3 100644 --- a/cspell.json +++ b/cspell.json @@ -15,6 +15,9 @@ "combinators", "comptime", "cranelift", + "defunctionalize", + "defunctionalized", + "defunctionalization", "desugared", "endianness", "forall", @@ -45,6 +48,7 @@ "pedersen", "peekable", "preprocess", + "pseudocode", "schnorr", "sdiv", "signedness",