diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs index 2931aee1dfd..9c8b3c6ee91 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs @@ -266,6 +266,10 @@ impl<'f> Context<'f> { let else_condition = self.insert_instruction(Instruction::Not(then_condition)); let zero = FieldElement::zero(); + // Make sure the else branch sees the previous values of each store + // rather than any values created in the 'then' branch. + self.undo_stores_in_then_branch(&then_branch); + let else_branch = self.inline_branch(block, else_block, old_condition, else_condition, zero); @@ -572,9 +576,25 @@ impl<'f> Context<'f> { /// instruction. If this ordering is changed, the ordering that store values are merged within /// this function also needs to be changed to reflect that. fn merge_stores(&mut self, then_branch: Branch, else_branch: Branch) { - let mut merge_store = |address, then_case, else_case, old_value| { - let then_condition = then_branch.condition; - let else_condition = else_branch.condition; + // Address -> (then_value, else_value, value_before_the_if) + let mut new_map = HashMap::with_capacity(then_branch.store_values.len()); + + for (address, store) in then_branch.store_values { + new_map.insert(address, (store.new_value, store.old_value, store.old_value)); + } + + for (address, store) in else_branch.store_values { + if let Some(entry) = new_map.get_mut(&address) { + entry.1 = store.new_value; + } else { + new_map.insert(address, (store.old_value, store.new_value, store.old_value)); + } + } + + let then_condition = then_branch.condition; + let else_condition = else_branch.condition; + + for (address, (then_case, else_case, old_value)) in new_map { let value = self.merge_values(then_condition, else_condition, then_case, else_case); self.insert_instruction_with_typevars(Instruction::Store { address, value }, None); @@ -583,14 +603,6 @@ impl<'f> Context<'f> { } else { self.store_values.insert(address, Store { old_value, new_value: value }); } - }; - - for (address, store) in then_branch.store_values { - merge_store(address, store.new_value, store.old_value, store.old_value); - } - - for (address, store) in else_branch.store_values { - merge_store(address, store.old_value, store.new_value, store.old_value); } } @@ -676,6 +688,14 @@ impl<'f> Context<'f> { instruction } } + + fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) { + for (address, store) in &then_branch.store_values { + let address = *address; + let value = store.old_value; + self.insert_instruction_with_typevars(Instruction::Store { address, value }, None); + } + } } #[cfg(test)] @@ -836,36 +856,35 @@ mod test { // v4 = load v1 // store Field 5 at v1 // v5 = not v0 - // enable_side_effects v5 + // store v4 at v1 // enable_side_effects u1 1 - // v7 = mul v0, Field 5 - // v8 = mul v5, v4 - // v9 = add v7, v8 - // store v9 at v1 + // v6 = cast v0 as Field + // v7 = cast v5 as Field + // v8 = mul v6, Field 5 + // v9 = mul v7, v4 + // v10 = add v8, v9 + // store v10 at v1 // return // } let ssa = ssa.flatten_cfg(); let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 1); let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. })); - assert_eq!(store_count, 2); + assert_eq!(store_count, 3); } - // Currently failing since the offsets create additions with different ValueIds which are - // treated wrongly as different addresses. #[test] - fn merge_stores_with_offsets() { + fn merge_stores_with_else_block() { // fn main f0 { // b0(v0: u1, v1: ref): // jmpif v0, then: b1, else: b2 // b1(): - // v2 = add v1, 1 - // store v2, Field 5 + // store Field 5 in v1 // jmp b3() // b2(): - // v3 = add v1, 1 - // store v3, Field 6 + // store Field 6 in v1 // jmp b3() // b3(): // return @@ -883,16 +902,13 @@ mod test { builder.terminate_with_jmpif(v0, b1, b2); builder.switch_to_block(b1); - let one = builder.field_constant(1u128); - let v2 = builder.insert_binary(v1, BinaryOp::Add, one); let five = builder.field_constant(5u128); - builder.insert_store(v2, five); + builder.insert_store(v1, five); builder.terminate_with_jmp(b3, vec![]); builder.switch_to_block(b2); - let v3 = builder.insert_binary(v1, BinaryOp::Add, one); let six = builder.field_constant(6u128); - builder.insert_store(v3, six); + builder.insert_store(v1, six); builder.terminate_with_jmp(b3, vec![]); builder.switch_to_block(b3); @@ -904,27 +920,25 @@ mod test { // fn main f0 { // b0(v0: u1, v1: reference): // enable_side_effects v0 - // v7 = add v1, Field 1 - // v8 = load v7 - // store Field 5 at v7 - // v9 = not v0 - // enable_side_effects v9 - // v11 = add v1, Field 1 - // v12 = load v11 - // store Field 6 at v11 - // enable_side_effects Field 1 - // v13 = mul v0, Field 5 - // v14 = mul v9, v8 - // v15 = add v13, v14 - // store v15 at v7 - // v16 = mul v0, v12 - // v17 = mul v9, Field 6 - // v18 = add v16, v17 - // store v18 at v11 + // v5 = load v1 + // store Field 5 at v1 + // v6 = not v0 + // store v5 at v1 + // enable_side_effects v6 + // v8 = load v1 + // store Field 6 at v1 + // enable_side_effects u1 1 + // v9 = cast v0 as Field + // v10 = cast v6 as Field + // v11 = mul v9, Field 5 + // v12 = mul v10, Field 6 + // v13 = add v11, v12 + // store v13 at v1 // return // } let ssa = ssa.flatten_cfg(); let main = ssa.main(); + println!("{ssa}"); assert_eq!(main.reachable_blocks().len(), 1); let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. })); @@ -1333,4 +1347,125 @@ mod test { } assert_eq!(constrain_count, 1); } + + #[test] + fn undo_stores() { + // Regression test for #1826. Ensures the `else` branch does not see the stores of the + // `then` branch. + // + // fn main f1 { + // b0(): + // v0 = allocate + // store Field 0 at v0 + // v2 = allocate + // store Field 2 at v2 + // v4 = load v2 + // v5 = lt v4, Field 2 + // jmpif v5 then: b1, else: b2 + // b1(): + // v24 = load v0 + // v25 = load v2 + // v26 = mul v25, Field 10 + // v27 = add v24, v26 + // store v27 at v0 + // v28 = load v2 + // v29 = add v28, Field 1 + // store v29 at v2 + // jmp b5() + // b5(): + // v14 = load v0 + // return v14 + // b2(): + // v6 = load v2 + // v8 = lt v6, Field 4 + // jmpif v8 then: b3, else: b4 + // b3(): + // v16 = load v0 + // v17 = load v2 + // v19 = mul v17, Field 100 + // v20 = add v16, v19 + // store v20 at v0 + // v21 = load v2 + // v23 = add v21, Field 1 + // store v23 at v2 + // jmp b4() + // b4(): + // jmp b5() + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + let b4 = builder.insert_block(); + let b5 = builder.insert_block(); + + let zero = builder.field_constant(0u128); + let one = builder.field_constant(1u128); + let two = builder.field_constant(2u128); + let four = builder.field_constant(4u128); + let ten = builder.field_constant(10u128); + let one_hundred = builder.field_constant(100u128); + + let v0 = builder.insert_allocate(); + builder.insert_store(v0, zero); + let v2 = builder.insert_allocate(); + builder.insert_store(v2, two); + let v4 = builder.insert_load(v2, Type::field()); + let v5 = builder.insert_binary(v4, BinaryOp::Lt, two); + builder.terminate_with_jmpif(v5, b1, b2); + + builder.switch_to_block(b1); + let v24 = builder.insert_load(v0, Type::field()); + let v25 = builder.insert_load(v2, Type::field()); + let v26 = builder.insert_binary(v25, BinaryOp::Mul, ten); + let v27 = builder.insert_binary(v24, BinaryOp::Add, v26); + builder.insert_store(v0, v27); + let v28 = builder.insert_load(v2, Type::field()); + let v29 = builder.insert_binary(v28, BinaryOp::Add, one); + builder.insert_store(v2, v29); + builder.terminate_with_jmp(b5, vec![]); + + builder.switch_to_block(b5); + let v14 = builder.insert_load(v0, Type::field()); + builder.terminate_with_return(vec![v14]); + + builder.switch_to_block(b2); + let v6 = builder.insert_load(v2, Type::field()); + let v8 = builder.insert_binary(v6, BinaryOp::Lt, four); + builder.terminate_with_jmpif(v8, b3, b4); + + builder.switch_to_block(b3); + let v16 = builder.insert_load(v0, Type::field()); + let v17 = builder.insert_load(v2, Type::field()); + let v19 = builder.insert_binary(v17, BinaryOp::Mul, one_hundred); + let v20 = builder.insert_binary(v16, BinaryOp::Add, v19); + builder.insert_store(v0, v20); + let v21 = builder.insert_load(v2, Type::field()); + let v23 = builder.insert_binary(v21, BinaryOp::Add, one); + builder.insert_store(v2, v23); + builder.terminate_with_jmp(b4, vec![]); + + builder.switch_to_block(b4); + builder.terminate_with_jmp(b5, vec![]); + + let ssa = builder.finish().flatten_cfg().mem2reg().fold_constants(); + + let main = ssa.main(); + + // The return value should be 200, not 310 + match main.dfg[main.entry_block()].terminator() { + Some(TerminatorInstruction::Return { return_values }) => { + match main.dfg.get_numeric_constant(return_values[0]) { + Some(constant) => { + let value = constant.to_u128(); + assert_eq!(value, 200); + } + None => unreachable!("Expected constant 200 for return value"), + } + } + _ => unreachable!(), + } + } }