Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): Reset condition value during flattening pass #1811

Merged
merged 5 commits into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 153 additions & 16 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ impl<'f> Context<'f> {
let else_branch =
self.inline_branch(block, else_block, old_condition, else_condition, zero);

// We must remember to reset whether side effects are enabled when both branches
// end, in addition to resetting the value of old_condition since it is set to
// known to be true/false within the then/else branch respectively.
self.insert_current_side_effects_enabled();
self.inserter.map_value(old_condition, old_condition);

// While there is a condition on the stack we don't compile outside the condition
// until it is popped. This ensures we inline the full then and else branches
Expand Down Expand Up @@ -494,10 +498,16 @@ impl<'f> Context<'f> {
let old_stores = std::mem::take(&mut self.store_values);
let old_allocations = std::mem::take(&mut self.local_allocations);

// Remember the old condition value is now known to be true/false within this branch
let known_value =
self.inserter.function.dfg.make_constant(condition_value, Type::bool());
self.inserter.map_value(old_condition, known_value);
// Optimization: within the then branch we know the condition to be true, so replace
// any references of it within this branch with true. Likewise, do the same with false
// with the else branch. We must be careful not to replace the condition if it is a
// known constant, otherwise we can end up setting 1 = 0 or vice-versa.
if self.inserter.function.dfg.get_numeric_constant(old_condition).is_none() {
let known_value =
self.inserter.function.dfg.make_constant(condition_value, Type::bool());

self.inserter.map_value(old_condition, known_value);
}

let final_block = self.inline_block(destination, &[]);

Expand Down Expand Up @@ -670,11 +680,12 @@ impl<'f> Context<'f> {

#[cfg(test)]
mod test {
use std::rc::Rc;

use crate::ssa_refactor::{
ir::{
dfg::DataFlowGraph,
function::RuntimeType,
function::{Function, RuntimeType},
instruction::{BinaryOp, Instruction, Intrinsic, TerminatorInstruction},
map::Id,
types::Type,
Expand Down Expand Up @@ -837,12 +848,7 @@ mod test {
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

let store_count = main.dfg[main.entry_block()]
.instructions()
.iter()
.filter(|id| matches!(&main.dfg[**id], Instruction::Store { .. }))
.count();

let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
assert_eq!(store_count, 2);
}

Expand Down Expand Up @@ -921,13 +927,16 @@ mod test {
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

let store_count = main.dfg[main.entry_block()]
let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
assert_eq!(store_count, 4);
}

fn count_instruction(function: &Function, f: impl Fn(&Instruction) -> bool) -> usize {
function.dfg[function.entry_block()]
.instructions()
.iter()
.filter(|id| matches!(&main.dfg[**id], Instruction::Store { .. }))
.count();

assert_eq!(store_count, 4);
.filter(|id| f(&function.dfg[**id]))
.count()
}

#[test]
Expand Down Expand Up @@ -1196,4 +1205,132 @@ mod test {
_ => Vec::new(),
}
}

#[test]
fn should_not_merge_away_constraints() {
// Very simplified derived regression test for #1792
// Tests that it does not simplify to a true constraint an always-false constraint
// The original function is replaced by the following:
// fn main f1 {
// b0():
// jmpif u1 0 then: b1, else: b2
// b1():
// jmp b2()
// b2():
// constrain u1 0 // was incorrectly removed
// return
// }
let main_id = Id::test_new(1);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

builder.insert_block(); // entry

let b1 = builder.insert_block();
let b2 = builder.insert_block();
let v_false = builder.numeric_constant(0_u128, Type::bool());
jfecher marked this conversation as resolved.
Show resolved Hide resolved
builder.terminate_with_jmpif(v_false, b1, b2);

builder.switch_to_block(b1);
builder.terminate_with_jmp(b2, vec![]);

builder.switch_to_block(b2);
builder.insert_constrain(v_false); // should not be removed
builder.terminate_with_return(vec![]);

let ssa = builder.finish().flatten_cfg();
let main = ssa.main();

// Assert we have not incorrectly removed a constraint:
use Instruction::Constrain;
let constrain_count = count_instruction(main, |ins| matches!(ins, Constrain(_)));
assert_eq!(constrain_count, 1);
}

#[test]
fn should_not_merge_incorrectly_to_false() {
// Regression test for #1792
// Tests that it does not simplify a true constraint an always-false constraint
// fn main f1 {
// b0():
// v4 = call pedersen([Field 0], u32 0)
// v5 = array_get v4, index Field 0
// v6 = cast v5 as u32
// v8 = mod v6, u32 2
// v9 = cast v8 as u1
// v10 = allocate
// store Field 0 at v10
// jmpif v9 then: b1, else: b2
// b1():
// v14 = add v5, Field 1
// store v14 at v10
// jmp b3()
// b3():
// v12 = eq v9, u1 1
// constrain v12
// return
// b2():
// store Field 0 at v10
// jmp b3()
// }
let main_id = Id::test_new(1);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

builder.insert_block(); // b0
let b1 = builder.insert_block();
let b2 = builder.insert_block();
let b3 = builder.insert_block();

let element_type = Rc::new(vec![Type::field()]);
let zero = builder.field_constant(0_u128);
let zero_array = builder.array_constant(im::Vector::unit(zero), element_type.clone());
let i_zero = builder.numeric_constant(0_u128, Type::unsigned(32));
let pedersen =
builder.import_intrinsic_id(Intrinsic::BlackBox(acvm::acir::BlackBoxFunc::Pedersen));
let v4 = builder.insert_call(
pedersen,
vec![zero_array, i_zero],
vec![Type::Array(element_type, 2)],
)[0];
let v5 = builder.insert_array_get(v4, zero, Type::field());
let v6 = builder.insert_cast(v5, Type::unsigned(32));
let i_two = builder.numeric_constant(2_u128, Type::unsigned(32));
let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two);
let v9 = builder.insert_cast(v8, Type::bool());

let v10 = builder.insert_allocate();
builder.insert_store(v10, zero);

builder.terminate_with_jmpif(v9, b1, b2);

builder.switch_to_block(b1);
let one = builder.field_constant(1_u128);
let v14 = builder.insert_binary(v5, BinaryOp::Add, one);
builder.insert_store(v10, v14);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b2);
builder.insert_store(v10, zero);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b3);
let b_true = builder.numeric_constant(1_u128, Type::unsigned(1));
let v12 = builder.insert_binary(v9, BinaryOp::Eq, b_true);
builder.insert_constrain(v12);
builder.terminate_with_return(vec![]);

let ssa = builder.finish().flatten_cfg();
let main = ssa.main();

// Now assert that there is not an always-false constraint after flattening:
let mut constrain_count = 0;
for instruction in main.dfg[main.entry_block()].instructions() {
if let Instruction::Constrain(value) = main.dfg[*instruction] {
if let Some(constant) = main.dfg.get_numeric_constant(value) {
assert!(constant.is_one());
}
constrain_count += 1;
}
}
assert_eq!(constrain_count, 1);
}
}