Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(perf): Cache commutative instructions in constant folding #5832

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::ssa::{
basic_block::BasicBlockId,
dfg::{DataFlowGraph, InsertInstructionResult},
function::Function,
instruction::{Instruction, InstructionId},
instruction::{Binary, BinaryOp, Instruction, InstructionId},
types::Type,
value::{Value, ValueId},
},
Expand Down Expand Up @@ -192,8 +192,32 @@ impl Context {
}

// Resolve any inputs to ensure that we're comparing like-for-like instructions.
let instruction = instruction
.map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id));

// Sort operands of commutative instructions
if let Instruction::Binary(binary) = &instruction {
match binary.operator {
Comment on lines +198 to +200
Copy link
Contributor

@jfecher jfecher Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I meant sorting when the instructions are first created. Not during constant folding. Probably won't remove the size increase in the tests though.

BinaryOp::Add
| BinaryOp::Mul
| BinaryOp::Eq
| BinaryOp::And
| BinaryOp::Or
| BinaryOp::Xor => {
let (lhs, rhs) = if binary.lhs < binary.rhs {
(binary.lhs, binary.rhs)
} else {
(binary.rhs, binary.lhs)
};
let commutative_instruction =
Instruction::Binary(Binary { lhs, rhs, operator: binary.operator });
return commutative_instruction;
}
_ => {}
}
}

instruction
.map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id))
}

/// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations
Expand Down Expand Up @@ -264,8 +288,6 @@ impl Context {
}
}

// If the instruction doesn't have side-effects and if it won't interact with enable_side_effects during acir_gen,
// we cache the results so we can reuse them if the same instruction appears again later in the block.
if instruction.can_be_deduplicated(dfg, self.use_constraint_info) {
let use_predicate =
self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
Expand All @@ -274,7 +296,7 @@ impl Context {
instruction_result_cache
.entry(instruction)
.or_default()
.insert(predicate, instruction_results);
.insert(predicate, instruction_results.clone());
}
}

Expand All @@ -295,6 +317,31 @@ impl Context {
instruction: &Instruction,
side_effects_enabled_var: ValueId,
) -> Option<&'a Vec<ValueId>> {
// dbg!(instruction.clone());
// let new_instruction = if let Instruction::Binary(binary) = &instruction {
// match binary.operator {
// BinaryOp::Add
// | BinaryOp::Mul
// | BinaryOp::Eq
// | BinaryOp::And
// | BinaryOp::Or
// | BinaryOp::Xor => {
// let (lhs, rhs) = if binary.lhs < binary.rhs {
// (binary.lhs, binary.rhs)
// } else {
// (binary.rhs, binary.lhs)
// };
// let commutative_instruction =
// Instruction::Binary(Binary { lhs, rhs, operator: binary.operator });
// // instruction_result_cache.get(&commutative_instruction)
// Some(commutative_instruction)
// }
// _ => None,
// }
// } else {
// None
// };

let results_for_instruction = instruction_result_cache.get(instruction);

// See if there's a cached version with no predicate first
Expand Down Expand Up @@ -841,4 +888,68 @@ mod test {
let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 10);
}

#[test]
fn deduplicated_commutative_instructions() {
// acir(inline) fn main f0 {
// b0(v0: u32):
// v2 = mul v0, u32 2
// v3 = mul u32 2, v0
// return v2, v3
// }
let main_id = Id::test_new(0);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id);

let v0 = builder.add_parameter(Type::unsigned(32));
let v1 = builder.numeric_constant(2u128, Type::unsigned(32));

let v3 = builder.insert_binary(v0, BinaryOp::Mul, v1);
let v4 = builder.insert_binary(v1, BinaryOp::Mul, v0);

builder.terminate_with_return(vec![v3, v4]);

let ssa = builder.finish();

let main = ssa.main();
let entry_block = &main.dfg[main.entry_block()];
let instructions = entry_block.instructions();
assert_eq!(instructions.len(), 2);

if let TerminatorInstruction::Return { return_values, .. } = entry_block.unwrap_terminator()
{
assert!(return_values[0] != return_values[1]);
} else {
panic!("Should have a return terminator");
}

// Expected output:
//
// acir(inline) fn main f0 {
// b0(v0: u32):
// v5 = mul v0, u32 2
// return v5, v5
// }
let ssa = ssa.fold_constants_using_constraints();

let main = ssa.main();
let entry_block = &main.dfg[main.entry_block()];
let instructions = entry_block.instructions();
assert_eq!(instructions.len(), 1);

if let Instruction::Binary(binary) = &main.dfg[instructions[0]] {
assert_eq!(binary.lhs, v0);
let constant_two =
main.dfg.get_numeric_constant(binary.rhs).expect("Should have a numeric constant");
assert_eq!(constant_two.to_u128(), 2u128);
}

if let TerminatorInstruction::Return { return_values, .. } = entry_block.unwrap_terminator()
{
assert_eq!(main.dfg.resolve(return_values[0]), main.dfg.resolve(return_values[1]));
} else {
panic!("Should have a return terminator");
}
}
}
Loading