Skip to content

Commit

Permalink
fix: Avoid non-determinism in defunctionalization (#2069)
Browse files Browse the repository at this point in the history
fix: avoid non-determinism in defunctionalize
  • Loading branch information
sirasistant authored Jul 27, 2023
1 parent 7d61d98 commit 898a9fa
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl std::fmt::Display for RuntimeType {
/// within Call instructions.
pub(crate) type FunctionId = Id<Function>;

#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) struct Signature {
pub(crate) params: Vec<Type>,
pub(crate) returns: Vec<Type>,
Expand Down
4 changes: 2 additions & 2 deletions crates/noirc_evaluator/src/ssa_refactor/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ use iter_extended::vecmap;
///
/// Fields do not have a notion of ordering, so this distinction
/// is reasonable.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum NumericType {
Signed { bit_size: u32 },
Unsigned { bit_size: u32 },
NativeField,
}

/// All types representable in the IR.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum Type {
/// Represents numeric types in the IR, including field elements
Numeric(NumericType),
Expand Down
27 changes: 12 additions & 15 deletions crates/noirc_evaluator/src/ssa_refactor/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! with a non-literal target can be replaced with a call to an apply function.
//! The apply function is a dispatch function that takes the function id as a parameter
//! and dispatches to the correct target.
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -47,7 +47,6 @@ struct ApplyFunction {
/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants
#[derive(Debug, Clone)]
struct DefunctionalizationContext {
fn_to_runtime: HashMap<FunctionId, RuntimeType>,
apply_functions: HashMap<Signature, ApplyFunction>,
}

Expand All @@ -57,10 +56,8 @@ impl Ssa {
let variants = find_variants(&self);

let apply_functions = create_apply_functions(&mut self, variants);
let fn_to_runtime =
self.functions.iter().map(|(func_id, func)| (*func_id, func.runtime())).collect();

let context = DefunctionalizationContext { fn_to_runtime, apply_functions };
let context = DefunctionalizationContext { apply_functions };

context.defunctionalize_all(&mut self);
self
Expand Down Expand Up @@ -157,23 +154,23 @@ impl DefunctionalizationContext {
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> HashMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: HashSet<Signature> = HashSet::new();
let mut functions_as_values: HashSet<FunctionId> = HashSet::new();
fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: BTreeSet<Signature> = BTreeSet::new();
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

for function in ssa.functions.values() {
functions_as_values.extend(find_functions_as_values(function));
dynamic_dispatches.extend(find_dynamic_dispatches(function));
}

let mut signature_to_functions_as_value: HashMap<Signature, Vec<FunctionId>> = HashMap::new();
let mut signature_to_functions_as_value: BTreeMap<Signature, Vec<FunctionId>> = BTreeMap::new();

for function_id in functions_as_values {
let signature = ssa.functions[&function_id].signature();
signature_to_functions_as_value.entry(signature).or_default().push(function_id);
}

let mut variants = HashMap::new();
let mut variants = BTreeMap::new();

for dispatch_signature in dynamic_dispatches {
let mut target_fns = vec![];
Expand All @@ -189,8 +186,8 @@ fn find_variants(ssa: &Ssa) -> HashMap<Signature, Vec<FunctionId>> {
}

/// Finds all literal functions used as values in the given function
fn find_functions_as_values(func: &Function) -> HashSet<FunctionId> {
let mut functions_as_values: HashSet<FunctionId> = HashSet::new();
fn find_functions_as_values(func: &Function) -> BTreeSet<FunctionId> {
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

let mut process_value = |value_id: ValueId| {
if let Value::Function(id) = func.dfg[value_id] {
Expand Down Expand Up @@ -220,8 +217,8 @@ fn find_functions_as_values(func: &Function) -> HashSet<FunctionId> {
}

/// Finds all dynamic dispatch signatures in the given function
fn find_dynamic_dispatches(func: &Function) -> HashSet<Signature> {
let mut dispatches = HashSet::new();
fn find_dynamic_dispatches(func: &Function) -> BTreeSet<Signature> {
let mut dispatches = BTreeSet::new();

for block_id in func.reachable_blocks() {
let block = &func.dfg[block_id];
Expand All @@ -246,7 +243,7 @@ fn find_dynamic_dispatches(func: &Function) -> HashSet<Signature> {

fn create_apply_functions(
ssa: &mut Ssa,
variants_map: HashMap<Signature, Vec<FunctionId>>,
variants_map: BTreeMap<Signature, Vec<FunctionId>>,
) -> HashMap<Signature, ApplyFunction> {
let mut apply_functions = HashMap::new();
for (signature, variants) in variants_map.into_iter() {
Expand Down

0 comments on commit 898a9fa

Please sign in to comment.