Skip to content

Commit

Permalink
fix(ssa refactor): Prevent stores in 'then' branch from affecting the…
Browse files Browse the repository at this point in the history
… 'else' branch (#1827)

* Fix stores in 'then' branch affecting the 'else' branch

* Update crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs

* Update crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
jfecher and kevaundray authored Jun 26, 2023
1 parent 6fa751b commit e068fd4
Showing 1 changed file with 181 additions and 46 deletions.
227 changes: 181 additions & 46 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ impl<'f> Context<'f> {
let else_condition = self.insert_instruction(Instruction::Not(then_condition));
let zero = FieldElement::zero();

// Make sure the else branch sees the previous values of each store
// rather than any values created in the 'then' branch.
self.undo_stores_in_then_branch(&then_branch);

let else_branch =
self.inline_branch(block, else_block, old_condition, else_condition, zero);

Expand Down Expand Up @@ -572,9 +576,25 @@ impl<'f> Context<'f> {
/// instruction. If this ordering is changed, the ordering that store values are merged within
/// this function also needs to be changed to reflect that.
fn merge_stores(&mut self, then_branch: Branch, else_branch: Branch) {
let mut merge_store = |address, then_case, else_case, old_value| {
let then_condition = then_branch.condition;
let else_condition = else_branch.condition;
// Address -> (then_value, else_value, value_before_the_if)
let mut new_map = HashMap::with_capacity(then_branch.store_values.len());

for (address, store) in then_branch.store_values {
new_map.insert(address, (store.new_value, store.old_value, store.old_value));
}

for (address, store) in else_branch.store_values {
if let Some(entry) = new_map.get_mut(&address) {
entry.1 = store.new_value;
} else {
new_map.insert(address, (store.old_value, store.new_value, store.old_value));
}
}

let then_condition = then_branch.condition;
let else_condition = else_branch.condition;

for (address, (then_case, else_case, old_value)) in new_map {
let value = self.merge_values(then_condition, else_condition, then_case, else_case);
self.insert_instruction_with_typevars(Instruction::Store { address, value }, None);

Expand All @@ -583,14 +603,6 @@ impl<'f> Context<'f> {
} else {
self.store_values.insert(address, Store { old_value, new_value: value });
}
};

for (address, store) in then_branch.store_values {
merge_store(address, store.new_value, store.old_value, store.old_value);
}

for (address, store) in else_branch.store_values {
merge_store(address, store.old_value, store.new_value, store.old_value);
}
}

Expand Down Expand Up @@ -676,6 +688,14 @@ impl<'f> Context<'f> {
instruction
}
}

fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) {
for (address, store) in &then_branch.store_values {
let address = *address;
let value = store.old_value;
self.insert_instruction_with_typevars(Instruction::Store { address, value }, None);
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -836,36 +856,35 @@ mod test {
// v4 = load v1
// store Field 5 at v1
// v5 = not v0
// enable_side_effects v5
// store v4 at v1
// enable_side_effects u1 1
// v7 = mul v0, Field 5
// v8 = mul v5, v4
// v9 = add v7, v8
// store v9 at v1
// v6 = cast v0 as Field
// v7 = cast v5 as Field
// v8 = mul v6, Field 5
// v9 = mul v7, v4
// v10 = add v8, v9
// store v10 at v1
// return
// }
let ssa = ssa.flatten_cfg();
let main = ssa.main();

assert_eq!(main.reachable_blocks().len(), 1);

let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
assert_eq!(store_count, 2);
assert_eq!(store_count, 3);
}

// Currently failing since the offsets create additions with different ValueIds which are
// treated wrongly as different addresses.
#[test]
fn merge_stores_with_offsets() {
fn merge_stores_with_else_block() {
// fn main f0 {
// b0(v0: u1, v1: ref):
// jmpif v0, then: b1, else: b2
// b1():
// v2 = add v1, 1
// store v2, Field 5
// store Field 5 in v1
// jmp b3()
// b2():
// v3 = add v1, 1
// store v3, Field 6
// store Field 6 in v1
// jmp b3()
// b3():
// return
Expand All @@ -883,16 +902,13 @@ mod test {
builder.terminate_with_jmpif(v0, b1, b2);

builder.switch_to_block(b1);
let one = builder.field_constant(1u128);
let v2 = builder.insert_binary(v1, BinaryOp::Add, one);
let five = builder.field_constant(5u128);
builder.insert_store(v2, five);
builder.insert_store(v1, five);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b2);
let v3 = builder.insert_binary(v1, BinaryOp::Add, one);
let six = builder.field_constant(6u128);
builder.insert_store(v3, six);
builder.insert_store(v1, six);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b3);
Expand All @@ -904,27 +920,25 @@ mod test {
// fn main f0 {
// b0(v0: u1, v1: reference):
// enable_side_effects v0
// v7 = add v1, Field 1
// v8 = load v7
// store Field 5 at v7
// v9 = not v0
// enable_side_effects v9
// v11 = add v1, Field 1
// v12 = load v11
// store Field 6 at v11
// enable_side_effects Field 1
// v13 = mul v0, Field 5
// v14 = mul v9, v8
// v15 = add v13, v14
// store v15 at v7
// v16 = mul v0, v12
// v17 = mul v9, Field 6
// v18 = add v16, v17
// store v18 at v11
// v5 = load v1
// store Field 5 at v1
// v6 = not v0
// store v5 at v1
// enable_side_effects v6
// v8 = load v1
// store Field 6 at v1
// enable_side_effects u1 1
// v9 = cast v0 as Field
// v10 = cast v6 as Field
// v11 = mul v9, Field 5
// v12 = mul v10, Field 6
// v13 = add v11, v12
// store v13 at v1
// return
// }
let ssa = ssa.flatten_cfg();
let main = ssa.main();
println!("{ssa}");
assert_eq!(main.reachable_blocks().len(), 1);

let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
Expand Down Expand Up @@ -1333,4 +1347,125 @@ mod test {
}
assert_eq!(constrain_count, 1);
}

#[test]
fn undo_stores() {
// Regression test for #1826. Ensures the `else` branch does not see the stores of the
// `then` branch.
//
// fn main f1 {
// b0():
// v0 = allocate
// store Field 0 at v0
// v2 = allocate
// store Field 2 at v2
// v4 = load v2
// v5 = lt v4, Field 2
// jmpif v5 then: b1, else: b2
// b1():
// v24 = load v0
// v25 = load v2
// v26 = mul v25, Field 10
// v27 = add v24, v26
// store v27 at v0
// v28 = load v2
// v29 = add v28, Field 1
// store v29 at v2
// jmp b5()
// b5():
// v14 = load v0
// return v14
// b2():
// v6 = load v2
// v8 = lt v6, Field 4
// jmpif v8 then: b3, else: b4
// b3():
// v16 = load v0
// v17 = load v2
// v19 = mul v17, Field 100
// v20 = add v16, v19
// store v20 at v0
// v21 = load v2
// v23 = add v21, Field 1
// store v23 at v2
// jmp b4()
// b4():
// jmp b5()
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let b1 = builder.insert_block();
let b2 = builder.insert_block();
let b3 = builder.insert_block();
let b4 = builder.insert_block();
let b5 = builder.insert_block();

let zero = builder.field_constant(0u128);
let one = builder.field_constant(1u128);
let two = builder.field_constant(2u128);
let four = builder.field_constant(4u128);
let ten = builder.field_constant(10u128);
let one_hundred = builder.field_constant(100u128);

let v0 = builder.insert_allocate();
builder.insert_store(v0, zero);
let v2 = builder.insert_allocate();
builder.insert_store(v2, two);
let v4 = builder.insert_load(v2, Type::field());
let v5 = builder.insert_binary(v4, BinaryOp::Lt, two);
builder.terminate_with_jmpif(v5, b1, b2);

builder.switch_to_block(b1);
let v24 = builder.insert_load(v0, Type::field());
let v25 = builder.insert_load(v2, Type::field());
let v26 = builder.insert_binary(v25, BinaryOp::Mul, ten);
let v27 = builder.insert_binary(v24, BinaryOp::Add, v26);
builder.insert_store(v0, v27);
let v28 = builder.insert_load(v2, Type::field());
let v29 = builder.insert_binary(v28, BinaryOp::Add, one);
builder.insert_store(v2, v29);
builder.terminate_with_jmp(b5, vec![]);

builder.switch_to_block(b5);
let v14 = builder.insert_load(v0, Type::field());
builder.terminate_with_return(vec![v14]);

builder.switch_to_block(b2);
let v6 = builder.insert_load(v2, Type::field());
let v8 = builder.insert_binary(v6, BinaryOp::Lt, four);
builder.terminate_with_jmpif(v8, b3, b4);

builder.switch_to_block(b3);
let v16 = builder.insert_load(v0, Type::field());
let v17 = builder.insert_load(v2, Type::field());
let v19 = builder.insert_binary(v17, BinaryOp::Mul, one_hundred);
let v20 = builder.insert_binary(v16, BinaryOp::Add, v19);
builder.insert_store(v0, v20);
let v21 = builder.insert_load(v2, Type::field());
let v23 = builder.insert_binary(v21, BinaryOp::Add, one);
builder.insert_store(v2, v23);
builder.terminate_with_jmp(b4, vec![]);

builder.switch_to_block(b4);
builder.terminate_with_jmp(b5, vec![]);

let ssa = builder.finish().flatten_cfg().mem2reg().fold_constants();

let main = ssa.main();

// The return value should be 200, not 310
match main.dfg[main.entry_block()].terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
match main.dfg.get_numeric_constant(return_values[0]) {
Some(constant) => {
let value = constant.to_u128();
assert_eq!(value, 200);
}
None => unreachable!("Expected constant 200 for return value"),
}
}
_ => unreachable!(),
}
}
}

0 comments on commit e068fd4

Please sign in to comment.