From 467948f9ee9ae65b4e2badaa1d15835fced3e835 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:38:37 +0000 Subject: [PATCH] fix: prevent `Instruction::Constrain`s for non-primitive types (#3916) # Description ## Problem\* Followup to #3740 ## Summary\* #3740 fixed an issue where array equalities were making their way into SSA and not having side effect predicates applied correctly by applying the predicate to each of the array elements. We actually state that we do not want array equalities in SSA (in `insert_array_equality`) so the fundamental issue of array equalities in SSA still exists. We were allowing an implicit array equality to sneak into SSA by performing an optimization of `Constrain(Eq(x, y), 1)` into `Constrain(x, y)` during codegen. This meant that if `x` and `y` were arrays then we bypass the `insert_array_equality` function which the `Eq` instruction would call, which would have calculated a primitive predicate value for the constrain statement to act on. This PR removes the extra logic from the `flatten_cfg` pass (while adding an assert that we're only constraining primitive values) and removes the faulty optimization from SSA codegen. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/ssa/opt/flatten_cfg.rs | 105 ++++-------------- .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 24 +--- 2 files changed, 24 insertions(+), 105 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 0e7bfff7b6b..fdd7c66684c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -641,8 +641,25 @@ impl<'f> Context<'f> { match instruction { Instruction::Constrain(lhs, rhs, message) => { // Replace constraint `lhs == rhs` with `condition * lhs == condition * rhs`. - let lhs = self.handle_constrain_arg_side_effects(lhs, condition, &call_stack); - let rhs = self.handle_constrain_arg_side_effects(rhs, condition, &call_stack); + + // Condition needs to be cast to argument type in order to multiply them together. + let argument_type = self.inserter.function.dfg.type_of_value(lhs); + // Sanity check that we're not constraining non-primitive types + assert!(matches!(argument_type, Type::Numeric(_))); + + let casted_condition = self.insert_instruction( + Instruction::Cast(condition, argument_type), + call_stack.clone(), + ); + + let lhs = self.insert_instruction( + Instruction::binary(BinaryOp::Mul, lhs, casted_condition), + call_stack.clone(), + ); + let rhs = self.insert_instruction( + Instruction::binary(BinaryOp::Mul, rhs, casted_condition), + call_stack, + ); Instruction::Constrain(lhs, rhs, message) } @@ -673,90 +690,6 @@ impl<'f> Context<'f> { } } - /// Given the arguments of a constrain instruction, multiplying them by the branch's condition - /// requires special handling in the case of complex types. - fn handle_constrain_arg_side_effects( - &mut self, - argument: ValueId, - condition: ValueId, - call_stack: &CallStack, - ) -> ValueId { - let argument_type = self.inserter.function.dfg.type_of_value(argument); - - match &argument_type { - Type::Numeric(_) => { - // Condition needs to be cast to argument type in order to multiply them together. - let casted_condition = self.insert_instruction( - Instruction::Cast(condition, argument_type), - call_stack.clone(), - ); - - self.insert_instruction( - Instruction::binary(BinaryOp::Mul, argument, casted_condition), - call_stack.clone(), - ) - } - Type::Array(_, _) => { - self.handle_array_constrain_arg(argument_type, argument, condition, call_stack) - } - Type::Slice(_) => { - panic!("Cannot use slices directly in a constrain statement") - } - Type::Reference(_) => { - panic!("Cannot use references directly in a constrain statement") - } - Type::Function => { - panic!("Cannot use functions directly in a constrain statement") - } - } - } - - fn handle_array_constrain_arg( - &mut self, - typ: Type, - argument: ValueId, - condition: ValueId, - call_stack: &CallStack, - ) -> ValueId { - let mut new_array = im::Vector::new(); - - let (element_types, len) = match &typ { - Type::Array(elements, len) => (elements, *len), - _ => panic!("Expected array type"), - }; - - for i in 0..len { - for (element_index, element_type) in element_types.iter().enumerate() { - let index = ((i * element_types.len() + element_index) as u128).into(); - let index = self.inserter.function.dfg.make_constant(index, Type::field()); - - let typevars = Some(vec![element_type.clone()]); - - let mut get_element = |array, typevars| { - let get = Instruction::ArrayGet { array, index }; - self.inserter - .function - .dfg - .insert_instruction_and_results( - get, - self.inserter.function.entry_block(), - typevars, - CallStack::new(), - ) - .first() - }; - - let element = get_element(argument, typevars); - - new_array.push_back( - self.handle_constrain_arg_side_effects(element, condition, call_stack), - ); - } - } - - self.inserter.function.dfg.make_array(new_array, typ) - } - fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) { for (address, store) in &then_branch.store_values { let address = *address; diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index d7e6b8b0a3d..c00fbbbcb40 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -8,8 +8,8 @@ use context::SharedContext; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; use noirc_frontend::{ - monomorphization::ast::{self, Binary, Expression, Program}, - BinaryOpKind, Visibility, + monomorphization::ast::{self, Expression, Program}, + Visibility, }; use crate::{ @@ -653,24 +653,10 @@ impl<'a> FunctionContext<'a> { location: Location, assert_message: Option, ) -> Result { - match expr { - // If we're constraining an equality to be true then constrain the two sides directly. - Expression::Binary(Binary { lhs, operator: BinaryOpKind::Equal, rhs, .. }) => { - let lhs = self.codegen_non_tuple_expression(lhs)?; - let rhs = self.codegen_non_tuple_expression(rhs)?; - self.builder.set_location(location).insert_constrain(lhs, rhs, assert_message); - } + let expr = self.codegen_non_tuple_expression(expr)?; + let true_literal = self.builder.numeric_constant(true, Type::bool()); + self.builder.set_location(location).insert_constrain(expr, true_literal, assert_message); - _ => { - let expr = self.codegen_non_tuple_expression(expr)?; - let true_literal = self.builder.numeric_constant(true, Type::bool()); - self.builder.set_location(location).insert_constrain( - expr, - true_literal, - assert_message, - ); - } - } Ok(Self::unit_value()) }