Skip to content

Commit

Permalink
chore: Oink takes directly populates an instance (#8170)
Browse files Browse the repository at this point in the history
Oink can be thought of as an "instance completer", i.e. when it is done
running all of the data that comprises an instance has been created. Up
until now the model was to pass oink a reference to a proving key. It
would "complete" the proving key in place by populating some witness
polynomials then explicitly return the rest of the data comprising an
instance (relation_parameters etc.) in a custom struct like
`OinkOutput`. The data from this output would then be std::move'd into
an instance existing in the external scope.

This PR simplifies this model by simply passing oink an instance
(ProverInstance or VerifierInstance) which is "completed" in place
throughout oink. IMO this is cleaner and clearer than the half-and-half
approach of completing the proving key in place and explicitly returning
other data. It also removes a ton of boilerplate for moving data in and
out of an instance. I don't love the "input parameter treated as output
parameter approach" but unless we refactor Honk/PG to construct
proving_key instead of an instance, I think this is preferred. (In that
case oink could take a proving_key and return a completed instance).
  • Loading branch information
ledwards2225 authored Aug 23, 2024
1 parent 7f95ee7 commit 6e46b45
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ BB_PROFILE static void test_round_inner(State& state, MegaProver& prover, size_t
BB_REPORT_OP_COUNT_BENCH_CANCEL();
}
};
OinkProver<MegaFlavor> oink_prover(prover.instance->proving_key, prover.transcript);
OinkProver<MegaFlavor> oink_prover(prover.instance, prover.transcript);
time_if_index(PREAMBLE, [&] { oink_prover.execute_preamble_round(); });
time_if_index(WIRE_COMMITMENTS, [&] { oink_prover.execute_wire_commitments_round(); });
time_if_index(SORTED_LIST_ACCUMULATOR, [&] { oink_prover.execute_sorted_list_accumulator_round(); });
time_if_index(LOG_DERIVATIVE_INVERSE, [&] { oink_prover.execute_log_derivative_inverse_round(); });
time_if_index(GRAND_PRODUCT_COMPUTATION, [&] { oink_prover.execute_grand_product_computation_round(); });
time_if_index(GENERATE_ALPHAS, [&] { prover.instance->alphas = oink_prover.generate_alphas_round(); });
// we need to get the relation_parameters and prover_polynomials from the oink_prover
prover.instance->relation_parameters = oink_prover.relation_parameters;

prover.generate_gate_challenges();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,9 @@ void ProtoGalaxyProver_<ProverInstances>::finalise_and_send_instance(std::shared
const std::string& domain_separator)
{
ZoneScopedN("ProtoGalaxyProver::finalise_and_send_instance");
OinkProver<Flavor> oink_prover(instance->proving_key, transcript, domain_separator + '_');
OinkProver<Flavor> oink_prover(instance, transcript, domain_separator + '_');

auto [proving_key, relation_params, alphas] = oink_prover.prove();
instance->proving_key = std::move(proving_key);
instance->relation_parameters = std::move(relation_params);
instance->alphas = std::move(alphas);
oink_prover.prove();
}

template <class ProverInstances> void ProtoGalaxyProver_<ProverInstances>::prepare_for_folding()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@ template <class VerifierInstances>
void ProtoGalaxyVerifier_<VerifierInstances>::receive_and_finalise_instance(const std::shared_ptr<Instance>& inst,
const std::string& domain_separator)
{
auto& key = inst->verification_key;
OinkVerifier<Flavor> oink_verifier{ key, transcript, domain_separator + '_' };
auto [relation_parameters, witness_commitments, public_inputs, alphas] = oink_verifier.verify();
inst->relation_parameters = std::move(relation_parameters);
inst->witness_commitments = std::move(witness_commitments);
inst->public_inputs = std::move(public_inputs);
inst->alphas = std::move(alphas);
OinkVerifier<Flavor> oink_verifier{ inst, transcript, domain_separator + '_' };
oink_verifier.verify();
}

template <class VerifierInstances>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace bb::stdlib::recursion::honk {

template <typename Flavor>
OinkRecursiveVerifier_<Flavor>::OinkRecursiveVerifier_(Builder* builder,
const std::shared_ptr<VerificationKey>& vkey,
const std::shared_ptr<Instance>& instance,
std::shared_ptr<Transcript> transcript,
std::string domain_separator)
: key(vkey)
: instance(instance)
, builder(builder)
, transcript(transcript)
, domain_separator(std::move(domain_separator))
Expand All @@ -22,12 +22,10 @@ OinkRecursiveVerifier_<Flavor>::OinkRecursiveVerifier_(Builder* builder,
* @brief This function constructs a recursive verifier circuit for a native Ultra Honk proof of a given flavor.
* @return Output aggregation object
*/
template <typename Flavor> OinkRecursiveVerifier_<Flavor>::Output OinkRecursiveVerifier_<Flavor>::verify()
template <typename Flavor> void OinkRecursiveVerifier_<Flavor>::verify()
{
using CommitmentLabels = typename Flavor::CommitmentLabels;
using RelationParams = ::bb::RelationParameters<FF>;

RelationParams relation_parameters;
WitnessCommitments commitments;
CommitmentLabels labels;

Expand All @@ -42,7 +40,7 @@ template <typename Flavor> OinkRecursiveVerifier_<Flavor>::Output OinkRecursiveV
// ASSERT(static_cast<uint32_t>(pub_inputs_offset.get_value()) == key->pub_inputs_offset);

std::vector<FF> public_inputs;
for (size_t i = 0; i < key->num_public_inputs; ++i) {
for (size_t i = 0; i < instance->verification_key->num_public_inputs; ++i) {
public_inputs.emplace_back(
transcript->template receive_from_prover<FF>(domain_separator + "public_input_" + std::to_string(i)));
}
Expand Down Expand Up @@ -90,9 +88,7 @@ template <typename Flavor> OinkRecursiveVerifier_<Flavor>::Output OinkRecursiveV
}

const FF public_input_delta = compute_public_input_delta<Flavor>(
public_inputs, beta, gamma, circuit_size, static_cast<uint32_t>(key->pub_inputs_offset));

relation_parameters = RelationParameters<FF>{ eta, eta_two, eta_three, beta, gamma, public_input_delta };
public_inputs, beta, gamma, circuit_size, static_cast<uint32_t>(instance->verification_key->pub_inputs_offset));

// Get commitment to permutation and lookup grand products
commitments.z_perm = transcript->template receive_from_prover<Commitment>(domain_separator + labels.z_perm);
Expand All @@ -102,10 +98,10 @@ template <typename Flavor> OinkRecursiveVerifier_<Flavor>::Output OinkRecursiveV
alphas[idx] = transcript->template get_challenge<FF>(domain_separator + "alpha_" + std::to_string(idx));
}

return { .relation_parameters = relation_parameters,
.commitments = std::move(commitments),
.public_inputs = public_inputs,
.alphas = alphas };
instance->relation_parameters = RelationParameters<FF>{ eta, eta_two, eta_three, beta, gamma, public_input_delta };
instance->witness_commitments = std::move(commitments);
instance->public_inputs = std::move(public_inputs);
instance->alphas = std::move(alphas);
}

template class OinkRecursiveVerifier_<bb::UltraRecursiveFlavor_<UltraCircuitBuilder>>;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include "barretenberg/stdlib/protogalaxy_verifier/recursive_verifier_instance.hpp"
#include "barretenberg/stdlib/transcript/transcript.hpp"
#include "barretenberg/stdlib_circuit_builders/mega_recursive_flavor.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_recursive_flavor.hpp"
Expand All @@ -10,27 +11,21 @@ template <typename Flavor> class OinkRecursiveVerifier_ {
using FF = typename Flavor::FF;
using Commitment = typename Flavor::Commitment;
using GroupElement = typename Flavor::GroupElement;
using Instance = RecursiveVerifierInstance_<Flavor>;
using VerificationKey = typename Flavor::VerificationKey;
using Builder = typename Flavor::CircuitBuilder;
using RelationSeparator = typename Flavor::RelationSeparator;
using Transcript = bb::BaseTranscript<bb::stdlib::recursion::honk::StdlibTranscriptParams<Builder>>;
using WitnessCommitments = typename Flavor::WitnessCommitments;

struct Output {
bb::RelationParameters<typename Flavor::FF> relation_parameters;
WitnessCommitments commitments;
std::vector<typename Flavor::FF> public_inputs;
typename Flavor::RelationSeparator alphas;
};

explicit OinkRecursiveVerifier_(Builder* builder,
const std::shared_ptr<VerificationKey>& vkey,
const std::shared_ptr<Instance>& instance,
std::shared_ptr<Transcript> transcript,
std::string domain_separator = "");

Output verify();
void verify();

std::shared_ptr<VerificationKey> key;
std::shared_ptr<Instance> instance;
Builder* builder;
std::shared_ptr<Transcript> transcript;
std::string domain_separator; // used in PG to distinguish between instances in transcript
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ UltraRecursiveVerifier_<Flavor>::AggregationObject UltraRecursiveVerifier_<Flavo
using Transcript = typename Flavor::Transcript;

transcript = std::make_shared<Transcript>(proof);
OinkVerifier oink_verifier{ builder, key, transcript };
auto [relation_parameters, witness_commitments, public_inputs, alphas] = oink_verifier.verify();
auto instance = std::make_shared<Instance>(builder, key);
OinkVerifier oink_verifier{ builder, instance, transcript };
oink_verifier.verify();

VerifierCommitments commitments{ key, witness_commitments };
VerifierCommitments commitments{ key, instance->witness_commitments };

auto gate_challenges = std::vector<FF>(CONST_PROOF_SIZE_LOG_N);
for (size_t idx = 0; idx < CONST_PROOF_SIZE_LOG_N; idx++) {
Expand All @@ -66,7 +67,7 @@ UltraRecursiveVerifier_<Flavor>::AggregationObject UltraRecursiveVerifier_<Flavo
for (size_t j = 0; j < 2; j++) {
std::array<FF, 4> bigfield_limbs;
for (size_t k = 0; k < 4; k++) {
bigfield_limbs[k] = public_inputs[key->recursive_proof_public_input_indices[idx]];
bigfield_limbs[k] = instance->public_inputs[key->recursive_proof_public_input_indices[idx]];
idx++;
}
base_field_vals[j] =
Expand All @@ -88,7 +89,7 @@ UltraRecursiveVerifier_<Flavor>::AggregationObject UltraRecursiveVerifier_<Flavo
auto sumcheck = Sumcheck(log_circuit_size, transcript);

auto [multivariate_challenge, claimed_evaluations, sumcheck_verified] =
sumcheck.verify(relation_parameters, alphas, gate_challenges);
sumcheck.verify(instance->relation_parameters, instance->alphas, gate_challenges);

// Execute ZeroMorph to produce an opening claim subsequently verified by a univariate PCS
auto opening_claim = ZeroMorph::verify(key->circuit_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ template <typename Flavor> class UltraRecursiveVerifier_ {
using FF = typename Flavor::FF;
using Commitment = typename Flavor::Commitment;
using GroupElement = typename Flavor::GroupElement;
using Instance = RecursiveVerifierInstance_<Flavor>;
using VerificationKey = typename Flavor::VerificationKey;
using NativeVerificationKey = typename Flavor::NativeVerificationKey;
using VerifierCommitmentKey = typename Flavor::VerifierCommitmentKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@ void ProtoGalaxyRecursiveVerifier_<VerifierInstances>::receive_and_finalise_inst
const std::shared_ptr<Instance>& inst, std::string& domain_separator)
{
domain_separator = domain_separator + "_";
OinkVerifier oink_verifier{ builder, inst->verification_key, transcript, domain_separator };
auto [relation_parameters, witness_commitments, public_inputs, alphas] = oink_verifier.verify();
inst->relation_parameters = std::move(relation_parameters);
inst->witness_commitments = std::move(witness_commitments);
inst->public_inputs = std::move(public_inputs);
inst->alphas = std::move(alphas);
OinkVerifier oink_verifier{ builder, inst, transcript, domain_separator };
oink_verifier.verify();
}

// TODO(https://github.com/AztecProtocol/barretenberg/issues/795): The rounds prior to actual verifying are common
Expand Down
63 changes: 29 additions & 34 deletions barretenberg/cpp/src/barretenberg/ultra_honk/oink_prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace bb {
* @tparam Flavor
* @return OinkProverOutput<Flavor>
*/
template <IsUltraFlavor Flavor> OinkProverOutput<Flavor> OinkProver<Flavor>::prove()
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::prove()
{
{
ZoneScopedN("execute_preamble_round");
Expand Down Expand Up @@ -41,13 +41,7 @@ template <IsUltraFlavor Flavor> OinkProverOutput<Flavor> OinkProver<Flavor>::pro
}

// Generate relation separators alphas for sumcheck/combiner computation
RelationSeparator alphas = generate_alphas_round();

return OinkProverOutput<Flavor>{
.proving_key = std::move(proving_key),
.relation_parameters = std::move(relation_parameters),
.alphas = std::move(alphas),
};
instance->alphas = generate_alphas_round();
}

/**
Expand All @@ -56,17 +50,17 @@ template <IsUltraFlavor Flavor> OinkProverOutput<Flavor> OinkProver<Flavor>::pro
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_preamble_round()
{
const auto circuit_size = static_cast<uint32_t>(proving_key.circuit_size);
const auto num_public_inputs = static_cast<uint32_t>(proving_key.num_public_inputs);
const auto circuit_size = static_cast<uint32_t>(instance->proving_key.circuit_size);
const auto num_public_inputs = static_cast<uint32_t>(instance->proving_key.num_public_inputs);
transcript->send_to_verifier(domain_separator + "circuit_size", circuit_size);
transcript->send_to_verifier(domain_separator + "public_input_size", num_public_inputs);
transcript->send_to_verifier(domain_separator + "pub_inputs_offset",
static_cast<uint32_t>(proving_key.pub_inputs_offset));
static_cast<uint32_t>(instance->proving_key.pub_inputs_offset));

ASSERT(proving_key.num_public_inputs == proving_key.public_inputs.size());
ASSERT(instance->proving_key.num_public_inputs == instance->proving_key.public_inputs.size());

for (size_t i = 0; i < proving_key.num_public_inputs; ++i) {
auto public_input_i = proving_key.public_inputs[i];
for (size_t i = 0; i < instance->proving_key.num_public_inputs; ++i) {
auto public_input_i = instance->proving_key.public_inputs[i];
transcript->send_to_verifier(domain_separator + "public_input_" + std::to_string(i), public_input_i);
}
}
Expand All @@ -82,9 +76,9 @@ template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_wire_commitment
// We only commit to the fourth wire polynomial after adding memory recordss
{
BB_OP_COUNT_TIME_NAME("COMMIT::wires");
witness_commitments.w_l = commitment_key->commit(proving_key.polynomials.w_l);
witness_commitments.w_r = commitment_key->commit(proving_key.polynomials.w_r);
witness_commitments.w_o = commitment_key->commit(proving_key.polynomials.w_o);
witness_commitments.w_l = commitment_key->commit(instance->proving_key.polynomials.w_l);
witness_commitments.w_r = commitment_key->commit(instance->proving_key.polynomials.w_r);
witness_commitments.w_o = commitment_key->commit(instance->proving_key.polynomials.w_o);
}

auto wire_comms = witness_commitments.get_wires();
Expand All @@ -97,7 +91,7 @@ template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_wire_commitment

// Commit to Goblin ECC op wires
for (auto [commitment, polynomial, label] : zip_view(witness_commitments.get_ecc_op_wires(),
proving_key.polynomials.get_ecc_op_wires(),
instance->proving_key.polynomials.get_ecc_op_wires(),
commitment_labels.get_ecc_op_wires())) {
{
BB_OP_COUNT_TIME_NAME("COMMIT::ecc_op_wires");
Expand All @@ -108,7 +102,7 @@ template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_wire_commitment

// Commit to DataBus related polynomials
for (auto [commitment, polynomial, label] : zip_view(witness_commitments.get_databus_entities(),
proving_key.polynomials.get_databus_entities(),
instance->proving_key.polynomials.get_databus_entities(),
commitment_labels.get_databus_entities())) {
{
BB_OP_COUNT_TIME_NAME("COMMIT::databus");
Expand All @@ -128,22 +122,23 @@ template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_sorted_list_acc
// Get eta challenges
auto [eta, eta_two, eta_three] = transcript->template get_challenges<FF>(
domain_separator + "eta", domain_separator + "eta_two", domain_separator + "eta_three");
relation_parameters.eta = eta;
relation_parameters.eta_two = eta_two;
relation_parameters.eta_three = eta_three;
instance->relation_parameters.eta = eta;
instance->relation_parameters.eta_two = eta_two;
instance->relation_parameters.eta_three = eta_three;

proving_key.add_ram_rom_memory_records_to_wire_4(
relation_parameters.eta, relation_parameters.eta_two, relation_parameters.eta_three);
instance->proving_key.add_ram_rom_memory_records_to_wire_4(eta, eta_two, eta_three);

// Commit to lookup argument polynomials and the finalized (i.e. with memory records) fourth wire polynomial
{
BB_OP_COUNT_TIME_NAME("COMMIT::lookup_counts_tags");
witness_commitments.lookup_read_counts = commitment_key->commit(proving_key.polynomials.lookup_read_counts);
witness_commitments.lookup_read_tags = commitment_key->commit(proving_key.polynomials.lookup_read_tags);
witness_commitments.lookup_read_counts =
commitment_key->commit(instance->proving_key.polynomials.lookup_read_counts);
witness_commitments.lookup_read_tags =
commitment_key->commit(instance->proving_key.polynomials.lookup_read_tags);
}
{
BB_OP_COUNT_TIME_NAME("COMMIT::wires");
witness_commitments.w_4 = commitment_key->commit(proving_key.polynomials.w_4);
witness_commitments.w_4 = commitment_key->commit(instance->proving_key.polynomials.w_4);
}

transcript->send_to_verifier(domain_separator + commitment_labels.lookup_read_counts,
Expand All @@ -160,23 +155,23 @@ template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_sorted_list_acc
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_log_derivative_inverse_round()
{
auto [beta, gamma] = transcript->template get_challenges<FF>(domain_separator + "beta", domain_separator + "gamma");
relation_parameters.beta = beta;
relation_parameters.gamma = gamma;
instance->relation_parameters.beta = beta;
instance->relation_parameters.gamma = gamma;

// Compute the inverses used in log-derivative lookup relations
proving_key.compute_logderivative_inverses(relation_parameters);
instance->proving_key.compute_logderivative_inverses(instance->relation_parameters);

{
BB_OP_COUNT_TIME_NAME("COMMIT::lookup_inverses");
witness_commitments.lookup_inverses = commitment_key->commit(proving_key.polynomials.lookup_inverses);
witness_commitments.lookup_inverses = commitment_key->commit(instance->proving_key.polynomials.lookup_inverses);
}
transcript->send_to_verifier(domain_separator + commitment_labels.lookup_inverses,
witness_commitments.lookup_inverses);

// If Mega, commit to the databus inverse polynomials and send
if constexpr (IsGoblinFlavor<Flavor>) {
for (auto [commitment, polynomial, label] : zip_view(witness_commitments.get_databus_inverses(),
proving_key.polynomials.get_databus_inverses(),
instance->proving_key.polynomials.get_databus_inverses(),
commitment_labels.get_databus_inverses())) {
{
BB_OP_COUNT_TIME_NAME("COMMIT::databus_inverses");
Expand All @@ -193,11 +188,11 @@ template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_log_derivative_
*/
template <IsUltraFlavor Flavor> void OinkProver<Flavor>::execute_grand_product_computation_round()
{
proving_key.compute_grand_product_polynomials(relation_parameters);
instance->proving_key.compute_grand_product_polynomials(instance->relation_parameters);

{
BB_OP_COUNT_TIME_NAME("COMMIT::z_perm");
witness_commitments.z_perm = commitment_key->commit(proving_key.polynomials.z_perm);
witness_commitments.z_perm = commitment_key->commit(instance->proving_key.polynomials.z_perm);
}
transcript->send_to_verifier(domain_separator + commitment_labels.z_perm, witness_commitments.z_perm);
}
Expand Down
Loading

0 comments on commit 6e46b45

Please sign in to comment.