diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index c8f6d201d86..ba0a3e9063b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -29,7 +29,7 @@ use crate::ssa::{ basic_block::BasicBlockId, dfg::{DataFlowGraph, InsertInstructionResult}, function::Function, - instruction::{Instruction, InstructionId}, + instruction::{Binary, BinaryOp, Instruction, InstructionId}, types::Type, value::{Value, ValueId}, }, @@ -192,8 +192,32 @@ impl Context { } // Resolve any inputs to ensure that we're comparing like-for-like instructions. + let instruction = instruction + .map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id)); + + // Sort operands of commutative instructions + if let Instruction::Binary(binary) = &instruction { + match binary.operator { + BinaryOp::Add + | BinaryOp::Mul + | BinaryOp::Eq + | BinaryOp::And + | BinaryOp::Or + | BinaryOp::Xor => { + let (lhs, rhs) = if binary.lhs < binary.rhs { + (binary.lhs, binary.rhs) + } else { + (binary.rhs, binary.lhs) + }; + let commutative_instruction = + Instruction::Binary(Binary { lhs, rhs, operator: binary.operator }); + return commutative_instruction; + } + _ => {} + } + } + instruction - .map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id)) } /// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations @@ -264,8 +288,6 @@ impl Context { } } - // If the instruction doesn't have side-effects and if it won't interact with enable_side_effects during acir_gen, - // we cache the results so we can reuse them if the same instruction appears again later in the block. if instruction.can_be_deduplicated(dfg, self.use_constraint_info) { let use_predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); @@ -274,7 +296,7 @@ impl Context { instruction_result_cache .entry(instruction) .or_default() - .insert(predicate, instruction_results); + .insert(predicate, instruction_results.clone()); } } @@ -295,6 +317,31 @@ impl Context { instruction: &Instruction, side_effects_enabled_var: ValueId, ) -> Option<&'a Vec> { + // dbg!(instruction.clone()); + // let new_instruction = if let Instruction::Binary(binary) = &instruction { + // match binary.operator { + // BinaryOp::Add + // | BinaryOp::Mul + // | BinaryOp::Eq + // | BinaryOp::And + // | BinaryOp::Or + // | BinaryOp::Xor => { + // let (lhs, rhs) = if binary.lhs < binary.rhs { + // (binary.lhs, binary.rhs) + // } else { + // (binary.rhs, binary.lhs) + // }; + // let commutative_instruction = + // Instruction::Binary(Binary { lhs, rhs, operator: binary.operator }); + // // instruction_result_cache.get(&commutative_instruction) + // Some(commutative_instruction) + // } + // _ => None, + // } + // } else { + // None + // }; + let results_for_instruction = instruction_result_cache.get(instruction); // See if there's a cached version with no predicate first @@ -841,4 +888,68 @@ mod test { let instructions = main.dfg[main.entry_block()].instructions(); assert_eq!(instructions.len(), 10); } + + #[test] + fn deduplicated_commutative_instructions() { + // acir(inline) fn main f0 { + // b0(v0: u32): + // v2 = mul v0, u32 2 + // v3 = mul u32 2, v0 + // return v2, v3 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let v0 = builder.add_parameter(Type::unsigned(32)); + let v1 = builder.numeric_constant(2u128, Type::unsigned(32)); + + let v3 = builder.insert_binary(v0, BinaryOp::Mul, v1); + let v4 = builder.insert_binary(v1, BinaryOp::Mul, v0); + + builder.terminate_with_return(vec![v3, v4]); + + let ssa = builder.finish(); + + let main = ssa.main(); + let entry_block = &main.dfg[main.entry_block()]; + let instructions = entry_block.instructions(); + assert_eq!(instructions.len(), 2); + + if let TerminatorInstruction::Return { return_values, .. } = entry_block.unwrap_terminator() + { + assert!(return_values[0] != return_values[1]); + } else { + panic!("Should have a return terminator"); + } + + // Expected output: + // + // acir(inline) fn main f0 { + // b0(v0: u32): + // v5 = mul v0, u32 2 + // return v5, v5 + // } + let ssa = ssa.fold_constants_using_constraints(); + + let main = ssa.main(); + let entry_block = &main.dfg[main.entry_block()]; + let instructions = entry_block.instructions(); + assert_eq!(instructions.len(), 1); + + if let Instruction::Binary(binary) = &main.dfg[instructions[0]] { + assert_eq!(binary.lhs, v0); + let constant_two = + main.dfg.get_numeric_constant(binary.rhs).expect("Should have a numeric constant"); + assert_eq!(constant_two.to_u128(), 2u128); + } + + if let TerminatorInstruction::Return { return_values, .. } = entry_block.unwrap_terminator() + { + assert_eq!(main.dfg.resolve(return_values[0]), main.dfg.resolve(return_values[1])); + } else { + panic!("Should have a return terminator"); + } + } }