Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid non-determinism in defunctionalization #2069

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl std::fmt::Display for RuntimeType {
/// within Call instructions.
pub(crate) type FunctionId = Id<Function>;

#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) struct Signature {
pub(crate) params: Vec<Type>,
pub(crate) returns: Vec<Type>,
Expand Down
4 changes: 2 additions & 2 deletions crates/noirc_evaluator/src/ssa_refactor/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ use iter_extended::vecmap;
///
/// Fields do not have a notion of ordering, so this distinction
/// is reasonable.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum NumericType {
Signed { bit_size: u32 },
Unsigned { bit_size: u32 },
NativeField,
}

/// All types representable in the IR.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum Type {
/// Represents numeric types in the IR, including field elements
Numeric(NumericType),
Expand Down
27 changes: 12 additions & 15 deletions crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! with a non-literal target can be replaced with a call to an apply function.
//! The apply function is a dispatch function that takes the function id as a parameter
//! and dispatches to the correct target.
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -47,7 +47,6 @@ struct ApplyFunction {
/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants
#[derive(Debug, Clone)]
struct DefunctionalizationContext {
fn_to_runtime: HashMap<FunctionId, RuntimeType>,
apply_functions: HashMap<Signature, ApplyFunction>,
}

Expand All @@ -57,10 +56,8 @@ impl Ssa {
let variants = find_variants(&self);

let apply_functions = create_apply_functions(&mut self, variants);
let fn_to_runtime =
self.functions.iter().map(|(func_id, func)| (*func_id, func.runtime())).collect();

let context = DefunctionalizationContext { fn_to_runtime, apply_functions };
let context = DefunctionalizationContext { apply_functions };

context.defunctionalize_all(&mut self);
self
Expand Down Expand Up @@ -157,23 +154,23 @@ impl DefunctionalizationContext {
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> HashMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: HashSet<Signature> = HashSet::new();
let mut functions_as_values: HashSet<FunctionId> = HashSet::new();
fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: BTreeSet<Signature> = BTreeSet::new();
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

for function in ssa.functions.values() {
functions_as_values.extend(find_functions_as_values(function));
dynamic_dispatches.extend(find_dynamic_dispatches(function));
}

let mut signature_to_functions_as_value: HashMap<Signature, Vec<FunctionId>> = HashMap::new();
let mut signature_to_functions_as_value: BTreeMap<Signature, Vec<FunctionId>> = BTreeMap::new();

for function_id in functions_as_values {
let signature = ssa.functions[&function_id].signature();
signature_to_functions_as_value.entry(signature).or_default().push(function_id);
}

let mut variants = HashMap::new();
let mut variants = BTreeMap::new();

for dispatch_signature in dynamic_dispatches {
let mut target_fns = vec![];
Expand All @@ -189,8 +186,8 @@ fn find_variants(ssa: &Ssa) -> HashMap<Signature, Vec<FunctionId>> {
}

/// Finds all literal functions used as values in the given function
fn find_functions_as_values(func: &Function) -> HashSet<FunctionId> {
let mut functions_as_values: HashSet<FunctionId> = HashSet::new();
fn find_functions_as_values(func: &Function) -> BTreeSet<FunctionId> {
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

let mut process_value = |value_id: ValueId| {
if let Value::Function(id) = func.dfg[value_id] {
Expand Down Expand Up @@ -220,8 +217,8 @@ fn find_functions_as_values(func: &Function) -> HashSet<FunctionId> {
}

/// Finds all dynamic dispatch signatures in the given function
fn find_dynamic_dispatches(func: &Function) -> HashSet<Signature> {
let mut dispatches = HashSet::new();
fn find_dynamic_dispatches(func: &Function) -> BTreeSet<Signature> {
let mut dispatches = BTreeSet::new();

for block_id in func.reachable_blocks() {
let block = &func.dfg[block_id];
Expand All @@ -246,7 +243,7 @@ fn find_dynamic_dispatches(func: &Function) -> HashSet<Signature> {

fn create_apply_functions(
ssa: &mut Ssa,
variants_map: HashMap<Signature, Vec<FunctionId>>,
variants_map: BTreeMap<Signature, Vec<FunctionId>>,
) -> HashMap<Signature, ApplyFunction> {
let mut apply_functions = HashMap::new();
for (signature, variants) in variants_map.into_iter() {
Expand Down