Skip to content

Commit

Permalink
refactor(bb): use std::span in pippenger for scalars (#8269)
Browse files Browse the repository at this point in the history
Refactoring stepping stone. Behaves identically

Next step would be to use this to allow accessing power of 2 quantities
above the std::span size() (with a different wrapper class) so that
non-powers-of-2 can be passed directly to pippenger

We recently anted to save memory on polynomials. The idea is that
instead of rounding up to a power of 2 to make pippenger fast (at cost
of memory), we will make a wrapper class that happily pretends it has
T{} (i.e. zeroes) anywhere form 0 to nearest rounded up power of 2. For
starters this just introduces a std::span, which should behave
identically
  • Loading branch information
ludamad authored Aug 29, 2024
1 parent 2b8af9e commit 2323cd5
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ int pippenger()
scalar_multiplication::pippenger_runtime_state<curve::BN254> state(NUM_POINTS);
std::chrono::steady_clock::time_point time_start = std::chrono::steady_clock::now();
g1::element result = scalar_multiplication::pippenger_unsafe<curve::BN254>(
&scalars[0], reference_string->get_monomial_points(), NUM_POINTS, state);
{ &scalars[0], /*size*/ NUM_POINTS }, reference_string->get_monomial_points(), NUM_POINTS, state);
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
std::chrono::microseconds diff = std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_start);
std::cout << "run time: " << diff.count() << "us" << std::endl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ template <class Curve> class CommitmentKey {
ASSERT(false);
}
return scalar_multiplication::pippenger_unsafe<Curve>(
const_cast<Fr*>(polynomial.data()), srs->get_monomial_points(), degree, pippenger_runtime_state);
polynomial, srs->get_monomial_points(), degree, pippenger_runtime_state);
};

/**
Expand Down Expand Up @@ -146,7 +146,7 @@ template <class Curve> class CommitmentKey {

// Call the version of pippenger which assumes all points are distinct
return scalar_multiplication::pippenger_unsafe<Curve>(
scalars.data(), points.data(), scalars.size(), pippenger_runtime_state);
scalars, points.data(), scalars.size(), pippenger_runtime_state);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ template <typename Curve_> class IPA {
// Step 6.a (using letters, because doxygen automaticall converts the sublist counters to letters :( )
// L_i = < a_vec_lo, G_vec_hi > + inner_prod_L * aux_generator
L_i = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&a_vec[0], &G_vec_local[round_size], round_size, ck->pippenger_runtime_state);
{&a_vec[0], /*size*/ round_size}, &G_vec_local[round_size], round_size, ck->pippenger_runtime_state);
L_i += aux_generator * inner_prod_L;

// Step 6.b
// R_i = < a_vec_hi, G_vec_lo > + inner_prod_R * aux_generator
R_i = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&a_vec[round_size], &G_vec_local[0], round_size, ck->pippenger_runtime_state);
{&a_vec[round_size], /*size*/ round_size}, &G_vec_local[0], round_size, ck->pippenger_runtime_state);
R_i += aux_generator * inner_prod_R;

// Step 6.c
Expand Down Expand Up @@ -345,7 +345,7 @@ template <typename Curve_> class IPA {
// Step 5.
// Compute C₀ = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j
GroupElement LR_sums = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&msm_scalars[0], &msm_elements[0], pippenger_size, vk->pippenger_runtime_state);
{&msm_scalars[0], /*size*/ pippenger_size}, &msm_elements[0], pippenger_size, vk->pippenger_runtime_state);
GroupElement C_zero = C_prime + LR_sums;

// Step 6.
Expand Down Expand Up @@ -394,7 +394,7 @@ template <typename Curve_> class IPA {
// Step 8.
// Compute G₀
Commitment G_zero = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&s_vec[0], &G_vec_local[0], poly_length, vk->pippenger_runtime_state);
{&s_vec[0], /*size*/ poly_length}, &G_vec_local[0], poly_length, vk->pippenger_runtime_state);

// Step 9.
// Receive a₀ from the prover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ template <typename Curve>
void compute_wnaf_states(uint64_t* point_schedule,
bool* input_skew_table,
uint64_t* round_counts,
const typename Curve::ScalarField* scalars,
const std::span<const typename Curve::ScalarField> scalars,
const size_t num_initial_points)
{
using Fr = typename Curve::ScalarField;
Expand Down Expand Up @@ -857,7 +857,7 @@ typename Curve::Element evaluate_pippenger_rounds(pippenger_runtime_state<Curve>

template <typename Curve>
typename Curve::Element pippenger_internal(typename Curve::AffineElement* points,
typename Curve::ScalarField* scalars,
std::span<const typename Curve::ScalarField> scalars,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
bool handle_edge_cases)
Expand All @@ -871,7 +871,7 @@ typename Curve::Element pippenger_internal(typename Curve::AffineElement* points
}

template <typename Curve>
typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
Expand Down Expand Up @@ -910,10 +910,9 @@ typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
const auto num_slice_points = static_cast<size_t>(1ULL << slice_bits);

Element result = pippenger_internal(points, scalars, num_slice_points, state, handle_edge_cases);

if (num_slice_points != num_initial_points) {
const uint64_t leftover_points = num_initial_points - num_slice_points;
return result + pippenger(scalars + num_slice_points,
return result + pippenger(scalars.subspan(num_slice_points),
points + static_cast<size_t>(num_slice_points * 2),
static_cast<size_t>(leftover_points),
state,
Expand All @@ -938,7 +937,7 @@ typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
*
**/
template <typename Curve>
typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger_unsafe(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state)
Expand All @@ -947,10 +946,11 @@ typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
}

template <typename Curve>
typename Curve::Element pippenger_without_endomorphism_basis_points(typename Curve::ScalarField* scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state)
typename Curve::Element pippenger_without_endomorphism_basis_points(
std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state)
{
std::vector<typename Curve::AffineElement> G_mod(num_initial_points * 2);
bb::scalar_multiplication::generate_pippenger_point_table<Curve>(points, &G_mod[0], num_initial_points);
Expand Down Expand Up @@ -978,7 +978,7 @@ template void evaluate_addition_chains<curve::BN254>(affine_product_runtime_stat
const size_t max_bucket_bits,
bool handle_edge_cases);
template curve::BN254::Element pippenger_internal<curve::BN254>(curve::BN254::AffineElement* points,
curve::BN254::ScalarField* scalars,
std::span<const curve::BN254::ScalarField> scalars,
const size_t num_initial_points,
pippenger_runtime_state<curve::BN254>& state,
bool handle_edge_cases);
Expand All @@ -992,19 +992,19 @@ template curve::BN254::AffineElement* reduce_buckets<curve::BN254>(affine_produc
bool first_round = true,
bool handle_edge_cases = false);

template curve::BN254::Element pippenger<curve::BN254>(curve::BN254::ScalarField* scalars,
template curve::BN254::Element pippenger<curve::BN254>(std::span<const curve::BN254::ScalarField> scalars,
curve::BN254::AffineElement* points,
const size_t num_points,
pippenger_runtime_state<curve::BN254>& state,
bool handle_edge_cases = true);

template curve::BN254::Element pippenger_unsafe<curve::BN254>(curve::BN254::ScalarField* scalars,
template curve::BN254::Element pippenger_unsafe<curve::BN254>(std::span<const curve::BN254::ScalarField> scalars,
curve::BN254::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::BN254>& state);

template curve::BN254::Element pippenger_without_endomorphism_basis_points<curve::BN254>(
curve::BN254::ScalarField* scalars,
std::span<const curve::BN254::ScalarField> scalars,
curve::BN254::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::BN254>& state);
Expand All @@ -1028,11 +1028,12 @@ template void add_affine_points_with_edge_cases<curve::Grumpkin>(curve::Grumpkin
template void evaluate_addition_chains<curve::Grumpkin>(affine_product_runtime_state<curve::Grumpkin>& state,
const size_t max_bucket_bits,
bool handle_edge_cases);
template curve::Grumpkin::Element pippenger_internal<curve::Grumpkin>(curve::Grumpkin::AffineElement* points,
curve::Grumpkin::ScalarField* scalars,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state,
bool handle_edge_cases);
template curve::Grumpkin::Element pippenger_internal<curve::Grumpkin>(
curve::Grumpkin::AffineElement* points,
std::span<const curve::Grumpkin::ScalarField> scalars,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state,
bool handle_edge_cases);

template curve::Grumpkin::Element evaluate_pippenger_rounds<curve::Grumpkin>(
pippenger_runtime_state<curve::Grumpkin>& state,
Expand All @@ -1043,19 +1044,20 @@ template curve::Grumpkin::Element evaluate_pippenger_rounds<curve::Grumpkin>(
template curve::Grumpkin::AffineElement* reduce_buckets<curve::Grumpkin>(
affine_product_runtime_state<curve::Grumpkin>& state, bool first_round = true, bool handle_edge_cases = false);

template curve::Grumpkin::Element pippenger<curve::Grumpkin>(curve::Grumpkin::ScalarField* scalars,
template curve::Grumpkin::Element pippenger<curve::Grumpkin>(std::span<const curve::Grumpkin::ScalarField> scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_points,
pippenger_runtime_state<curve::Grumpkin>& state,
bool handle_edge_cases = true);

template curve::Grumpkin::Element pippenger_unsafe<curve::Grumpkin>(curve::Grumpkin::ScalarField* scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state);
template curve::Grumpkin::Element pippenger_unsafe<curve::Grumpkin>(
std::span<const curve::Grumpkin::ScalarField> scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state);

template curve::Grumpkin::Element pippenger_without_endomorphism_basis_points<curve::Grumpkin>(
curve::Grumpkin::ScalarField* scalars,
std::span<const curve::Grumpkin::ScalarField> scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ template <typename Curve>
void compute_wnaf_states(uint64_t* point_schedule,
bool* input_skew_table,
uint64_t* round_counts,
const typename Curve::ScalarField* scalars,
std::span<const typename Curve::ScalarField> scalars,
size_t num_initial_points);

template <typename Curve>
Expand Down Expand Up @@ -135,7 +135,7 @@ void evaluate_addition_chains(affine_product_runtime_state<Curve>& state,
bool handle_edge_cases);
template <typename Curve>
typename Curve::Element pippenger_internal(typename Curve::AffineElement* points,
typename Curve::ScalarField* scalars,
std::span<const typename Curve::ScalarField> scalars,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
bool handle_edge_cases);
Expand All @@ -152,23 +152,24 @@ typename Curve::AffineElement* reduce_buckets(affine_product_runtime_state<Curve
bool handle_edge_cases = false);

template <typename Curve>
typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
bool handle_edge_cases = true);

template <typename Curve>
typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger_unsafe(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state);

template <typename Curve>
typename Curve::Element pippenger_without_endomorphism_basis_points(typename Curve::ScalarField* scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state);
typename Curve::Element pippenger_without_endomorphism_basis_points(
std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state);

// Explicit instantiation
// BN254
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ template <typename program_settings> bool VerifierBase<program_settings>::verify

g1::element P[2];

P[0] = bb::scalar_multiplication::pippenger<curve::BN254>(&scalars[0], &elements[0], num_elements, state);
P[0] = bb::scalar_multiplication::pippenger<curve::BN254>(
{ &scalars[0], num_elements }, &elements[0], num_elements, state);
P[1] = -(g1::element(PI_Z_OMEGA) * separator_challenge + PI_Z);

if (key->contains_recursive_proof) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ plonk::Verifier generate_verifier(std::shared_ptr<proving_key> circuit_proving_k
commitments.resize(8);

for (size_t i = 0; i < 8; ++i) {
commitments[i] = g1::affine_element(
scalar_multiplication::pippenger<curve::BN254>(poly_coefficients[i].get(),
circuit_proving_key->reference_string->get_monomial_points(),
circuit_proving_key->circuit_size,
state));
commitments[i] = g1::affine_element(scalar_multiplication::pippenger<curve::BN254>(
{ poly_coefficients[i].get(), circuit_proving_key->circuit_size },
circuit_proving_key->reference_string->get_monomial_points(),
circuit_proving_key->circuit_size,
state));
}

auto crs = std::make_shared<bb::srs::factories::FileVerifierCrs<curve::BN254>>("../srs_db/ignition");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void work_queue::process_queue()
// Run pippenger multi-scalar multiplication.
auto runtime_state = bb::scalar_multiplication::pippenger_runtime_state<curve::BN254>(msm_size);
bb::g1::affine_element result(bb::scalar_multiplication::pippenger_unsafe<curve::BN254>(
item.mul_scalars.get(), srs_points, msm_size, runtime_state));
{ item.mul_scalars.get(), msm_size }, srs_points, msm_size, runtime_state));

transcript->add_element(item.tag, result.to_buffer());

Expand Down
Loading

0 comments on commit 2323cd5

Please sign in to comment.