Skip to content

Commit

Permalink
refactor(bb): pointer_view to reference-based get_all (#3495)
Browse files Browse the repository at this point in the history
Originally, pointer_view was a compromise as we didn't have a good way
of making an iterable structure of references. Now with RefVector, we
can move away from pointer_view

Co-authored-by: ludamad <adam@aztecprotocol.com>
  • Loading branch information
ludamad and ludamad0 authored Dec 1, 2023
1 parent d889359 commit 50d7327
Show file tree
Hide file tree
Showing 28 changed files with 152 additions and 206 deletions.
13 changes: 5 additions & 8 deletions barretenberg/cpp/src/barretenberg/common/ref_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,11 @@ template <typename T, std::size_t N> class RefArray {
storage[i++] = &elem;
}
}
RefArray(std::initializer_list<T&> init)
template <typename... Ts> RefArray(T& ref, Ts&... rest)
{
if (init.size() != N) {
throw std::invalid_argument("Initializer list size does not match RefArray size");
}
std::size_t i = 0;
for (auto& elem : init) {
storage[i++] = &elem;
}
storage[0] = &ref;
int i = 1;
((storage[i++] = &rest), ...);
}

T& operator[](std::size_t idx) const
Expand Down Expand Up @@ -82,6 +78,7 @@ template <typename T, std::size_t N> class RefArray {
std::size_t pos;
};

constexpr std::size_t size() const { return N; }
/**
* @brief Returns an iterator to the beginning of the RefArray.
*
Expand Down
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/common/ref_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template <typename T> class RefVector {
std::size_t pos;
};

[[nodiscard]] std::size_t size() const { return storage.size(); }
std::size_t size() const { return storage.size(); }

void push_back(T& element) { storage.push_back(element); }
iterator begin() const { return iterator(this, 0); }
Expand Down
10 changes: 5 additions & 5 deletions barretenberg/cpp/src/barretenberg/common/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ template <typename B, typename T> inline void read(B& it, std::optional<T>& opt_
}

template <typename T>
concept HasPointerView = requires(T t) { t.pointer_view(); };
concept HasGetAll = requires(T t) { t.get_all(); };

// Write out a struct that defines pointer_view()
template <typename B, HasPointerView T> inline void write(B& buf, T const& value)
// Write out a struct that defines get_all()
template <typename B, HasGetAll T> inline void write(B& buf, T const& value)
{
using serialize::write;
for (auto* pointer : value.pointer_view()) {
write(buf, *pointer);
for (auto& reference : value.get_all()) {
write(buf, reference);
}
}

Expand Down
13 changes: 6 additions & 7 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_prover.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "eccvm_prover.hpp"
#include "barretenberg/commitment_schemes/claim.hpp"
#include "barretenberg/commitment_schemes/commitment_key.hpp"
#include "barretenberg/common/ref_array.hpp"
#include "barretenberg/honk/proof_system/lookup_library.hpp"
#include "barretenberg/honk/proof_system/permutation_library.hpp"
#include "barretenberg/honk/proof_system/power_polynomial.hpp"
Expand Down Expand Up @@ -345,18 +346,16 @@ template <ECCVMFlavor Flavor> void ECCVMProver_<Flavor>::execute_transcript_cons
FF batching_challenge = transcript.get_challenge("Translation:batching_challenge");

// Collect the polynomials and evaluations to be batched
const size_t NUM_UNIVARIATES = 6; // 5 transcript polynomials plus the constant hack poly
std::array<Polynomial*, NUM_UNIVARIATES> univariate_polynomials = { &key->transcript_op, &key->transcript_Px,
&key->transcript_Py, &key->transcript_z1,
&key->transcript_z2, &hack };
std::array<FF, NUM_UNIVARIATES> univariate_evaluations;
RefArray univariate_polynomials{ key->transcript_op, key->transcript_Px, key->transcript_Py,
key->transcript_z1, key->transcript_z2, hack };
std::array<FF, univariate_polynomials.size()> univariate_evaluations;

// Constuct the batched polynomial and batched evaluation
Polynomial batched_univariate{ key->circuit_size };
FF batched_evaluation{ 0 };
auto batching_scalar = FF(1);
for (auto [eval, polynomial] : zip_view(univariate_evaluations, univariate_polynomials)) {
batched_univariate.add_scaled(*polynomial, batching_scalar);
for (auto [polynomial, eval] : zip_view(univariate_polynomials, univariate_evaluations)) {
batched_univariate.add_scaled(polynomial, batching_scalar);
batched_evaluation += eval * batching_scalar;
batching_scalar *= batching_challenge;
}
Expand Down
17 changes: 6 additions & 11 deletions barretenberg/cpp/src/barretenberg/flavor/ecc_vm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
class WitnessEntities : public WireEntities<DataType>, public DerivedWitnessEntities<DataType> {
public:
DEFINE_COMPOUND_GET_ALL(WireEntities<DataType>::get_all(), DerivedWitnessEntities<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(WireEntities<DataType>::pointer_view(),
DerivedWitnessEntities<DataType>::pointer_view())
RefVector<DataType> get_wires() { return WireEntities<DataType>::get_all(); };
// The sorted concatenations of table and witness data needed for plookup.
RefVector<DataType> get_sorted_polynomials() { return {}; };
Expand Down Expand Up @@ -268,9 +266,6 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
DEFINE_COMPOUND_GET_ALL(PrecomputedEntities<DataType>::get_all(),
WitnessEntities<DataType>::get_all(),
ShiftedEntities<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(PrecomputedEntities<DataType>::pointer_view(),
WitnessEntities<DataType>::pointer_view(),
ShiftedEntities<DataType>::pointer_view())
// Gemini-specific getters.
RefVector<DataType> get_unshifted()
{
Expand Down Expand Up @@ -369,8 +364,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -391,8 +386,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand All @@ -419,8 +414,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
AllValues get_row(const size_t row_idx)
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand Down
10 changes: 5 additions & 5 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ProvingKey_ : public PrecomputedPolynomials, public WitnessPolynomials {
std::vector<uint32_t> recursive_proof_public_input_indices;
barretenberg::EvaluationDomain<FF> evaluation_domain;

auto precomputed_polynomials_pointer_view() { return PrecomputedPolynomials::pointer_view(); }
auto precomputed_polynomials_get_all() { return PrecomputedPolynomials::get_all(); }
ProvingKey_() = default;
ProvingKey_(const size_t circuit_size, const size_t num_public_inputs)
{
Expand All @@ -114,12 +114,12 @@ class ProvingKey_ : public PrecomputedPolynomials, public WitnessPolynomials {
this->log_circuit_size = numeric::get_msb(circuit_size);
this->num_public_inputs = num_public_inputs;
// Allocate memory for precomputed polynomials
for (auto* poly : PrecomputedPolynomials::pointer_view()) {
*poly = Polynomial(circuit_size);
for (auto& poly : PrecomputedPolynomials::get_all()) {
poly = Polynomial(circuit_size);
}
// Allocate memory for witness polynomials
for (auto* poly : WitnessPolynomials::pointer_view()) {
*poly = Polynomial(circuit_size);
for (auto& poly : WitnessPolynomials::get_all()) {
poly = Polynomial(circuit_size);
}
};
};
Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ TEST(Flavor, GetRow)
return std::vector<FF>({ FF::random_element(), FF::random_element() });
});
Flavor::ProverPolynomials prover_polynomials;
for (auto [poly, entry] : zip_view(prover_polynomials.pointer_view(), data)) {
*poly = entry;
for (auto [poly, entry] : zip_view(prover_polynomials.get_all(), data)) {
poly = entry;
}
auto row0 = prover_polynomials.get_row(0);
auto row1 = prover_polynomials.get_row(1);
Expand Down
24 changes: 1 addition & 23 deletions barretenberg/cpp/src/barretenberg/flavor/flavor_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Macros for defining the flavor classes.
// These are used to derive iterator methods along with the body of a 'flavor' class.
// DEFINE_FLAVOR_MEMBERS lets you define a flavor entity as a collection of individual members, and derive an iterator.
// while DEFINE_COMPOUND_GET_ALL and DEFINE_COMPOUND_POINTER_VIEW let you combine the iterators of substructures or base
// while DEFINE_COMPOUND_GET_ALL lets you combine the iterators of substructures or base
// classes.

#include "barretenberg/common/ref_vector.hpp"
Expand All @@ -17,17 +17,6 @@ template <typename... Refs> auto _refs_to_pointer_array(Refs&... refs)
return std::array{ &refs... };
}

// @deprecated this was less natural than the ref view
#define DEFINE_POINTER_VIEW(...) \
[[nodiscard]] auto pointer_view() \
{ \
return _refs_to_pointer_array(__VA_ARGS__); \
} \
[[nodiscard]] auto pointer_view() const \
{ \
return _refs_to_pointer_array(__VA_ARGS__); \
}

#define DEFINE_REF_VIEW(...) \
[[nodiscard]] auto get_all() \
{ \
Expand All @@ -47,19 +36,8 @@ template <typename... Refs> auto _refs_to_pointer_array(Refs&... refs)
*/
#define DEFINE_FLAVOR_MEMBERS(DataType, ...) \
DataType __VA_ARGS__; \
DEFINE_POINTER_VIEW(__VA_ARGS__) \
DEFINE_REF_VIEW(__VA_ARGS__)

#define DEFINE_COMPOUND_POINTER_VIEW(...) \
[[nodiscard]] auto pointer_view() \
{ \
return concatenate(__VA_ARGS__); \
} \
[[nodiscard]] auto pointer_view() const \
{ \
return concatenate(__VA_ARGS__); \
}

#define DEFINE_COMPOUND_GET_ALL(...) \
[[nodiscard]] auto get_all() \
{ \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ class AvmMiniFlavor {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -271,8 +271,8 @@ class AvmMiniFlavor {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class FibFlavor {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -179,8 +179,8 @@ class FibFlavor {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
19 changes: 6 additions & 13 deletions barretenberg/cpp/src/barretenberg/flavor/goblin_translator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,6 @@ class GoblinTranslator {
WireToBeShiftedEntities<DataType>::get_all(),
DerivedWitnessEntities<DataType>::get_all(),
ConcatenatedRangeConstraints<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(WireNonshiftedEntities<DataType>::pointer_view(),
WireToBeShiftedEntities<DataType>::pointer_view(),
DerivedWitnessEntities<DataType>::pointer_view(),
ConcatenatedRangeConstraints<DataType>::pointer_view())

RefVector<DataType> get_wires()
{
Expand Down Expand Up @@ -654,9 +650,6 @@ class GoblinTranslator {
DEFINE_COMPOUND_GET_ALL(PrecomputedEntities<DataType>::get_all(),
WitnessEntities<DataType>::get_all(),
ShiftedEntities<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(PrecomputedEntities<DataType>::pointer_view(),
WitnessEntities<DataType>::pointer_view(),
ShiftedEntities<DataType>::pointer_view())
/**
* @brief Get the polynomials that are concatenated for the permutation relation
*
Expand Down Expand Up @@ -1036,8 +1029,8 @@ class GoblinTranslator {
[[nodiscard]] AllValues get_row(size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -1064,8 +1057,8 @@ class GoblinTranslator {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -1086,8 +1079,8 @@ class GoblinTranslator {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
8 changes: 4 additions & 4 deletions barretenberg/cpp/src/barretenberg/flavor/goblin_ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ class GoblinUltra {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down Expand Up @@ -363,8 +363,8 @@ class GoblinUltra {
[[nodiscard]] AllValues get_row(size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand Down
8 changes: 4 additions & 4 deletions barretenberg/cpp/src/barretenberg/flavor/ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ class Ultra {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -283,8 +283,8 @@ class Ultra {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void compute_permutation_grand_product(const size_t circuit_size,
for (size_t i = start; i < end; ++i) {

typename Flavor::AllValues evaluations;
for (auto [eval, poly] : zip_view(evaluations.pointer_view(), full_polynomials.pointer_view())) {
*eval = poly->size() > i ? (*poly)[i] : 0;
for (auto [eval, poly] : zip_view(evaluations.get_all(), full_polynomials.get_all())) {
eval = poly.size() > i ? poly[i] : 0;
}
numerator[i] = GrandProdRelation::template compute_permutation_numerator<Accumulator>(evaluations,
relation_parameters);
Expand Down
Loading

0 comments on commit 50d7327

Please sign in to comment.