Skip to content

Commit

Permalink
feat: generalize protogalaxy to multiple instances (#5510)
Browse files Browse the repository at this point in the history
Closes AztecProtocol/barretenberg#764.

This PR aims to generalize protogalaxy to multiple instances. In particular, we care about k=2 (1 accumulator and 2 instances) and k=3 and and their performance relative to folding k=1 instance.

We achieve the following numbers:
```
--------------------------------------------------------------------------
Benchmark                                Time             CPU   Iterations
--------------------------------------------------------------------------
fold_k<UltraFlavor, 1>/16             1039 ms          908 ms            1
fold_k<UltraFlavor, 2>/16             1744 ms         1562 ms            1
fold_k<UltraFlavor, 3>/16             2755 ms         2484 ms            1
fold_k<GoblinUltraFlavor, 1>/16       1431 ms         1231 ms            1
fold_k<GoblinUltraFlavor, 2>/16       2387 ms         2084 ms            1
fold_k<GoblinUltraFlavor, 3>/16       3734 ms         3291 ms            1
```

and client IVC benchmark stays the same:
```
--------------------------------------------------------------------------------    
Benchmark                      Time             CPU   Iterations 
--------------------------------------------------------------------------------    
ClientIVCBench/Full/6      23140 ms        17976 ms            1    
Benchmarking lock deleted.                                                          
client_ivc_bench.json                             100% 3561   106.5KB/s   00:00     
function                                        ms     % sum                        
construct_circuits(t)                         4522    19.75%                        
ProverInstance(Circuit&)(t)                   2060     9.00%
ProtogalaxyProver::fold_instances(t)         12545    54.81%
Decider::construct_proof(t)                    734     3.21%
ECCVMProver(CircuitBuilder&)(t)                158     0.69%
ECCVMProver::construct_proof(t)               1768     7.73%
GoblinTranslatorProver::construct_proof(t)     959     4.19%
Goblin::merge(t)                               145     0.63%

Total time accounted for: 22890ms/23140ms = 98.92%

Major contributors:
function                                        ms    % sum
commit(t)                                     4283   18.71%
compute_combiner(t)                           5702   24.91%
compute_perturbator(t)                        1250    5.46%
compute_univariate(t)                         1386    6.05%

Breakdown of ProtogalaxyProver::fold_instances:
ProtoGalaxyProver_::preparation_round(t)           5297    42.22%
ProtoGalaxyProver_::perturbator_round(t)           1250     9.97%
ProtoGalaxyProver_::combiner_quotient_round(t)     5704    45.47%
ProtoGalaxyProver_::accumulator_update_round(t)     294     2.35%
```
  • Loading branch information
lucasxia01 authored Apr 9, 2024
1 parent c496a10 commit f038b70
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 98 deletions.
15 changes: 12 additions & 3 deletions barretenberg/cpp/scripts/analyze_client_ivc_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>8}")
for key in to_keep:
time_ms = bench[key]/1e6
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

# Validate that kept times account for most of the total measured time.
Expand All @@ -45,7 +48,10 @@
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>7}")
for key in ['commit(t)', 'compute_combiner(t)', 'compute_perturbator(t)', 'compute_univariate(t)']:
time_ms = bench[key]/1e6
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

print('\nBreakdown of ProtogalaxyProver::fold_instances:')
Expand All @@ -57,7 +63,10 @@
]
max_label_length = max(len(label) for label in protogalaxy_round_labels)
for key in protogalaxy_round_labels:
time_ms = bench[key]/1e6
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
total_time_ms = bench["ProtogalaxyProver::fold_instances(t)"]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/total_time_ms:>8.2%}")

Expand Down
61 changes: 61 additions & 0 deletions barretenberg/cpp/scripts/analyze_protogalaxy_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json
from pathlib import Path

PREFIX = Path("build-op-count-time")
PROTOGALAXY_BENCH_JSON = Path("protogalaxy_bench.json")
BENCHMARK = "fold_k<GoblinUltraFlavor, 3>/16"

# Single out an independent set of functions accounting for most of BENCHMARK's real_time
to_keep = [
"ProtogalaxyProver::fold_instances(t)",
]
with open(PREFIX/PROTOGALAXY_BENCH_JSON, "r") as read_file:
read_result = json.load(read_file)
for _bench in read_result["benchmarks"]:
print(_bench)
if _bench["name"] == BENCHMARK:
bench = _bench
bench_components = dict(filter(lambda x: x[0] in to_keep, bench.items()))

# For each kept time, get the proportion over all kept times.
sum_of_kept_times_ms = sum(float(time)
for _, time in bench_components.items())/1e6
max_label_length = max(len(label) for label in to_keep)
column = {"function": "function", "ms": "ms", "%": "% sum"}
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>8}")
for key in to_keep:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

# Validate that kept times account for most of the total measured time.
total_time_ms = bench["real_time"]
totals = '\nTotal time accounted for: {:.0f}ms/{:.0f}ms = {:.2%}'
totals = totals.format(
sum_of_kept_times_ms, total_time_ms, sum_of_kept_times_ms/total_time_ms)
print(totals)

print("\nMajor contributors:")
print(
f"{column['function']:<{max_label_length}}{column['ms']:>8} {column['%']:>7}")
for key in ['commit(t)', 'compute_combiner(t)', 'compute_perturbator(t)', 'compute_univariate(t)']:
if key not in bench:
time_ms = 0
else:
time_ms = bench[key]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/sum_of_kept_times_ms:>8.2%}")

print('\nBreakdown of ProtogalaxyProver::fold_instances:')
protogalaxy_round_labels = [
"ProtoGalaxyProver_::preparation_round(t)",
"ProtoGalaxyProver_::perturbator_round(t)",
"ProtoGalaxyProver_::combiner_quotient_round(t)",
"ProtoGalaxyProver_::accumulator_update_round(t)"
]
max_label_length = max(len(label) for label in protogalaxy_round_labels)
for key in protogalaxy_round_labels:
time_ms = bench[key]/1e6
total_time_ms = bench["ProtogalaxyProver::fold_instances(t)"]/1e6
print(f"{key:<{max_label_length}}{time_ms:>8.0f} {time_ms/total_time_ms:>8.2%}")


25 changes: 25 additions & 0 deletions barretenberg/cpp/scripts/benchmark_protogalaxy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
set -eu

TARGET="protogalaxy_bench"
FILTER="/16$"
BUILD_DIR=build-op-count-time

# Move above script dir.
cd $(dirname $0)/..

# Measure the benchmarks with ops time counting
./scripts/benchmark_remote.sh protogalaxy_bench\
"./protogalaxy_bench --benchmark_filter=$FILTER\
--benchmark_out=$TARGET.json\
--benchmark_out_format=json"\
op-count-time\
build-op-count-time

# Retrieve output from benching instance
cd $BUILD_DIR
scp $BB_SSH_KEY $BB_SSH_INSTANCE:$BB_SSH_CPP_PATH/build/$TARGET.json .

# Analyze the results
cd ../
python3 ./scripts/analyze_protogalaxy_bench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <benchmark/benchmark.h>

#include "barretenberg/common/op_count_google_bench.hpp"
#include "barretenberg/protogalaxy/protogalaxy_prover.hpp"
#include "barretenberg/stdlib_circuit_builders/mock_circuits.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"
Expand All @@ -11,11 +12,11 @@ using namespace benchmark;
namespace bb {

// Fold one instance into an accumulator.
template <typename Flavor> void fold_one(State& state) noexcept
template <typename Flavor, size_t k> void fold_k(State& state) noexcept
{
using ProverInstance = ProverInstance_<Flavor>;
using Instance = ProverInstance;
using Instances = ProverInstances_<Flavor, 2>;
using Instances = ProverInstances_<Flavor, k + 1>;
using ProtoGalaxyProver = ProtoGalaxyProver_<Instances>;
using Builder = typename Flavor::CircuitBuilder;

Expand All @@ -28,19 +29,29 @@ template <typename Flavor> void fold_one(State& state) noexcept
MockCircuits::construct_arithmetic_circuit(builder, log2_num_gates);
return std::make_shared<ProverInstance>(builder);
};
std::vector<std::shared_ptr<Instance>> instances;
// TODO(https://github.com/AztecProtocol/barretenberg/issues/938): Parallelize this loop
for (size_t i = 0; i < k + 1; ++i) {
instances.emplace_back(construct_instance());
}

std::shared_ptr<Instance> instance_1 = construct_instance();
std::shared_ptr<Instance> instance_2 = construct_instance();

ProtoGalaxyProver folding_prover({ instance_1, instance_2 });
ProtoGalaxyProver folding_prover(instances);

for (auto _ : state) {
BB_REPORT_OP_COUNT_IN_BENCH(state);
auto proof = folding_prover.fold_instances();
}
}

BENCHMARK(fold_one<UltraFlavor>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_one<GoblinUltraFlavor>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<UltraFlavor, 1>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<GoblinUltraFlavor, 1>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);

BENCHMARK(fold_k<UltraFlavor, 2>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<GoblinUltraFlavor, 2>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);

BENCHMARK(fold_k<UltraFlavor, 3>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);
BENCHMARK(fold_k<GoblinUltraFlavor, 3>)->/* vary the circuit size */ DenseRange(14, 20)->Unit(kMillisecond);

} // namespace bb

BENCHMARK_MAIN();
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ void _bench_round(::benchmark::State& state, void (*F)(ProtoGalaxyProver_<Prover
return std::make_shared<ProverInstance>(builder);
};

// TODO(https://github.com/AztecProtocol/barretenberg/issues/938): Parallelize this loop, also extend to more than
// k=1
std::shared_ptr<ProverInstance> prover_instance_1 = construct_instance();
std::shared_ptr<ProverInstance> prover_instance_2 = construct_instance();

Expand Down
83 changes: 82 additions & 1 deletion barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,94 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate

std::copy(evaluations.begin(), evaluations.end(), result.evaluations.begin());

static constexpr Fr inverse_two = Fr(2).invert();
if constexpr (LENGTH == 2) {
Fr delta = value_at(1) - value_at(0);
static_assert(EXTENDED_LENGTH != 0);
for (size_t idx = domain_start; idx < EXTENDED_DOMAIN_END - 1; idx++) {
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + delta;
}
return result;
} else if constexpr (LENGTH == 3) {
// Based off https://hackmd.io/@aztec-network/SyR45cmOq?type=view
// The technique used here is the same as the length == 3 case below.
Fr a = (value_at(2) + value_at(0)) * inverse_two - value_at(1);
Fr b = value_at(1) - a - value_at(0);
Fr a2 = a + a;
Fr a_mul = a2;
for (size_t i = 0; i < domain_end - 2; i++) {
a_mul += a2;
}
Fr extra = a_mul + a + b;
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + extra;
extra += a2;
}
return result;
} else if constexpr (LENGTH == 4) {
static constexpr Fr inverse_six = Fr(6).invert(); // computed at compile time for efficiency

// To compute a barycentric extension, we can compute the coefficients of the univariate.
// We have the evaluation of the polynomial at the domain (which is assumed to be 0, 1, 2, 3).
// Therefore, we have the 4 linear equations from plugging into f(x) = ax^3 + bx^2 + cx + d:
// a*0 + b*0 + c*0 + d = f(0)
// a*1 + b*1 + c*1 + d = f(1)
// a*2^3 + b*2^2 + c*2 + d = f(2)
// a*3^3 + b*3^2 + c*3 + d = f(3)
// These equations can be rewritten as a matrix equation M * [a, b, c, d] = [f(0), f(1), f(2), f(3)], where
// M is:
// 0, 0, 0, 1
// 1, 1, 1, 1
// 2^3, 2^2, 2, 1
// 3^3, 3^2, 3, 1
// We can invert this matrix in order to compute a, b, c, d:
// -1/6, 1/2, -1/2, 1/6
// 1, -5/2, 2, -1/2
// -11/6, 3, -3/2, 1/3
// 1, 0, 0, 0
// To compute these values, we can multiply everything by 6 and multiply by inverse_six at the end for each
// coefficient The resulting computation here does 18 field adds, 6 subtracts, 3 muls to compute a, b, c,
// and d.
Fr zero_times_3 = value_at(0) + value_at(0) + value_at(0);
Fr zero_times_6 = zero_times_3 + zero_times_3;
Fr zero_times_12 = zero_times_6 + zero_times_6;
Fr one_times_3 = value_at(1) + value_at(1) + value_at(1);
Fr one_times_6 = one_times_3 + one_times_3;
Fr two_times_3 = value_at(2) + value_at(2) + value_at(2);
Fr three_times_2 = value_at(3) + value_at(3);
Fr three_times_3 = three_times_2 + value_at(3);

Fr one_minus_two_times_3 = one_times_3 - two_times_3;
Fr one_minus_two_times_6 = one_minus_two_times_3 + one_minus_two_times_3;
Fr one_minus_two_times_12 = one_minus_two_times_6 + one_minus_two_times_6;
Fr a = (one_minus_two_times_3 + value_at(3) - value_at(0)) * inverse_six; // compute a in 1 muls and 4 adds
Fr b = (zero_times_6 - one_minus_two_times_12 - one_times_3 - three_times_3) * inverse_six;
Fr c = (value_at(0) - zero_times_12 + one_minus_two_times_12 + one_times_6 + two_times_3 + three_times_2) *
inverse_six;

// Then, outside of the a, b, c, d computation, we need to do some extra precomputation
// This work is 3 field muls, 8 adds
Fr a_plus_b = a + b;
Fr a_plus_b_times_2 = a_plus_b + a_plus_b;
size_t start_idx_sqr = (domain_end - 1) * (domain_end - 1);
size_t idx_sqr_three = start_idx_sqr + start_idx_sqr + start_idx_sqr;
Fr idx_sqr_three_times_a = Fr(idx_sqr_three) * a;
Fr x_a_term = Fr(6 * (domain_end - 1)) * a;
Fr three_a = a + a + a;
Fr six_a = three_a + three_a;

Fr three_a_plus_two_b = a_plus_b_times_2 + a;
Fr linear_term = Fr(domain_end - 1) * three_a_plus_two_b + (a_plus_b + c);
// For each new evaluation, we do only 6 field additions and 0 muls.
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + idx_sqr_three_times_a + linear_term;

idx_sqr_three_times_a += x_a_term + three_a;
x_a_term += six_a;

linear_term += three_a_plus_two_b;
}
return result;
} else {
for (size_t k = domain_end; k != EXTENDED_DOMAIN_END; ++k) {
result.value_at(k) = 0;
Expand Down
Loading

0 comments on commit f038b70

Please sign in to comment.