diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index 02737f5645b..ae55f85d897 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -113,8 +113,12 @@ impl Context { // We track per block whether an IncrementRc instruction has a paired DecrementRc instruction // with the same value but no array set in between. // If we see an inc/dec RC pair within a block we can safely remove both instructions. - let mut inc_rcs: HashMap> = HashMap::default(); - let mut inc_rcs_to_remove = HashSet::default(); + let mut rcs_with_possible_pairs: HashMap> = HashMap::default(); + let mut rc_pairs_to_remove = HashSet::default(); + // We also separately track all IncrementRc instructions and all arrays which have been mutably borrowed. + // If an array has not been mutably borrowed we can then safely remove all IncrementRc instructions on that array. + let mut inc_rcs: HashMap> = HashMap::default(); + let mut borrowed_arrays: HashSet = HashSet::default(); // Indexes of instructions that might be out of bounds. // We'll remove those, but before that we'll insert bounds checks for them. @@ -146,14 +150,29 @@ impl Context { self.track_inc_rcs_to_remove( *instruction_id, function, + &mut rcs_with_possible_pairs, + &mut rc_pairs_to_remove, &mut inc_rcs, - &mut inc_rcs_to_remove, + &mut borrowed_arrays, ); } - for id in inc_rcs_to_remove { - self.instructions_to_remove.insert(id); - } + let non_mutated_arrays = + inc_rcs + .keys() + .filter_map(|value| { + if !borrowed_arrays.contains(value) { + Some(&inc_rcs[value]) + } else { + None + } + }) + .flatten() + .copied() + .collect::>(); + + self.instructions_to_remove.extend(non_mutated_arrays); + self.instructions_to_remove.extend(rc_pairs_to_remove); // If there are some instructions that might trigger an out of bounds error, // first add constrain checks. Then run the DIE pass again, which will remove those @@ -181,36 +200,50 @@ impl Context { &self, instruction_id: InstructionId, function: &Function, - inc_rcs: &mut HashMap>, + rcs_with_possible_pairs: &mut HashMap>, inc_rcs_to_remove: &mut HashSet, + inc_rcs: &mut HashMap>, + borrowed_arrays: &mut HashSet, ) { let instruction = &function.dfg[instruction_id]; // DIE loops over a block in reverse order, so we insert an RC instruction for possible removal // when we see a DecrementRc and check whether it was possibly mutated when we see an IncrementRc. match instruction { Instruction::IncrementRc { value } => { - if let Some(inc_rc) = pop_rc_for(*value, function, inc_rcs) { + if let Some(inc_rc) = pop_rc_for(*value, function, rcs_with_possible_pairs) { if !inc_rc.possibly_mutated { inc_rcs_to_remove.insert(inc_rc.id); inc_rcs_to_remove.insert(instruction_id); } } + + inc_rcs.entry(*value).or_default().insert(instruction_id); } Instruction::DecrementRc { value } => { let typ = function.dfg.type_of_value(*value); // We assume arrays aren't mutated until we find an array_set - let inc_rc = + let dec_rc = RcInstruction { id: instruction_id, array: *value, possibly_mutated: false }; - inc_rcs.entry(typ).or_default().push(inc_rc); + rcs_with_possible_pairs.entry(typ).or_default().push(dec_rc); } Instruction::ArraySet { array, .. } => { let typ = function.dfg.type_of_value(*array); - if let Some(inc_rcs) = inc_rcs.get_mut(&typ) { - for inc_rc in inc_rcs { - inc_rc.possibly_mutated = true; + if let Some(dec_rcs) = rcs_with_possible_pairs.get_mut(&typ) { + for dec_rc in dec_rcs { + dec_rc.possibly_mutated = true; } } + + borrowed_arrays.insert(*array); + } + Instruction::Store { value, .. } => { + // We are very conservative and say that any store of an array value means it has the potential + // to be mutated. This is done due to the tracking of mutable borrows still being per block. + let typ = function.dfg.type_of_value(*value); + if matches!(&typ, Type::Array(..) | Type::Slice(..)) { + borrowed_arrays.insert(*value); + } } _ => {} } @@ -572,6 +605,8 @@ fn apply_side_effects( mod test { use std::sync::Arc; + use im::vector; + use crate::ssa::{ function_builder::FunctionBuilder, ir::{ @@ -779,4 +814,143 @@ mod test { assert_eq!(main.dfg[main.entry_block()].instructions().len(), 3); } + + #[test] + fn keep_inc_rc_on_borrowed_array_store() { + // acir(inline) fn main f0 { + // b0(): + // v2 = allocate + // inc_rc [u32 0, u32 0] + // store [u32 0, u32 0] at v2 + // inc_rc [u32 0, u32 0] + // jmp b1() + // b1(): + // v3 = load v2 + // v5 = array_set v3, index u32 0, value u32 1 + // return v5 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let zero = builder.numeric_constant(0u128, Type::unsigned(32)); + let array_type = Type::Array(Arc::new(vec![Type::unsigned(32)]), 2); + let array = builder.array_constant(vector![zero, zero], array_type.clone()); + let v2 = builder.insert_allocate(array_type.clone()); + builder.increment_array_reference_count(array); + builder.insert_store(v2, array); + builder.increment_array_reference_count(array); + + let b1 = builder.insert_block(); + builder.terminate_with_jmp(b1, vec![]); + builder.switch_to_block(b1); + + let v3 = builder.insert_load(v2, array_type); + let one = builder.numeric_constant(1u128, Type::unsigned(32)); + let v5 = builder.insert_array_set(v3, zero, one); + builder.terminate_with_return(vec![v5]); + + let ssa = builder.finish(); + let main = ssa.main(); + + // The instruction count never includes the terminator instruction + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 4); + assert_eq!(main.dfg[b1].instructions().len(), 2); + + // We expect the output to be unchanged + let ssa = ssa.dead_instruction_elimination(); + let main = ssa.main(); + + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 4); + assert_eq!(main.dfg[b1].instructions().len(), 2); + } + + #[test] + fn keep_inc_rc_on_borrowed_array_set() { + // acir(inline) fn main f0 { + // b0(v0: [u32; 2]): + // inc_rc v0 + // v3 = array_set v0, index u32 0, value u32 1 + // inc_rc v0 + // inc_rc v0 + // inc_rc v0 + // v4 = array_get v3, index u32 1 + // return v4 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let array_type = Type::Array(Arc::new(vec![Type::unsigned(32)]), 2); + let v0 = builder.add_parameter(array_type.clone()); + builder.increment_array_reference_count(v0); + let zero = builder.numeric_constant(0u128, Type::unsigned(32)); + let one = builder.numeric_constant(1u128, Type::unsigned(32)); + let v3 = builder.insert_array_set(v0, zero, one); + builder.increment_array_reference_count(v0); + builder.increment_array_reference_count(v0); + builder.increment_array_reference_count(v0); + + let v4 = builder.insert_array_get(v3, one, Type::unsigned(32)); + + builder.terminate_with_return(vec![v4]); + + let ssa = builder.finish(); + let main = ssa.main(); + + // The instruction count never includes the terminator instruction + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 6); + + // We expect the output to be unchanged + let ssa = ssa.dead_instruction_elimination(); + let main = ssa.main(); + + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 6); + } + + #[test] + fn remove_inc_rcs_that_are_never_mutably_borrowed() { + // acir(inline) fn main f0 { + // b0(v0: [Field; 2]): + // inc_rc v0 + // inc_rc v0 + // inc_rc v0 + // v2 = array_get v0, index u32 0 + // inc_rc v0 + // return v2 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let v0 = builder.add_parameter(Type::Array(Arc::new(vec![Type::field()]), 2)); + builder.increment_array_reference_count(v0); + builder.increment_array_reference_count(v0); + builder.increment_array_reference_count(v0); + + let zero = builder.numeric_constant(0u128, Type::unsigned(32)); + let v2 = builder.insert_array_get(v0, zero, Type::field()); + builder.increment_array_reference_count(v0); + builder.terminate_with_return(vec![v2]); + + let ssa = builder.finish(); + let main = ssa.main(); + + // The instruction count never includes the terminator instruction + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 5); + + // Expected output: + // + // acir(inline) fn main f0 { + // b0(v0: [Field; 2]): + // v2 = array_get v0, index u32 0 + // return v2 + // } + let ssa = ssa.dead_instruction_elimination(); + let main = ssa.main(); + + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 1); + assert!(matches!(&main.dfg[instructions[0]], Instruction::ArrayGet { .. })); + } }