Skip to content

Commit

Permalink
feat(acir_gen): Width aware ACIR gen addition (#5493)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4629 

## Summary\*

This PR updates how we add vars in ACIR gen to account for an expression
width. This is under a compile time flag `--bounded-codegen`.

If the sum of two expressions is going to go over the specified
expression width we automatically create witnesses for either the lhs,
the rhs, or both, before then generating a sum expression.

This new bounded codegen could provide an easy way for developers to
tell whether their program can be optimized using `as_witness`.
Reference the additional context section below for some numbers.

## Additional Context


There are some limitations in this approach as it is pretty naive in how
it decides to when to generate a new witness. It doesn't look ahead at
all other than for whether the two AcirVar's being added are going to
create an expression over the specified ACIR width.

For example we can see the following gate counts for
`poseidonsponge_x5_254`:
```
No width awareness:
+-----------------------+----------+----------------------+--------------+
| Package               | Function | Expression Width     | ACIR Opcodes |
+-----------------------+----------+----------------------+--------------+
| poseidonsponge_x5_254 | main     | Bounded { width: 4 } | 3096         |
+-----------------------+----------+----------------------+--------------+

No width awareness w/ as_witness (this is the currently optimized poseidon we have in the stdlib):
+-----------------------+----------+----------------------+--------------+
| Package               | Function | Expression Width     | ACIR Opcodes |
+-----------------------+----------+----------------------+--------------+
| poseidonsponge_x5_254 | main     | Bounded { width: 4 } | 1302         |
+-----------------------+----------+----------------------+--------------+

Width awareness:
+-----------------------+----------+----------------------+--------------+
| Package               | Function | Expression Width     | ACIR Opcodes |
+-----------------------+----------+----------------------+--------------+
| poseidonsponge_x5_254 | main     | Bounded { width: 4 } | 2114         |
+-----------------------+----------+----------------------+--------------+

Width awareness w/ as_witness:
+-----------------------+----------+----------------------+--------------+
| Package               | Function | Expression Width     | ACIR Opcodes |
+-----------------------+----------+----------------------+--------------+
| poseidonsponge_x5_254 | main     | Bounded { width: 4 } | 1792         |
+-----------------------+----------+----------------------+--------------+
```
From the above we can see that we actually have a degradation when using
the addition strategy used in this PR with a hand optimized program
using `as_witness`. Although this PR still gives an improvement in the
default.

Another example is the following program:
```rust
fn main(x: Field, y: pub Field) {
    let state = [x, y];
    let state = oh_no_not_again(state);

    // This assert will fail if we execute
    assert(state[0] + state[1] == 0);
}

fn oh_no_not_again(mut state: [Field; 2]) -> [Field; 2] {
    for _ in 0..200 {
        state[0] = state[0] * state[0] + state[1];
        state[1] += state[0];
    }
    state
}
```
Without any width awareness we get 1150 ACIR gates. With this PR we will
get 399 gates. If we substitute `oh_no_not_again` for the following:
```rust
fn oh_no_not_again_as_witness(mut state: [Field; 2]) -> [Field; 2] {
    for i in 0..200 {
        state[0] = state[0] * state[0] + state[1];
        std::as_witness(state[0]);
        state[1] += state[0];
        if (i & 1 == 1) {
            std::as_witness(state[1]);
        }
    }
    state
}
```
We will get 301 gates if the method above is called instead of
`oh_no_not_again`.

## Documentation\*

Check one:
- [X] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
vezenovm and TomAFrench authored Jul 24, 2024
1 parent cec6390 commit 85fa592
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 103 deletions.
54 changes: 54 additions & 0 deletions acvm-repo/acir/src/native_types/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,60 @@ impl<F: AcirField> Expression<F> {

Expression { mul_terms, linear_combinations, q_c }
}

/// Determine the width of this expression.
/// The width meaning the number of unique witnesses needed for this expression.
pub fn width(&self) -> usize {
let mut width = 0;

for mul_term in &self.mul_terms {
// The coefficient should be non-zero, as this method is ran after the compiler removes all zero coefficient terms
assert_ne!(mul_term.0, F::zero());

let mut found_x = false;
let mut found_y = false;

for term in self.linear_combinations.iter() {
let witness = &term.1;
let x = &mul_term.1;
let y = &mul_term.2;
if witness == x {
found_x = true;
};
if witness == y {
found_y = true;
};
if found_x & found_y {
break;
}
}

// If the multiplication is a squaring then we must assign the two witnesses to separate wires and so we
// can never get a zero contribution to the width.
let multiplication_is_squaring = mul_term.1 == mul_term.2;

let mul_term_width_contribution = if !multiplication_is_squaring && (found_x & found_y)
{
// Both witnesses involved in the multiplication exist elsewhere in the expression.
// They both do not contribute to the width of the expression as this would be double-counting
// due to their appearance in the linear terms.
0
} else if found_x || found_y {
// One of the witnesses involved in the multiplication exists elsewhere in the expression.
// The multiplication then only contributes 1 new witness to the width.
1
} else {
// Worst case scenario, the multiplication is using completely unique witnesses so has a contribution of 2.
2
};

width += mul_term_width_contribution;
}

width += self.linear_combinations.len();

width
}
}

impl<F: AcirField> From<F> for Expression<F> {
Expand Down
65 changes: 1 addition & 64 deletions acvm-repo/acvm/src/compiler/transformers/csat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,71 +415,8 @@ fn fits_in_one_identity<F: AcirField>(expr: &Expression<F>, width: usize) -> boo
if expr.mul_terms.len() > 1 {
return false;
};
// A Polynomial with more terms than fan-in cannot fit within a single opcode
if expr.linear_combinations.len() > width {
return false;
}

// A polynomial with no mul term and a fan-in that fits inside of the width can fit into a single opcode
if expr.mul_terms.is_empty() {
return true;
}

// A polynomial with width-2 fan-in terms and a single non-zero mul term can fit into one opcode
// Example: Axy + Dz . Notice, that the mul term places a constraint on the first two terms, but not the last term
// XXX: This would change if our arithmetic polynomial equation was changed to Axyz for example, but for now it is not.
if expr.linear_combinations.len() <= (width - 2) {
return true;
}

// We now know that we have a single mul term. We also know that the mul term must match up with at least one of the other terms
// A polynomial whose mul terms are non zero which do not match up with two terms in the fan-in cannot fit into one opcode
// An example of this is: Axy + Bx + Cy + ...
// Notice how the bivariate monomial xy has two univariate monomials with their respective coefficients
// XXX: note that if x or y is zero, then we could apply a further optimization, but this would be done in another algorithm.
// It would be the same as when we have zero coefficients - Can only work if wire is constrained to be zero publicly
let mul_term = &expr.mul_terms[0];

// The coefficient should be non-zero, as this method is ran after the compiler removes all zero coefficient terms
assert_ne!(mul_term.0, F::zero());

let mut found_x = false;
let mut found_y = false;

for term in expr.linear_combinations.iter() {
let witness = &term.1;
let x = &mul_term.1;
let y = &mul_term.2;
if witness == x {
found_x = true;
};
if witness == y {
found_y = true;
};
if found_x & found_y {
break;
}
}

// If the multiplication is a squaring then we must assign the two witnesses to separate wires and so we
// can never get a zero contribution to the width.
let multiplication_is_squaring = mul_term.1 == mul_term.2;

let mul_term_width_contribution = if !multiplication_is_squaring && (found_x & found_y) {
// Both witnesses involved in the multiplication exist elsewhere in the expression.
// They both do not contribute to the width of the expression as this would be double-counting
// due to their appearance in the linear terms.
0
} else if found_x || found_y {
// One of the witnesses involved in the multiplication exists elsewhere in the expression.
// The multiplication then only contributes 1 new witness to the width.
1
} else {
// Worst case scenario, the multiplication is using completely unique witnesses so has a contribution of 2.
2
};

mul_term_width_contribution + expr.linear_combinations.len() <= width
expr.width() <= width
}

#[cfg(test)]
Expand Down
17 changes: 17 additions & 0 deletions compiler/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ pub struct CompileOptions {
#[arg(long, value_parser = parse_expression_width)]
pub expression_width: Option<ExpressionWidth>,

/// Generate ACIR with the target backend expression width.
/// The default is to generate ACIR without a bound and split expressions after code generation.
/// Activating this flag can sometimes provide optimizations for certain programs.
#[arg(long, default_value = "false")]
pub bounded_codegen: bool,

/// Force a full recompilation.
#[arg(long = "force")]
pub force_compile: bool,
Expand Down Expand Up @@ -512,6 +518,12 @@ fn compile_contract_inner(
}
}

/// Default expression width used for Noir compilation.
/// The ACVM native type `ExpressionWidth` has its own default which should always be unbounded,
/// while we can sometimes expect the compilation target width to change.
/// Thus, we set it separately here rather than trying to alter the default derivation of the type.
pub const DEFAULT_EXPRESSION_WIDTH: ExpressionWidth = ExpressionWidth::Bounded { width: 4 };

/// Compile the current crate using `main_function` as the entrypoint.
///
/// This function assumes [`check_crate`] is called beforehand.
Expand Down Expand Up @@ -550,6 +562,11 @@ pub fn compile_no_check(
enable_brillig_logging: options.show_brillig,
force_brillig_output: options.force_brillig,
print_codegen_timings: options.benchmark_codegen,
expression_width: if options.bounded_codegen {
options.expression_width.unwrap_or(DEFAULT_EXPRESSION_WIDTH)
} else {
ExpressionWidth::default()
},
};

let SsaProgramArtifact { program, debug, warnings, names, error_types, .. } =
Expand Down
33 changes: 19 additions & 14 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ pub mod ir;
mod opt;
pub mod ssa_gen;

pub struct SsaEvaluatorOptions {
/// Emit debug information for the intermediate SSA IR
pub enable_ssa_logging: bool,

pub enable_brillig_logging: bool,

/// Force Brillig output (for step debugging)
pub force_brillig_output: bool,

/// Pretty print benchmark times of each code generation pass
pub print_codegen_timings: bool,

/// Width of expressions to be used for ACIR
pub expression_width: ExpressionWidth,
}

pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec<SsaReport>);

/// Optimize the given program by converting it into SSA
Expand Down Expand Up @@ -99,7 +115,9 @@ pub(crate) fn optimize_into_acir(

drop(ssa_gen_span_guard);

let artifacts = time("SSA to ACIR", options.print_codegen_timings, || ssa.into_acir(&brillig))?;
let artifacts = time("SSA to ACIR", options.print_codegen_timings, || {
ssa.into_acir(&brillig, options.expression_width)
})?;
Ok(ArtifactsAndWarnings(artifacts, ssa_level_warnings))
}

Expand Down Expand Up @@ -160,19 +178,6 @@ impl SsaProgramArtifact {
}
}

pub struct SsaEvaluatorOptions {
/// Emit debug information for the intermediate SSA IR
pub enable_ssa_logging: bool,

pub enable_brillig_logging: bool,

/// Force Brillig output (for step debugging)
pub force_brillig_output: bool,

/// Pretty print benchmark times of each code generation pass
pub print_codegen_timings: bool,
}

/// Compiles the [`Program`] into [`ACIR``][acvm::acir::circuit::Program].
///
/// The output ACIR is backend-agnostic and so must go through a transformation pass before usage in proof generation.
Expand Down
83 changes: 81 additions & 2 deletions compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::ssa::ir::types::Type as SsaType;
use crate::ssa::ir::{instruction::Endian, types::NumericType};
use acvm::acir::circuit::brillig::{BrilligInputs, BrilligOutputs};
use acvm::acir::circuit::opcodes::{BlockId, BlockType, MemOp};
use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, Opcode};
use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode};
use acvm::blackbox_solver;
use acvm::brillig_vm::{MemoryValue, VMStatus, VM};
use acvm::{
Expand All @@ -24,6 +24,7 @@ use acvm::{
use fxhash::FxHashMap as HashMap;
use iter_extended::{try_vecmap, vecmap};
use num_bigint::BigUint;
use std::cmp::Ordering;
use std::{borrow::Cow, hash::Hash};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -124,9 +125,15 @@ pub(crate) struct AcirContext<F: AcirField> {

/// The BigIntContext, used to generate identifiers for BigIntegers
big_int_ctx: BigIntContext,

expression_width: ExpressionWidth,
}

impl<F: AcirField> AcirContext<F> {
pub(crate) fn set_expression_width(&mut self, expression_width: ExpressionWidth) {
self.expression_width = expression_width;
}

pub(crate) fn current_witness_index(&self) -> Witness {
self.acir_ir.current_witness_index()
}
Expand Down Expand Up @@ -584,6 +591,7 @@ impl<F: AcirField> AcirContext<F> {
pub(crate) fn mul_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, RuntimeError> {
let lhs_data = self.vars[&lhs].clone();
let rhs_data = self.vars[&rhs].clone();

let result = match (lhs_data, rhs_data) {
// (x * 1) == (1 * x) == x
(AcirVarData::Const(constant), _) if constant.is_one() => rhs,
Expand Down Expand Up @@ -655,6 +663,7 @@ impl<F: AcirField> AcirContext<F> {
self.mul_var(lhs, rhs)?
}
};

Ok(result)
}

Expand All @@ -670,9 +679,62 @@ impl<F: AcirField> AcirContext<F> {
pub(crate) fn add_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, RuntimeError> {
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;

let sum_expr = &lhs_expr + &rhs_expr;
if fits_in_one_identity(&sum_expr, self.expression_width) {
let sum_var = self.add_data(AcirVarData::from(sum_expr));

return Ok(sum_var);
}

let sum_expr = match lhs_expr.width().cmp(&rhs_expr.width()) {
Ordering::Greater => {
let lhs_witness_var = self.get_or_create_witness_var(lhs)?;
let lhs_witness_expr = self.var_to_expression(lhs_witness_var)?;

let new_sum_expr = &lhs_witness_expr + &rhs_expr;
if fits_in_one_identity(&new_sum_expr, self.expression_width) {
new_sum_expr
} else {
let rhs_witness_var = self.get_or_create_witness_var(rhs)?;
let rhs_witness_expr = self.var_to_expression(rhs_witness_var)?;

&lhs_expr + &rhs_witness_expr
}
}
Ordering::Less => {
let rhs_witness_var = self.get_or_create_witness_var(rhs)?;
let rhs_witness_expr = self.var_to_expression(rhs_witness_var)?;

let new_sum_expr = &lhs_expr + &rhs_witness_expr;
if fits_in_one_identity(&new_sum_expr, self.expression_width) {
new_sum_expr
} else {
let lhs_witness_var = self.get_or_create_witness_var(lhs)?;
let lhs_witness_expr = self.var_to_expression(lhs_witness_var)?;

Ok(self.add_data(AcirVarData::from(sum_expr)))
&lhs_witness_expr + &rhs_expr
}
}
Ordering::Equal => {
let lhs_witness_var = self.get_or_create_witness_var(lhs)?;
let lhs_witness_expr = self.var_to_expression(lhs_witness_var)?;

let new_sum_expr = &lhs_witness_expr + &rhs_expr;
if fits_in_one_identity(&new_sum_expr, self.expression_width) {
new_sum_expr
} else {
let rhs_witness_var = self.get_or_create_witness_var(rhs)?;
let rhs_witness_expr = self.var_to_expression(rhs_witness_var)?;

&lhs_witness_expr + &rhs_witness_expr
}
}
};

let sum_var = self.add_data(AcirVarData::from(sum_expr));

Ok(sum_var)
}

/// Adds a new Variable to context whose value will
Expand Down Expand Up @@ -1990,6 +2052,23 @@ impl<F: AcirField> From<Expression<F>> for AcirVarData<F> {
}
}

/// Checks if this expression can fit into one arithmetic identity
fn fits_in_one_identity<F: AcirField>(expr: &Expression<F>, width: ExpressionWidth) -> bool {
let width = match &width {
ExpressionWidth::Unbounded => {
return true;
}
ExpressionWidth::Bounded { width } => *width,
};

// A Polynomial with more than one mul term cannot fit into one opcode
if expr.mul_terms.len() > 1 {
return false;
};

expr.width() <= width
}

/// A Reference to an `AcirVarData`
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) struct AcirVar(usize);
Expand Down
Loading

0 comments on commit 85fa592

Please sign in to comment.