Skip to content

Commit

Permalink
Merge 565c521 into 3c778b7
Browse files Browse the repository at this point in the history
  • Loading branch information
vezenovm authored Aug 27, 2024
2 parents 3c778b7 + 565c521 commit 6605711
Showing 1 changed file with 116 additions and 5 deletions.
121 changes: 116 additions & 5 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::ssa::{
basic_block::BasicBlockId,
dfg::{DataFlowGraph, InsertInstructionResult},
function::Function,
instruction::{Instruction, InstructionId},
instruction::{Binary, BinaryOp, Instruction, InstructionId},
types::Type,
value::{Value, ValueId},
},
Expand Down Expand Up @@ -192,8 +192,32 @@ impl Context {
}

// Resolve any inputs to ensure that we're comparing like-for-like instructions.
let instruction = instruction
.map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id));

// Sort operands of commutative instructions
if let Instruction::Binary(binary) = &instruction {
match binary.operator {
BinaryOp::Add
| BinaryOp::Mul
| BinaryOp::Eq
| BinaryOp::And
| BinaryOp::Or
| BinaryOp::Xor => {
let (lhs, rhs) = if binary.lhs < binary.rhs {
(binary.lhs, binary.rhs)
} else {
(binary.rhs, binary.lhs)
};
let commutative_instruction =
Instruction::Binary(Binary { lhs, rhs, operator: binary.operator });
return commutative_instruction;
}
_ => {}
}
}

instruction
.map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id))
}

/// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations
Expand Down Expand Up @@ -264,8 +288,6 @@ impl Context {
}
}

// If the instruction doesn't have side-effects and if it won't interact with enable_side_effects during acir_gen,
// we cache the results so we can reuse them if the same instruction appears again later in the block.
if instruction.can_be_deduplicated(dfg, self.use_constraint_info) {
let use_predicate =
self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg);
Expand All @@ -274,7 +296,7 @@ impl Context {
instruction_result_cache
.entry(instruction)
.or_default()
.insert(predicate, instruction_results);
.insert(predicate, instruction_results.clone());
}
}

Expand All @@ -295,6 +317,31 @@ impl Context {
instruction: &Instruction,
side_effects_enabled_var: ValueId,
) -> Option<&'a Vec<ValueId>> {
// dbg!(instruction.clone());
// let new_instruction = if let Instruction::Binary(binary) = &instruction {
// match binary.operator {
// BinaryOp::Add
// | BinaryOp::Mul
// | BinaryOp::Eq
// | BinaryOp::And
// | BinaryOp::Or
// | BinaryOp::Xor => {
// let (lhs, rhs) = if binary.lhs < binary.rhs {
// (binary.lhs, binary.rhs)
// } else {
// (binary.rhs, binary.lhs)
// };
// let commutative_instruction =
// Instruction::Binary(Binary { lhs, rhs, operator: binary.operator });
// // instruction_result_cache.get(&commutative_instruction)
// Some(commutative_instruction)
// }
// _ => None,
// }
// } else {
// None
// };

let results_for_instruction = instruction_result_cache.get(instruction);

// See if there's a cached version with no predicate first
Expand Down Expand Up @@ -841,4 +888,68 @@ mod test {
let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 10);
}

#[test]
fn deduplicated_commutative_instructions() {
// acir(inline) fn main f0 {
// b0(v0: u32):
// v2 = mul v0, u32 2
// v3 = mul u32 2, v0
// return v2, v3
// }
let main_id = Id::test_new(0);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id);

let v0 = builder.add_parameter(Type::unsigned(32));
let v1 = builder.numeric_constant(2u128, Type::unsigned(32));

let v3 = builder.insert_binary(v0, BinaryOp::Mul, v1);
let v4 = builder.insert_binary(v1, BinaryOp::Mul, v0);

builder.terminate_with_return(vec![v3, v4]);

let ssa = builder.finish();

let main = ssa.main();
let entry_block = &main.dfg[main.entry_block()];
let instructions = entry_block.instructions();
assert_eq!(instructions.len(), 2);

if let TerminatorInstruction::Return { return_values, .. } = entry_block.unwrap_terminator()
{
assert!(return_values[0] != return_values[1]);
} else {
panic!("Should have a return terminator");
}

// Expected output:
//
// acir(inline) fn main f0 {
// b0(v0: u32):
// v5 = mul v0, u32 2
// return v5, v5
// }
let ssa = ssa.fold_constants_using_constraints();

let main = ssa.main();
let entry_block = &main.dfg[main.entry_block()];
let instructions = entry_block.instructions();
assert_eq!(instructions.len(), 1);

if let Instruction::Binary(binary) = &main.dfg[instructions[0]] {
assert_eq!(binary.lhs, v0);
let constant_two =
main.dfg.get_numeric_constant(binary.rhs).expect("Should have a numeric constant");
assert_eq!(constant_two.to_u128(), 2u128);
}

if let TerminatorInstruction::Return { return_values, .. } = entry_block.unwrap_terminator()
{
assert_eq!(main.dfg.resolve(return_values[0]), main.dfg.resolve(return_values[1]));
} else {
panic!("Should have a return terminator");
}
}
}

0 comments on commit 6605711

Please sign in to comment.