Skip to content

Commit

Permalink
feat: remove truncations which can be seen to be noops using type inf…
Browse files Browse the repository at this point in the history
…ormation (#3953)

…

# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

## Summary\*

`Truncate`s which occur on outputs of `Binary(BinaryOp::Div)` can
sometimes be determined to be noops based on type information.

Put simply, if we have `(numerator / constant_denom) as T` then if
`type_of(numerator)::max / constant_denom` fits inside `T` then we don't
need to truncate the value.

## Additional Context

Testing this on noir-rsa takes us from

```
+---------+----------------------+--------------+----------------------+
| Package | Expression Width     | ACIR Opcodes | Backend Circuit Size |
+---------+----------------------+--------------+----------------------+
| dkim    | Bounded { width: 3 } | 1649234      | 3049613              |
+---------+----------------------+--------------+----------------------+
```

to

```
+---------+----------------------+--------------+----------------------+
| Package | Expression Width     | ACIR Opcodes | Backend Circuit Size |
+---------+----------------------+--------------+----------------------+
| dkim    | Bounded { width: 3 } | 1486175      | 2764246              |
+---------+----------------------+--------------+----------------------+
```


## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench authored Jan 10, 2024
1 parent 3b0f7d4 commit cc3c2c2
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 11 deletions.
45 changes: 36 additions & 9 deletions compiler/noirc_evaluator/src/ssa/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,16 +533,43 @@ impl Instruction {
let truncated = numeric_constant.to_u128() % integer_modulus;
SimplifiedTo(dfg.make_constant(truncated.into(), typ))
} else if let Value::Instruction { instruction, .. } = &dfg[dfg.resolve(*value)] {
if let Instruction::Truncate { bit_size: src_bit_size, .. } = &dfg[*instruction]
{
// If we're truncating the value to fit into the same or larger bit size then this is a noop.
if src_bit_size <= bit_size && src_bit_size <= max_bit_size {
SimplifiedTo(*value)
} else {
None
match &dfg[*instruction] {
Instruction::Truncate { bit_size: src_bit_size, .. } => {
// If we're truncating the value to fit into the same or larger bit size then this is a noop.
if src_bit_size <= bit_size && src_bit_size <= max_bit_size {
SimplifiedTo(*value)
} else {
None
}
}
} else {
None

Instruction::Binary(Binary {
lhs, rhs, operator: BinaryOp::Div, ..
}) if dfg.is_constant(*rhs) => {
// If we're truncating the result of a division by a constant denominator, we can
// reason about the maximum bit size of the result and whether a truncation is necessary.

let numerator_type = dfg.type_of_value(*lhs);
let max_numerator_bits = numerator_type.bit_size();

let divisor = dfg
.get_numeric_constant(*rhs)
.expect("rhs is checked to be constant.");
let divisor_bits = divisor.num_bits();

// 2^{max_quotient_bits} = 2^{max_numerator_bits} / 2^{divisor_bits}
// => max_quotient_bits = max_numerator_bits - divisor_bits
//
// In order for the truncation to be a noop, we then require `max_quotient_bits < bit_size`.
let max_quotient_bits = max_numerator_bits - divisor_bits;
if max_quotient_bits < *bit_size {
SimplifiedTo(*value)
} else {
None
}
}

_ => None,
}
} else {
None
Expand Down
115 changes: 113 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ mod test {
function_builder::FunctionBuilder,
ir::{
function::RuntimeType,
instruction::{BinaryOp, Instruction, TerminatorInstruction},
instruction::{Binary, BinaryOp, Instruction, TerminatorInstruction},
map::Id,
types::Type,
value::Value,
value::{Value, ValueId},
},
};

Expand Down Expand Up @@ -247,6 +247,117 @@ mod test {
}
}

#[test]
fn redundant_truncation() {
// fn main f0 {
// b0(v0: u16, v1: u16):
// v2 = div v0, v1
// v3 = truncate v2 to 8 bits, max_bit_size: 16
// return v3
// }
//
// After constructing this IR, we set the value of v1 to 2^8.
// The expected return afterwards should be v2.
let main_id = Id::test_new(0);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);
let v0 = builder.add_parameter(Type::unsigned(16));
let v1 = builder.add_parameter(Type::unsigned(16));

// Note that this constant guarantees that `v0/constant < 2^8`. We then do not need to truncate the result.
let constant = 2_u128.pow(8);
let constant = builder.numeric_constant(constant, Type::field());

let v2 = builder.insert_binary(v0, BinaryOp::Div, v1);
let v3 = builder.insert_truncate(v2, 8, 16);
builder.terminate_with_return(vec![v3]);

let mut ssa = builder.finish();
let main = ssa.main_mut();
let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 2); // The final return is not counted

// Expected output:
//
// fn main f0 {
// b0(Field 2: Field):
// return Field 9
// }
main.dfg.set_value_from_id(v1, constant);

let ssa = ssa.fold_constants();
let main = ssa.main();

println!("{ssa}");

let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 1);
let instruction = &main.dfg[instructions[0]];

assert_eq!(
instruction,
&Instruction::Binary(Binary { lhs: v0, operator: BinaryOp::Div, rhs: constant })
);
}

#[test]
fn non_redundant_truncation() {
// fn main f0 {
// b0(v0: u16, v1: u16):
// v2 = div v0, v1
// v3 = truncate v2 to 8 bits, max_bit_size: 16
// return v3
// }
//
// After constructing this IR, we set the value of v1 to 2^8 - 1.
// This should not result in the truncation being removed.
let main_id = Id::test_new(0);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);
let v0 = builder.add_parameter(Type::unsigned(16));
let v1 = builder.add_parameter(Type::unsigned(16));

// Note that this constant does not guarantee that `v0/constant < 2^8`. We must then truncate the result.
let constant = 2_u128.pow(8) - 1;
let constant = builder.numeric_constant(constant, Type::field());

let v2 = builder.insert_binary(v0, BinaryOp::Div, v1);
let v3 = builder.insert_truncate(v2, 8, 16);
builder.terminate_with_return(vec![v3]);

let mut ssa = builder.finish();
let main = ssa.main_mut();
let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 2); // The final return is not counted

// Expected output:
//
// fn main f0 {
// b0(v0: u16, Field 255: Field):
// v5 = div v0, Field 255
// v6 = truncate v5 to 8 bits, max_bit_size: 16
// return v6
// }
main.dfg.set_value_from_id(v1, constant);

let ssa = ssa.fold_constants();
let main = ssa.main();

let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 2);

assert_eq!(
&main.dfg[instructions[0]],
&Instruction::Binary(Binary { lhs: v0, operator: BinaryOp::Div, rhs: constant })
);
assert_eq!(
&main.dfg[instructions[1]],
&Instruction::Truncate { value: ValueId::test_new(5), bit_size: 8, max_bit_size: 16 }
);
}

#[test]
fn arrays_elements_are_updated() {
// fn main f0 {
Expand Down

0 comments on commit cc3c2c2

Please sign in to comment.