Skip to content

Commit

Permalink
refactor: clean out prover instance and remove instance from oink (#5314
Browse files Browse the repository at this point in the history
)

Moves sorted_polynomials, initialize_prover_polynomials,
compute_sorted_accumulator_polynomials, compute_sorted_list_accumulator,
compute_logderivative_inverse, compute_grand_product_polynomials out of
prover instance and into the proving_key. We also modify the OinkProver
to return an OinkProverOutput, which is just the relation_parameters.

These changes enable us to remove instance from the oink prover, and
only take in a proving_key.

```
--------------------------------------------------------------------------------
Benchmark                      Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------
ClientIVCBench/Full/6      23860 ms        18897 ms            1 Decider::construct_proof=1 Decider::construct_proof(t)=751.457M ECCVMComposer::compute_commitment_key=1 ECCVMComposer::compute_commitment_key(t)=3.7316M ECCVMComposer::compute_witness=1 ECCVMComposer::compute_witness(t)=129.141M ECCVMComposer::create_prover=1 ECCVMComposer::create_prover(t)=149.103M ECCVMComposer::create_proving_key=1 ECCVMComposer::create_proving_key(t)=16.0015M ECCVMProver::construct_proof=1 ECCVMProver::construct_proof(t)=1.76309G Goblin::merge=11 Goblin::merge(t)=135.854M GoblinTranslatorCircuitBuilder::constructor=1 GoblinTranslatorCircuitBuilder::constructor(t)=57.9044M GoblinTranslatorProver=1 GoblinTranslatorProver(t)=125.368M GoblinTranslatorProver::construct_proof=1 GoblinTranslatorProver::construct_proof(t)=953.077M ProtoGalaxyProver_::accumulator_update_round=10 ProtoGalaxyProver_::accumulator_update_round(t)=725.626M ProtoGalaxyProver_::combiner_quotient_round=10 ProtoGalaxyProver_::combiner_quotient_round(t)=7.21724G ProtoGalaxyProver_::perturbator_round=10 ProtoGalaxyProver_::perturbator_round(t)=1.3161G ProtoGalaxyProver_::preparation_round=10 ProtoGalaxyProver_::preparation_round(t)=4.14482G ProtogalaxyProver::fold_instances=10 ProtogalaxyProver::fold_instances(t)=13.4038G ProverInstance(Circuit&)=11 ProverInstance(Circuit&)(t)=2.02586G batch_mul_with_endomorphism=30 batch_mul_with_endomorphism(t)=565.397M commit=425 commit(t)=4.00022G compute_combiner=10 compute_combiner(t)=7.21503G compute_perturbator=9 compute_perturbator(t)=1.31577G compute_univariate=48 compute_univariate(t)=1.4239G construct_circuits=6 construct_circuits(t)=4.4678G
Benchmarking lock deleted.
client_ivc_bench.json                                              100% 3995   110.5KB/s   00:00    
function                                        ms     % sum
construct_circuits(t)                         4468    18.89%
ProverInstance(Circuit&)(t)                   2026     8.57%
ProtogalaxyProver::fold_instances(t)         13404    56.68%
Decider::construct_proof(t)                    751     3.18%
ECCVMComposer::create_prover(t)                149     0.63%
ECCVMProver::construct_proof(t)               1763     7.45%
GoblinTranslatorProver::construct_proof(t)     953     4.03%
Goblin::merge(t)                               136     0.57%

Total time accounted for: 23650ms/23860ms = 99.12%

Major contributors:
function                                        ms    % sum
commit(t)                                     4000   16.91%
compute_combiner(t)                           7215   30.51%
compute_perturbator(t)                        1316    5.56%
compute_univariate(t)                         1424    6.02%

Breakdown of ProtogalaxyProver::fold_instances:
ProtoGalaxyProver_::preparation_round(t)           4145    30.92%
ProtoGalaxyProver_::perturbator_round(t)           1316     9.82%
ProtoGalaxyProver_::combiner_quotient_round(t)     7217    53.84%
ProtoGalaxyProver_::accumulator_update_round(t)     726     5.41%
```
  • Loading branch information
lucasxia01 authored Mar 22, 2024
1 parent 86a181b commit a83368c
Show file tree
Hide file tree
Showing 16 changed files with 384 additions and 249 deletions.
1 change: 0 additions & 1 deletion barretenberg/cpp/scripts/analyze_client_ivc_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"ProtogalaxyProver::fold_instances(t)",
"Decider::construct_proof(t)",
"ECCVMComposer::create_prover(t)",
"GoblinTranslatorComposer::create_prover(t)",
"ECCVMProver::construct_proof(t)",
"GoblinTranslatorProver::construct_proof(t)",
"Goblin::merge(t)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ BB_PROFILE static void test_round_inner(State& state, GoblinUltraProver& prover,
time_if_index(SORTED_LIST_ACCUMULATOR, [&] { prover.oink_prover.execute_sorted_list_accumulator_round(); });
time_if_index(LOG_DERIVATIVE_INVERSE, [&] { prover.oink_prover.execute_log_derivative_inverse_round(); });
time_if_index(GRAND_PRODUCT_COMPUTATION, [&] { prover.oink_prover.execute_grand_product_computation_round(); });
// we need to get the relation_parameters and prover_polynomials from the oink_prover
prover.instance->relation_parameters = prover.oink_prover.relation_parameters;
prover.instance->prover_polynomials = GoblinUltraFlavor::ProverPolynomials(prover.instance->proving_key);
time_if_index(RELATION_CHECK, [&] { prover.execute_relation_check_rounds(); });
time_if_index(ZEROMORPH, [&] { prover.execute_zeromorph_rounds(); });
}
Expand Down
143 changes: 143 additions & 0 deletions barretenberg/cpp/src/barretenberg/flavor/goblin_ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "barretenberg/honk/proof_system/types/proof.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/proof_system/circuit_builder/goblin_ultra_circuit_builder.hpp"
#include "barretenberg/proof_system/library/grand_product_delta.hpp"
#include "barretenberg/proof_system/library/grand_product_library.hpp"
#include "barretenberg/relations/auxiliary_relation.hpp"
#include "barretenberg/relations/databus_lookup_relation.hpp"
#include "barretenberg/relations/delta_range_constraint_relation.hpp"
Expand Down Expand Up @@ -266,6 +268,7 @@ class GoblinUltraFlavor {

std::vector<uint32_t> memory_read_records;
std::vector<uint32_t> memory_write_records;
std::array<Polynomial, 4> sorted_polynomials;

size_t num_ecc_op_gates; // needed to determine public input offset

Expand All @@ -276,6 +279,124 @@ class GoblinUltraFlavor {
};
// The plookup wires that store plookup read data.
auto get_table_column_wires() { return RefArray{ w_l, w_r, w_o }; };

/**
* @brief Construct sorted list accumulator polynomial 's'.
*
* @details Compute s = s_1 + η*s_2 + η²*s_3 + η³*s_4 (via Horner) where s_i are the
* sorted concatenated witness/table polynomials
*
* @param key proving key
* @param sorted_list_polynomials sorted concatenated witness/table polynomials
* @param eta random challenge
* @return Polynomial
*/
void compute_sorted_list_accumulator(const FF& eta)
{
const size_t circuit_size = this->circuit_size;

auto sorted_list_accumulator = Polynomial{ circuit_size };

// Construct s via Horner, i.e. s = s_1 + η(s_2 + η(s_3 + η*s_4))
for (size_t i = 0; i < circuit_size; ++i) {
FF T0 = this->sorted_polynomials[3][i];
T0 *= eta;
T0 += this->sorted_polynomials[2][i];
T0 *= eta;
T0 += this->sorted_polynomials[1][i];
T0 *= eta;
T0 += this->sorted_polynomials[0][i];
sorted_list_accumulator[i] = T0;
}
this->sorted_accum = sorted_list_accumulator.share();
}

void compute_sorted_accumulator_polynomials(const FF& eta)
{
// Compute sorted witness-table accumulator
this->compute_sorted_list_accumulator(eta);

// Finalize fourth wire polynomial by adding lookup memory records
add_plookup_memory_records_to_wire_4(eta);
}

/**
* @brief Add plookup memory records to the fourth wire polynomial
*
* @details This operation must be performed after the first three wires have been committed to, hence the
* dependence on the `eta` challenge.
*
* @tparam Flavor
* @param eta challenge produced after commitment to first three wire polynomials
*/
void add_plookup_memory_records_to_wire_4(const FF& eta)
{
// The plookup memory record values are computed at the indicated indices as
// w4 = w3 * eta^3 + w2 * eta^2 + w1 * eta + read_write_flag;
// (See plookup_auxiliary_widget.hpp for details)
auto wires = this->get_wires();

// Compute read record values
for (const auto& gate_idx : this->memory_read_records) {
wires[3][gate_idx] += wires[2][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[1][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[0][gate_idx];
wires[3][gate_idx] *= eta;
}

// Compute write record values
for (const auto& gate_idx : this->memory_write_records) {
wires[3][gate_idx] += wires[2][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[1][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[0][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += 1;
}
}

/**
* @brief Compute the inverse polynomial used in the log derivative lookup argument
*
* @tparam Flavor
* @param beta
* @param gamma
*/
void compute_logderivative_inverse(const RelationParameters<FF>& relation_parameters)
{
auto prover_polynomials = ProverPolynomials(*this);
// Compute permutation and lookup grand product polynomials
bb::compute_logderivative_inverse<GoblinUltraFlavor, typename GoblinUltraFlavor::LogDerivLookupRelation>(
prover_polynomials, relation_parameters, this->circuit_size);
this->lookup_inverses = prover_polynomials.lookup_inverses;
}

/**
* @brief Computes public_input_delta, lookup_grand_product_delta, the z_perm and z_lookup polynomials
*
* @param relation_parameters
*/
void compute_grand_product_polynomials(RelationParameters<FF>& relation_parameters)
{
auto public_input_delta = compute_public_input_delta<GoblinUltraFlavor>(this->public_inputs,
relation_parameters.beta,
relation_parameters.gamma,
this->circuit_size,
this->pub_inputs_offset);
relation_parameters.public_input_delta = public_input_delta;
auto lookup_grand_product_delta = compute_lookup_grand_product_delta(
relation_parameters.beta, relation_parameters.gamma, this->circuit_size);
relation_parameters.lookup_grand_product_delta = lookup_grand_product_delta;

// Compute permutation and lookup grand product polynomials
auto prover_polynomials = ProverPolynomials(*this);
compute_grand_products<GoblinUltraFlavor>(*this, prover_polynomials, relation_parameters);
this->z_perm = prover_polynomials.z_perm;
this->z_lookup = prover_polynomials.z_lookup;
}
};

/**
Expand Down Expand Up @@ -331,6 +452,28 @@ class GoblinUltraFlavor {
*/
class ProverPolynomials : public AllEntities<Polynomial> {
public:
ProverPolynomials(ProvingKey& proving_key)
{
for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key.get_all())) {
ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(proving_key, key_poly));
prover_poly = key_poly.share();
}
for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key.get_to_be_shifted())) {
ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(proving_key, key_poly) + "_shift"));
prover_poly = key_poly.shifted();
}
}
ProverPolynomials(std::shared_ptr<ProvingKey>& proving_key)
{
for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key->get_all())) {
ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(*proving_key, key_poly));
prover_poly = key_poly.share();
}
for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key->get_to_be_shifted())) {
ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(*proving_key, key_poly) + "_shift"));
prover_poly = key_poly.shifted();
}
}
// Define all operations as default, except move construction/assignment
ProverPolynomials() = default;
ProverPolynomials& operator=(const ProverPolynomials&) = delete;
Expand Down
128 changes: 128 additions & 0 deletions barretenberg/cpp/src/barretenberg/flavor/ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
#include "barretenberg/polynomials/polynomial.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/proof_system/circuit_builder/ultra_circuit_builder.hpp"
#include "barretenberg/proof_system/library/grand_product_delta.hpp"
#include "barretenberg/proof_system/library/grand_product_library.hpp"
#include "barretenberg/relations/auxiliary_relation.hpp"
#include "barretenberg/relations/delta_range_constraint_relation.hpp"
#include "barretenberg/relations/elliptic_relation.hpp"
#include "barretenberg/relations/lookup_relation.hpp"
#include "barretenberg/relations/permutation_relation.hpp"
#include "barretenberg/relations/relation_parameters.hpp"
#include "barretenberg/relations/ultra_arithmetic_relation.hpp"
#include "barretenberg/transcript/transcript.hpp"

Expand Down Expand Up @@ -272,6 +275,7 @@ class UltraFlavor {

std::vector<uint32_t> memory_read_records;
std::vector<uint32_t> memory_write_records;
std::array<Polynomial, 4> sorted_polynomials;

auto get_to_be_shifted()
{
Expand All @@ -280,6 +284,108 @@ class UltraFlavor {
};
// The plookup wires that store plookup read data.
auto get_table_column_wires() { return RefArray{ w_l, w_r, w_o }; };

/**
* @brief Construct sorted list accumulator polynomial 's'.
*
* @details Compute s = s_1 + η*s_2 + η²*s_3 + η³*s_4 (via Horner) where s_i are the
* sorted concatenated witness/table polynomials
*
* @param key proving key
* @param sorted_list_polynomials sorted concatenated witness/table polynomials
* @param eta random challenge
* @return Polynomial
*/
void compute_sorted_list_accumulator(const FF& eta)
{
const size_t circuit_size = this->circuit_size;

auto sorted_list_accumulator = Polynomial{ circuit_size };

// Construct s via Horner, i.e. s = s_1 + η(s_2 + η(s_3 + η*s_4))
for (size_t i = 0; i < circuit_size; ++i) {
FF T0 = this->sorted_polynomials[3][i];
T0 *= eta;
T0 += this->sorted_polynomials[2][i];
T0 *= eta;
T0 += this->sorted_polynomials[1][i];
T0 *= eta;
T0 += this->sorted_polynomials[0][i];
sorted_list_accumulator[i] = T0;
}
this->sorted_accum = sorted_list_accumulator.share();
}

void compute_sorted_accumulator_polynomials(const FF& eta)
{
// Compute sorted witness-table accumulator
this->compute_sorted_list_accumulator(eta);

// Finalize fourth wire polynomial by adding lookup memory records
add_plookup_memory_records_to_wire_4(eta);
}

/**
* @brief Add plookup memory records to the fourth wire polynomial
*
* @details This operation must be performed after the first three wires have been committed to, hence the
* dependence on the `eta` challenge.
*
* @tparam Flavor
* @param eta challenge produced after commitment to first three wire polynomials
*/
void add_plookup_memory_records_to_wire_4(const FF& eta)
{
// The plookup memory record values are computed at the indicated indices as
// w4 = w3 * eta^3 + w2 * eta^2 + w1 * eta + read_write_flag;
// (See plookup_auxiliary_widget.hpp for details)
auto wires = this->get_wires();

// Compute read record values
for (const auto& gate_idx : this->memory_read_records) {
wires[3][gate_idx] += wires[2][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[1][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[0][gate_idx];
wires[3][gate_idx] *= eta;
}

// Compute write record values
for (const auto& gate_idx : this->memory_write_records) {
wires[3][gate_idx] += wires[2][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[1][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += wires[0][gate_idx];
wires[3][gate_idx] *= eta;
wires[3][gate_idx] += 1;
}
}

/**
* @brief Computes public_input_delta, lookup_grand_product_delta, the z_perm and z_lookup polynomials
*
* @param relation_parameters
*/
void compute_grand_product_polynomials(RelationParameters<FF>& relation_parameters)
{
auto public_input_delta = compute_public_input_delta<UltraFlavor>(this->public_inputs,
relation_parameters.beta,
relation_parameters.gamma,
this->circuit_size,
this->pub_inputs_offset);
relation_parameters.public_input_delta = public_input_delta;
auto lookup_grand_product_delta = compute_lookup_grand_product_delta(
relation_parameters.beta, relation_parameters.gamma, this->circuit_size);
relation_parameters.lookup_grand_product_delta = lookup_grand_product_delta;

// Compute permutation and lookup grand product polynomials
auto prover_polynomials = ProverPolynomials(*this);
compute_grand_products<UltraFlavor>(*this, prover_polynomials, relation_parameters);
this->z_perm = prover_polynomials.z_perm;
this->z_lookup = prover_polynomials.z_lookup;
}
};

/**
Expand Down Expand Up @@ -307,6 +413,28 @@ class UltraFlavor {
*/
class ProverPolynomials : public AllEntities<Polynomial> {
public:
ProverPolynomials(ProvingKey& proving_key)
{
for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key.get_all())) {
ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(proving_key, key_poly));
prover_poly = key_poly.share();
}
for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key.get_to_be_shifted())) {
ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(proving_key, key_poly) + "_shift"));
prover_poly = key_poly.shifted();
}
}
ProverPolynomials(std::shared_ptr<ProvingKey>& proving_key)
{
for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key->get_all())) {
ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(*proving_key, key_poly));
prover_poly = key_poly.share();
}
for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key->get_to_be_shifted())) {
ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(*proving_key, key_poly) + "_shift"));
prover_poly = key_poly.shifted();
}
}
// Define all operations as default, except move construction/assignment
ProverPolynomials() = default;
ProverPolynomials& operator=(const ProverPolynomials&) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void compute_grand_product(const size_t circuit_size,
}

template <typename Flavor>
void compute_grand_products(std::shared_ptr<typename Flavor::ProvingKey>& key,
void compute_grand_products(const typename Flavor::ProvingKey& key,
typename Flavor::ProverPolynomials& full_polynomials,
bb::RelationParameters<typename Flavor::FF>& relation_parameters)
{
Expand All @@ -157,10 +157,10 @@ void compute_grand_products(std::shared_ptr<typename Flavor::ProvingKey>& key,
// For example, for UltraPermutationRelation, this will be `full_polynomials.z_perm`
// For example, for LookupRelation, this will be `full_polynomials.z_lookup`
bb::Polynomial<FF>& full_polynomial = GrandProdRelation::get_grand_product_polynomial(full_polynomials);
auto& key_polynomial = GrandProdRelation::get_grand_product_polynomial(*key);
auto& key_polynomial = GrandProdRelation::get_grand_product_polynomial(key);
full_polynomial = key_polynomial.share();

compute_grand_product<Flavor, GrandProdRelation>(key->circuit_size, full_polynomials, relation_parameters);
compute_grand_product<Flavor, GrandProdRelation>(key.circuit_size, full_polynomials, relation_parameters);
bb::Polynomial<FF>& full_polynomial_shift =
GrandProdRelation::get_shifted_grand_product_polynomial(full_polynomials);
full_polynomial_shift = key_polynomial.shifted();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,17 @@ template <typename Flavor> class ProtoGalaxyTests : public testing::Test {
construct_circuit(builder);

auto instance = std::make_shared<ProverInstance>(builder);
instance->initialize_prover_polynomials();

auto eta = FF::random_element();
auto beta = FF::random_element();
auto gamma = FF::random_element();
instance->compute_sorted_accumulator_polynomials(eta);
instance->relation_parameters.eta = FF::random_element();
instance->relation_parameters.beta = FF::random_element();
instance->relation_parameters.gamma = FF::random_element();

instance->proving_key->compute_sorted_accumulator_polynomials(instance->relation_parameters.eta);
if constexpr (IsGoblinFlavor<Flavor>) {
instance->compute_logderivative_inverse(beta, gamma);
instance->proving_key->compute_logderivative_inverse(instance->relation_parameters);
}
instance->compute_grand_product_polynomials(beta, gamma);
instance->proving_key->compute_grand_product_polynomials(instance->relation_parameters);
instance->prover_polynomials = ProverPolynomials(instance->proving_key);

for (auto& alpha : instance->alphas) {
alpha = FF::random_element();
Expand Down
Loading

0 comments on commit a83368c

Please sign in to comment.