diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs index 9bbe59d1b1e..2931aee1dfd 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs @@ -269,7 +269,11 @@ impl<'f> Context<'f> { let else_branch = self.inline_branch(block, else_block, old_condition, else_condition, zero); + // We must remember to reset whether side effects are enabled when both branches + // end, in addition to resetting the value of old_condition since it is set to + // known to be true/false within the then/else branch respectively. self.insert_current_side_effects_enabled(); + self.inserter.map_value(old_condition, old_condition); // While there is a condition on the stack we don't compile outside the condition // until it is popped. This ensures we inline the full then and else branches @@ -494,10 +498,16 @@ impl<'f> Context<'f> { let old_stores = std::mem::take(&mut self.store_values); let old_allocations = std::mem::take(&mut self.local_allocations); - // Remember the old condition value is now known to be true/false within this branch - let known_value = - self.inserter.function.dfg.make_constant(condition_value, Type::bool()); - self.inserter.map_value(old_condition, known_value); + // Optimization: within the then branch we know the condition to be true, so replace + // any references of it within this branch with true. Likewise, do the same with false + // with the else branch. We must be careful not to replace the condition if it is a + // known constant, otherwise we can end up setting 1 = 0 or vice-versa. + if self.inserter.function.dfg.get_numeric_constant(old_condition).is_none() { + let known_value = + self.inserter.function.dfg.make_constant(condition_value, Type::bool()); + + self.inserter.map_value(old_condition, known_value); + } let final_block = self.inline_block(destination, &[]); @@ -670,11 +680,12 @@ impl<'f> Context<'f> { #[cfg(test)] mod test { + use std::rc::Rc; use crate::ssa_refactor::{ ir::{ dfg::DataFlowGraph, - function::RuntimeType, + function::{Function, RuntimeType}, instruction::{BinaryOp, Instruction, Intrinsic, TerminatorInstruction}, map::Id, types::Type, @@ -837,12 +848,7 @@ mod test { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 1); - let store_count = main.dfg[main.entry_block()] - .instructions() - .iter() - .filter(|id| matches!(&main.dfg[**id], Instruction::Store { .. })) - .count(); - + let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. })); assert_eq!(store_count, 2); } @@ -921,13 +927,16 @@ mod test { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 1); - let store_count = main.dfg[main.entry_block()] + let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. })); + assert_eq!(store_count, 4); + } + + fn count_instruction(function: &Function, f: impl Fn(&Instruction) -> bool) -> usize { + function.dfg[function.entry_block()] .instructions() .iter() - .filter(|id| matches!(&main.dfg[**id], Instruction::Store { .. })) - .count(); - - assert_eq!(store_count, 4); + .filter(|id| f(&function.dfg[**id])) + .count() } #[test] @@ -1196,4 +1205,132 @@ mod test { _ => Vec::new(), } } + + #[test] + fn should_not_merge_away_constraints() { + // Very simplified derived regression test for #1792 + // Tests that it does not simplify to a true constraint an always-false constraint + // The original function is replaced by the following: + // fn main f1 { + // b0(): + // jmpif u1 0 then: b1, else: b2 + // b1(): + // jmp b2() + // b2(): + // constrain u1 0 // was incorrectly removed + // return + // } + let main_id = Id::test_new(1); + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + + builder.insert_block(); // entry + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let v_false = builder.numeric_constant(0_u128, Type::bool()); + builder.terminate_with_jmpif(v_false, b1, b2); + + builder.switch_to_block(b1); + builder.terminate_with_jmp(b2, vec![]); + + builder.switch_to_block(b2); + builder.insert_constrain(v_false); // should not be removed + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().flatten_cfg(); + let main = ssa.main(); + + // Assert we have not incorrectly removed a constraint: + use Instruction::Constrain; + let constrain_count = count_instruction(main, |ins| matches!(ins, Constrain(_))); + assert_eq!(constrain_count, 1); + } + + #[test] + fn should_not_merge_incorrectly_to_false() { + // Regression test for #1792 + // Tests that it does not simplify a true constraint an always-false constraint + // fn main f1 { + // b0(): + // v4 = call pedersen([Field 0], u32 0) + // v5 = array_get v4, index Field 0 + // v6 = cast v5 as u32 + // v8 = mod v6, u32 2 + // v9 = cast v8 as u1 + // v10 = allocate + // store Field 0 at v10 + // jmpif v9 then: b1, else: b2 + // b1(): + // v14 = add v5, Field 1 + // store v14 at v10 + // jmp b3() + // b3(): + // v12 = eq v9, u1 1 + // constrain v12 + // return + // b2(): + // store Field 0 at v10 + // jmp b3() + // } + let main_id = Id::test_new(1); + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + + builder.insert_block(); // b0 + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + + let element_type = Rc::new(vec![Type::field()]); + let zero = builder.field_constant(0_u128); + let zero_array = builder.array_constant(im::Vector::unit(zero), element_type.clone()); + let i_zero = builder.numeric_constant(0_u128, Type::unsigned(32)); + let pedersen = + builder.import_intrinsic_id(Intrinsic::BlackBox(acvm::acir::BlackBoxFunc::Pedersen)); + let v4 = builder.insert_call( + pedersen, + vec![zero_array, i_zero], + vec![Type::Array(element_type, 2)], + )[0]; + let v5 = builder.insert_array_get(v4, zero, Type::field()); + let v6 = builder.insert_cast(v5, Type::unsigned(32)); + let i_two = builder.numeric_constant(2_u128, Type::unsigned(32)); + let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two); + let v9 = builder.insert_cast(v8, Type::bool()); + + let v10 = builder.insert_allocate(); + builder.insert_store(v10, zero); + + builder.terminate_with_jmpif(v9, b1, b2); + + builder.switch_to_block(b1); + let one = builder.field_constant(1_u128); + let v14 = builder.insert_binary(v5, BinaryOp::Add, one); + builder.insert_store(v10, v14); + builder.terminate_with_jmp(b3, vec![]); + + builder.switch_to_block(b2); + builder.insert_store(v10, zero); + builder.terminate_with_jmp(b3, vec![]); + + builder.switch_to_block(b3); + let b_true = builder.numeric_constant(1_u128, Type::unsigned(1)); + let v12 = builder.insert_binary(v9, BinaryOp::Eq, b_true); + builder.insert_constrain(v12); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().flatten_cfg(); + let main = ssa.main(); + + // Now assert that there is not an always-false constraint after flattening: + let mut constrain_count = 0; + for instruction in main.dfg[main.entry_block()].instructions() { + if let Instruction::Constrain(value) = main.dfg[*instruction] { + if let Some(constant) = main.dfg.get_numeric_constant(value) { + assert!(constant.is_one()); + } + constrain_count += 1; + } + } + assert_eq!(constrain_count, 1); + } }