Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(bb): pointer_view to reference-based get_all #3495

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions barretenberg/cpp/src/barretenberg/common/ref_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,11 @@ template <typename T, std::size_t N> class RefArray {
storage[i++] = &elem;
}
}
RefArray(std::initializer_list<T&> init)
template <typename... Ts> RefArray(T& ref, Ts&... rest)
{
if (init.size() != N) {
throw std::invalid_argument("Initializer list size does not match RefArray size");
}
std::size_t i = 0;
for (auto& elem : init) {
storage[i++] = &elem;
}
storage[0] = &ref;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't actually working before, can't have an initializer list of references for the same reason you can't have a vector of references

int i = 1;
((storage[i++] = &rest), ...);
}

T& operator[](std::size_t idx) const
Expand Down Expand Up @@ -82,6 +78,7 @@ template <typename T, std::size_t N> class RefArray {
std::size_t pos;
};

constexpr std::size_t size() const { return N; }
/**
* @brief Returns an iterator to the beginning of the RefArray.
*
Expand Down
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/common/ref_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template <typename T> class RefVector {
std::size_t pos;
};

[[nodiscard]] std::size_t size() const { return storage.size(); }
std::size_t size() const { return storage.size(); }

void push_back(T& element) { storage.push_back(element); }
iterator begin() const { return iterator(this, 0); }
Expand Down
10 changes: 5 additions & 5 deletions barretenberg/cpp/src/barretenberg/common/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ template <typename B, typename T> inline void read(B& it, std::optional<T>& opt_
}

template <typename T>
concept HasPointerView = requires(T t) { t.pointer_view(); };
concept HasGetAll = requires(T t) { t.get_all(); };

// Write out a struct that defines pointer_view()
template <typename B, HasPointerView T> inline void write(B& buf, T const& value)
// Write out a struct that defines get_all()
template <typename B, HasGetAll T> inline void write(B& buf, T const& value)
{
using serialize::write;
for (auto* pointer : value.pointer_view()) {
write(buf, *pointer);
for (auto& reference : value.get_all()) {
write(buf, reference);
}
}

Expand Down
13 changes: 6 additions & 7 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_prover.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "eccvm_prover.hpp"
#include "barretenberg/commitment_schemes/claim.hpp"
#include "barretenberg/commitment_schemes/commitment_key.hpp"
#include "barretenberg/common/ref_array.hpp"
#include "barretenberg/honk/proof_system/lookup_library.hpp"
#include "barretenberg/honk/proof_system/permutation_library.hpp"
#include "barretenberg/honk/proof_system/power_polynomial.hpp"
Expand Down Expand Up @@ -345,18 +346,16 @@ template <ECCVMFlavor Flavor> void ECCVMProver_<Flavor>::execute_transcript_cons
FF batching_challenge = transcript.get_challenge("Translation:batching_challenge");

// Collect the polynomials and evaluations to be batched
const size_t NUM_UNIVARIATES = 6; // 5 transcript polynomials plus the constant hack poly
std::array<Polynomial*, NUM_UNIVARIATES> univariate_polynomials = { &key->transcript_op, &key->transcript_Px,
&key->transcript_Py, &key->transcript_z1,
&key->transcript_z2, &hack };
std::array<FF, NUM_UNIVARIATES> univariate_evaluations;
RefArray univariate_polynomials{ key->transcript_op, key->transcript_Px, key->transcript_Py,
key->transcript_z1, key->transcript_z2, hack };
std::array<FF, univariate_polynomials.size()> univariate_evaluations;

// Constuct the batched polynomial and batched evaluation
Polynomial batched_univariate{ key->circuit_size };
FF batched_evaluation{ 0 };
auto batching_scalar = FF(1);
for (auto [eval, polynomial] : zip_view(univariate_evaluations, univariate_polynomials)) {
batched_univariate.add_scaled(*polynomial, batching_scalar);
for (auto [polynomial, eval] : zip_view(univariate_polynomials, univariate_evaluations)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just convention: LHS the assignable variable

batched_univariate.add_scaled(polynomial, batching_scalar);
batched_evaluation += eval * batching_scalar;
batching_scalar *= batching_challenge;
}
Expand Down
17 changes: 6 additions & 11 deletions barretenberg/cpp/src/barretenberg/flavor/ecc_vm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
class WitnessEntities : public WireEntities<DataType>, public DerivedWitnessEntities<DataType> {
public:
DEFINE_COMPOUND_GET_ALL(WireEntities<DataType>::get_all(), DerivedWitnessEntities<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(WireEntities<DataType>::pointer_view(),
DerivedWitnessEntities<DataType>::pointer_view())
RefVector<DataType> get_wires() { return WireEntities<DataType>::get_all(); };
// The sorted concatenations of table and witness data needed for plookup.
RefVector<DataType> get_sorted_polynomials() { return {}; };
Expand Down Expand Up @@ -268,9 +266,6 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
DEFINE_COMPOUND_GET_ALL(PrecomputedEntities<DataType>::get_all(),
WitnessEntities<DataType>::get_all(),
ShiftedEntities<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(PrecomputedEntities<DataType>::pointer_view(),
WitnessEntities<DataType>::pointer_view(),
ShiftedEntities<DataType>::pointer_view())
// Gemini-specific getters.
RefVector<DataType> get_unshifted()
{
Expand Down Expand Up @@ -369,8 +364,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -391,8 +386,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand All @@ -419,8 +414,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
AllValues get_row(const size_t row_idx)
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand Down
10 changes: 5 additions & 5 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ProvingKey_ : public PrecomputedPolynomials, public WitnessPolynomials {
std::vector<uint32_t> recursive_proof_public_input_indices;
barretenberg::EvaluationDomain<FF> evaluation_domain;

auto precomputed_polynomials_pointer_view() { return PrecomputedPolynomials::pointer_view(); }
auto precomputed_polynomials_get_all() { return PrecomputedPolynomials::get_all(); }
ProvingKey_() = default;
ProvingKey_(const size_t circuit_size, const size_t num_public_inputs)
{
Expand All @@ -114,12 +114,12 @@ class ProvingKey_ : public PrecomputedPolynomials, public WitnessPolynomials {
this->log_circuit_size = numeric::get_msb(circuit_size);
this->num_public_inputs = num_public_inputs;
// Allocate memory for precomputed polynomials
for (auto* poly : PrecomputedPolynomials::pointer_view()) {
*poly = Polynomial(circuit_size);
for (auto& poly : PrecomputedPolynomials::get_all()) {
poly = Polynomial(circuit_size);
}
// Allocate memory for witness polynomials
for (auto* poly : WitnessPolynomials::pointer_view()) {
*poly = Polynomial(circuit_size);
for (auto& poly : WitnessPolynomials::get_all()) {
poly = Polynomial(circuit_size);
}
};
};
Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ TEST(Flavor, GetRow)
return std::vector<FF>({ FF::random_element(), FF::random_element() });
});
Flavor::ProverPolynomials prover_polynomials;
for (auto [poly, entry] : zip_view(prover_polynomials.pointer_view(), data)) {
*poly = entry;
for (auto [poly, entry] : zip_view(prover_polynomials.get_all(), data)) {
poly = entry;
}
auto row0 = prover_polynomials.get_row(0);
auto row1 = prover_polynomials.get_row(1);
Expand Down
24 changes: 1 addition & 23 deletions barretenberg/cpp/src/barretenberg/flavor/flavor_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Macros for defining the flavor classes.
// These are used to derive iterator methods along with the body of a 'flavor' class.
// DEFINE_FLAVOR_MEMBERS lets you define a flavor entity as a collection of individual members, and derive an iterator.
// while DEFINE_COMPOUND_GET_ALL and DEFINE_COMPOUND_POINTER_VIEW let you combine the iterators of substructures or base
// while DEFINE_COMPOUND_GET_ALL lets you combine the iterators of substructures or base
// classes.

#include "barretenberg/common/ref_vector.hpp"
Expand All @@ -17,17 +17,6 @@ template <typename... Refs> auto _refs_to_pointer_array(Refs&... refs)
return std::array{ &refs... };
}

// @deprecated this was less natural than the ref view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice that this change is so localized ✅

#define DEFINE_POINTER_VIEW(...) \
[[nodiscard]] auto pointer_view() \
{ \
return _refs_to_pointer_array(__VA_ARGS__); \
} \
[[nodiscard]] auto pointer_view() const \
{ \
return _refs_to_pointer_array(__VA_ARGS__); \
}

#define DEFINE_REF_VIEW(...) \
[[nodiscard]] auto get_all() \
{ \
Expand All @@ -47,19 +36,8 @@ template <typename... Refs> auto _refs_to_pointer_array(Refs&... refs)
*/
#define DEFINE_FLAVOR_MEMBERS(DataType, ...) \
DataType __VA_ARGS__; \
DEFINE_POINTER_VIEW(__VA_ARGS__) \
DEFINE_REF_VIEW(__VA_ARGS__)

#define DEFINE_COMPOUND_POINTER_VIEW(...) \
[[nodiscard]] auto pointer_view() \
{ \
return concatenate(__VA_ARGS__); \
} \
[[nodiscard]] auto pointer_view() const \
{ \
return concatenate(__VA_ARGS__); \
}

#define DEFINE_COMPOUND_GET_ALL(...) \
[[nodiscard]] auto get_all() \
{ \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ class AvmMiniFlavor {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -271,8 +271,8 @@ class AvmMiniFlavor {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class FibFlavor {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -179,8 +179,8 @@ class FibFlavor {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
19 changes: 6 additions & 13 deletions barretenberg/cpp/src/barretenberg/flavor/goblin_translator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,6 @@ class GoblinTranslator {
WireToBeShiftedEntities<DataType>::get_all(),
DerivedWitnessEntities<DataType>::get_all(),
ConcatenatedRangeConstraints<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(WireNonshiftedEntities<DataType>::pointer_view(),
WireToBeShiftedEntities<DataType>::pointer_view(),
DerivedWitnessEntities<DataType>::pointer_view(),
ConcatenatedRangeConstraints<DataType>::pointer_view())

RefVector<DataType> get_wires()
{
Expand Down Expand Up @@ -654,9 +650,6 @@ class GoblinTranslator {
DEFINE_COMPOUND_GET_ALL(PrecomputedEntities<DataType>::get_all(),
WitnessEntities<DataType>::get_all(),
ShiftedEntities<DataType>::get_all())
DEFINE_COMPOUND_POINTER_VIEW(PrecomputedEntities<DataType>::pointer_view(),
WitnessEntities<DataType>::pointer_view(),
ShiftedEntities<DataType>::pointer_view())
/**
* @brief Get the polynomials that are concatenated for the permutation relation
*
Expand Down Expand Up @@ -1036,8 +1029,8 @@ class GoblinTranslator {
[[nodiscard]] AllValues get_row(size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -1064,8 +1057,8 @@ class GoblinTranslator {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -1086,8 +1079,8 @@ class GoblinTranslator {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
8 changes: 4 additions & 4 deletions barretenberg/cpp/src/barretenberg/flavor/goblin_ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ class GoblinUltra {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down Expand Up @@ -363,8 +363,8 @@ class GoblinUltra {
[[nodiscard]] AllValues get_row(size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), this->pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand Down
8 changes: 4 additions & 4 deletions barretenberg/cpp/src/barretenberg/flavor/ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ class Ultra {
[[nodiscard]] AllValues get_row(const size_t row_idx) const
{
AllValues result;
for (auto [result_field, polynomial] : zip_view(result.pointer_view(), pointer_view())) {
*result_field = (*polynomial)[row_idx];
for (auto [result_field, polynomial] : zip_view(result.get_all(), get_all())) {
result_field = polynomial[row_idx];
}
return result;
}
Expand All @@ -283,8 +283,8 @@ class Ultra {
PartiallyEvaluatedMultivariates(const size_t circuit_size)
{
// Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2)
for (auto* poly : this->pointer_view()) {
*poly = Polynomial(circuit_size / 2);
for (auto& poly : this->get_all()) {
poly = Polynomial(circuit_size / 2);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void compute_permutation_grand_product(const size_t circuit_size,
for (size_t i = start; i < end; ++i) {

typename Flavor::AllValues evaluations;
for (auto [eval, poly] : zip_view(evaluations.pointer_view(), full_polynomials.pointer_view())) {
*eval = poly->size() > i ? (*poly)[i] : 0;
for (auto [eval, poly] : zip_view(evaluations.get_all(), full_polynomials.get_all())) {
eval = poly.size() > i ? poly[i] : 0;
}
numerator[i] = GrandProdRelation::template compute_permutation_numerator<Accumulator>(evaluations,
relation_parameters);
Expand Down
Loading