Skip to content

Commit

Permalink
feat(avm): support skippable relations (#7750)
Browse files Browse the repository at this point in the history
BB supports "skippable relations". If you define `bool Relation::skip(row)` and it returns true, sumcheck will not waste time checking that the accumulations actually sum 0.

In this PR I add a way to define this in PIL and generate code to create `Relation::skip` (see poseidon2.pil).

* Adds fuzzing test for skippable relations.
* Makes poseidon2 relation skippable iff `sel_poseidon_perm = 0`.

----

Skippable relations are great to reclaim time for gadgets, which may have lots of columns (and subrelations) but may seldom be used. In the example below I applied this to poseidon2 which has 200+ columns and 150+ subrelations. On a token transfer, it almost completely reclaimed sumcheck time to how things were before the introduction of poseidon2 constraints.

AFTER (1 cpu)

```
proving time minus check circuit = 44s

prove/all_ms: 79209
prove/check_circuit_ms: 35276
prove/create_composer_ms: 0
prove/create_prover_ms: 3921
prove/create_verifier_ms: 0
prove/execute_log_derivative_inverse_commitments_round_ms: 8499
prove/execute_log_derivative_inverse_round_ms: 5975
prove/execute_pcs_rounds_ms: 1927
>>> prove/execute_relation_check_rounds_ms: 21324
prove/execute_wire_commitments_round_ms: 554
prove/gen_trace_ms: 1579
```

BEFORE (1 cpu)

```
proving time minus check circuit = 85s

prove/all_ms: 120465
prove/check_circuit_ms: 35002
prove/create_composer_ms: 0
prove/create_prover_ms: 2138
prove/create_verifier_ms: 0
prove/execute_log_derivative_inverse_commitments_round_ms: 8568
prove/execute_log_derivative_inverse_round_ms: 5840
prove/execute_pcs_rounds_ms: 2150
>>> prove/execute_relation_check_rounds_ms: 64416
prove/execute_wire_commitments_round_ms: 551
prove/gen_trace_ms: 1602
```

BEFORE POSEIDON2

```
proving time minus check circuit = 38s

prove/all_ms: 51309
prove/check_circuit_ms: 13069
prove/create_composer_ms: 0
prove/create_prover_ms: 1294
prove/create_verifier_ms: 0
prove/execute_log_derivative_inverse_commitments_round_ms: 8537
prove/execute_log_derivative_inverse_round_ms: 2704
prove/execute_pcs_rounds_ms: 1673
>>> prove/execute_relation_check_rounds_ms: 22331
prove/execute_wire_commitments_round_ms: 493
prove/gen_trace_ms: 1079
```
  • Loading branch information
fcarreiro authored Aug 3, 2024
1 parent 609f044 commit 89d7b37
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 41 deletions.
5 changes: 4 additions & 1 deletion barretenberg/cpp/pil/avm/gadgets/poseidon2.pil
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
include "poseidon2_params.pil";

namespace poseidon2(256);

pol commit clk;

// Selector for poseidon2 operation
pol commit sel_poseidon_perm;
// Selector is boolean
sel_poseidon_perm * (1 - sel_poseidon_perm) = 0;

// No relations will be checked if this identity is satisfied.
#[skippable_if]
sel_poseidon_perm = 0;

// The initial mem address for inputs or output
pol commit input_addr;
pol commit output_addr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,10 @@ bool AvmCircuitBuilder::check_circuit() const
for (auto& r : result) {
r = 0;
}
constexpr size_t NUM_SUBRELATIONS = result.size();

for (size_t r = 0; r < num_rows; ++r) {
Relation::accumulate(result, polys.get_row(r), {}, 1);
for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) {
for (size_t j = 0; j < result.size(); ++j) {
if (result[j] != 0) {
signal_error(format("Relation ",
Relation::NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ template <typename FF_> class poseidon2Impl {
8, 7, 7, 7, 8, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 3, 3, 3
};

template <typename AllEntities> inline static bool skip(const AllEntities& in)
{
const auto& new_term = in;
return (new_term.poseidon2_sel_poseidon_perm).is_zero();
}

template <typename ContainerOverSubrelations, typename AllEntities>
void static accumulate(ContainerOverSubrelations& evals,
const AllEntities& new_term,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "barretenberg/common/thread.hpp"
#include "barretenberg/vm/avm/generated/circuit_builder.hpp"
#include "barretenberg/vm/avm/generated/flavor.hpp"
#include "barretenberg/vm/avm/generated/full_row.hpp"

#include <gtest/gtest.h>
#include <memory>
#include <vector>

namespace tests_avm {

using namespace bb;
using namespace bb::Avm_vm;

TEST(AvmSkippableTests, shouldSkipCorrectly)
{
using FF = AvmFlavor::FF;
constexpr size_t TRACE_SIZE = 1 << 15;

std::vector<AvmFullRow<FF>> trace(TRACE_SIZE);
std::cerr << "Generating trace of size " << TRACE_SIZE << "..." << std::endl;
// This is the most time consuming part of this test!
// In particular, the generation of random fields.
bb::parallel_for(trace.size(), [&](size_t i) {
// The first row needs to be zeroes otherwise shifting doesn't work.
if (i == 0) {
return;
}
AvmFullRow<FF>& row = trace[i];

// Fill the row with random values.
auto as_vector = row.as_vector();
const auto as_vector_size = as_vector.size();
for (size_t j = 0; j < as_vector_size; j++) {
// FF::random_element(); is so slow! Using std::rand() instead.
const_cast<FF&>(as_vector[j]) = FF(std::rand());
}

// Set the conditions for skippable to return true.
row.poseidon2_sel_poseidon_perm = 0;
});
std::cerr << "Done generating trace..." << std::endl;

// We build the polynomials needed to run "sumcheck".
AvmCircuitBuilder cb;
cb.set_trace(std::move(trace));
auto polys = cb.compute_polynomials();
std::cerr << "Done computing polynomials..." << std::endl;

// For each skippable relation we will check:
// 1. That Relation::skippable returns true (i.e., we correctly set the conditions)
// 2. That the sumcheck result is zero (i.e., it was ok to skip the relation)
for (size_t ri = 1; ri < TRACE_SIZE; ++ri) {
auto row = polys.get_row(ri);

bb::constexpr_for<0, std::tuple_size_v<AvmFlavor::Relations>, 1>([&]<size_t i>() {
using Relation = std::tuple_element_t<i, AvmFlavor::Relations>;

// We only want to test skippable relations.
if constexpr (isSkippable<Relation, AvmFullRow<FF>>) {
typename Relation::SumcheckArrayOfValuesOverSubrelations result;
for (auto& r : result) {
r = 0;
}

// We set the conditions up there.
auto skip = Relation::skip(row);
EXPECT_TRUE(skip) << "Relation " << Relation::NAME << " was expected to be skippable at row " << ri
<< ".";

Relation::accumulate(result, row, {}, 1);

// If the relation is skippable, the result should be zero.
for (size_t j = 0; j < result.size(); ++j) {
if (result[j] != 0) {
EXPECT_EQ(result[j], 0)
<< "Relation " << Relation::NAME << " subrelation " << j << " was expected to be zero.";
GTEST_SKIP();
}
}
}
});
}
}

} // namespace tests_avm
79 changes: 44 additions & 35 deletions bb-pilcom/bb-pil-backend/src/relation_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ pub struct RelationOutput {

/// Each created bb Identity is passed around with its degree so as needs to be manually
/// provided for sumcheck
type BBIdentity = (DegreeType, String);
#[derive(Debug)]
pub struct BBIdentity {
pub degree: DegreeType,
pub identity: String,
pub label: Option<String>,
}

pub trait RelationBuilder {
/// Create Relations
Expand Down Expand Up @@ -61,8 +66,8 @@ pub trait RelationBuilder {
root_name: &str,
name: &str,
identities: &[BBIdentity],
skippable_if: &Option<BBIdentity>,
all_cols: &[String],
labels: &HashMap<usize, String>,
);
}

Expand All @@ -85,9 +90,9 @@ impl RelationBuilder for BBFiles {
for (relation_name, analyzed_idents) in grouped_relations.iter() {
let IdentitiesOutput {
identities,
skippable_if,
collected_cols,
collected_shifts,
expression_labels,
} = create_identities(file_name, analyzed_idents);

// Aggregate all shifted polys
Expand All @@ -97,8 +102,8 @@ impl RelationBuilder for BBFiles {
file_name,
relation_name,
&identities,
&skippable_if,
&collected_cols,
&expression_labels,
);
}

Expand All @@ -116,25 +121,28 @@ impl RelationBuilder for BBFiles {
root_name: &str,
name: &str,
identities: &[BBIdentity],
skippable_if: &Option<BBIdentity>,
all_cols: &[String],
labels: &HashMap<usize, String>,
) {
let mut handlebars = Handlebars::new();
let degrees: Vec<_> = identities.iter().map(|(d, _)| d + 1).collect();
let sorted_labels = labels
let degrees: Vec<_> = identities.iter().map(|id| id.degree + 1).collect();
let sorted_labels = identities
.iter()
.sorted_by_key(|(idx, _)| *idx)
.collect::<Vec<_>>();
.enumerate()
.filter(|(_, id)| id.label.is_some())
.map(|(idx, id)| (idx, id.label.clone().unwrap()))
.collect_vec();

let data = &json!({
"root_name": root_name,
"name": name,
"identities": identities.iter().map(|(d, id)| {
"identities": identities.iter().map(|id| {
json!({
"degree": d,
"identity": id,
"degree": id.degree,
"identity": id.identity,
})
}).collect::<Vec<_>>(),
}).collect_vec(),
"skippable_if": skippable_if.as_ref().map(|id| id.identity.clone()),
"degrees": degrees,
"all_cols": all_cols,
"labels": sorted_labels,
Expand Down Expand Up @@ -196,13 +204,17 @@ fn create_identity<T: FieldElement>(
expression: &SelectedExpressions<Expression<T>>,
collected_cols: &mut HashSet<String>,
collected_public_identities: &mut HashSet<String>,
label: &Option<String>,
) -> Option<BBIdentity> {
// We want to read the types of operators and then create the appropiate code

if let Some(expr) = &expression.selector {
let x = craft_expression(expr, collected_cols, collected_public_identities);
log::trace!("expression {:?}", x);
Some(x)
let (degree, id) = craft_expression(expr, collected_cols, collected_public_identities);
log::trace!("expression {:?}, {:?}", degree, id);
Some(BBIdentity {
degree: degree,
identity: id,
label: label.clone(),
})
} else {
None
}
Expand All @@ -213,7 +225,7 @@ fn craft_expression<T: FieldElement>(
// TODO: maybe make state?
collected_cols: &mut HashSet<String>,
collected_public_identities: &mut HashSet<String>,
) -> BBIdentity {
) -> (u64, String) {
let var_name = match expr {
Expression::Number(n) => {
let number: BigUint = n.to_arbitrary_integer();
Expand Down Expand Up @@ -319,9 +331,9 @@ fn craft_expression<T: FieldElement>(

pub struct IdentitiesOutput {
identities: Vec<BBIdentity>,
skippable_if: Option<BBIdentity>,
collected_cols: Vec<String>,
collected_shifts: Vec<String>,
expression_labels: HashMap<usize, String>,
}

pub(crate) fn create_identities<F: FieldElement>(
Expand All @@ -336,28 +348,25 @@ pub(crate) fn create_identities<F: FieldElement>(
.collect::<Vec<_>>();

let mut identities = Vec::new();
let mut expression_labels: HashMap<usize, String> = HashMap::new(); // Each relation can be given a label, this label can be assigned here
let mut skippable_if_identity = None;
let mut collected_cols: HashSet<String> = HashSet::new();
let mut collected_public_identities: HashSet<String> = HashSet::new();

// Collect labels for each identity
// TODO: shite
for (i, id) in ids.iter().enumerate() {
if let Some(label) = &id.attribute {
expression_labels.insert(i, label.clone());
}
}

let expressions = ids.iter().map(|id| id.left.clone()).collect::<Vec<_>>();
for (i, expression) in expressions.iter().enumerate() {
// TODO: collected pattern is shit
let mut identity = create_identity(
expression,
for (i, expression) in ids.iter().enumerate() {
let identity = create_identity(
&expression.left,
&mut collected_cols,
&mut collected_public_identities,
&expression.attribute,
)
.unwrap();
identities.push(identity);

if identity.label.clone().is_some_and(|l| l == "skippable_if") {
assert!(skippable_if_identity.is_none());
skippable_if_identity = Some(identity);
} else {
identities.push(identity);
}
}

// Print a warning to the user about usage of public identities
Expand Down Expand Up @@ -386,8 +395,8 @@ pub(crate) fn create_identities<F: FieldElement>(

IdentitiesOutput {
identities,
skippable_if: skippable_if_identity,
collected_cols,
collected_shifts,
expression_labels,
}
}
3 changes: 1 addition & 2 deletions bb-pilcom/bb-pil-backend/templates/circuit_builder.cpp.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ bool {{name}}CircuitBuilder::check_circuit() const {
for (auto& r : result) {
r = 0;
}
constexpr size_t NUM_SUBRELATIONS = result.size();

for (size_t r = 0; r < num_rows; ++r) {
Relation::accumulate(result, polys.get_row(r), {}, 1);
for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) {
for (size_t j = 0; j < result.size(); ++j) {
if (result[j] != 0) {
signal_error(format("Relation ",
Relation::NAME,
Expand Down
10 changes: 9 additions & 1 deletion bb-pilcom/bb-pil-backend/templates/relation.hpp.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ template <typename FF_> class {{name}}Impl {
static constexpr std::array<size_t, {{len degrees}}> SUBRELATION_PARTIAL_LENGTHS = {
{{#each degrees as |degree|}}{{degree}}{{#unless @last}},{{/unless}}{{/each}}
};


{{#if skippable_if}}
template <typename AllEntities> inline static bool skip(const AllEntities& in)
{
const auto& new_term = in;
return ({{skippable_if}}).is_zero();
}
{{/if}}

template <typename ContainerOverSubrelations, typename AllEntities>
void static accumulate(
ContainerOverSubrelations& evals,
Expand Down

0 comments on commit 89d7b37

Please sign in to comment.