Skip to content

Commit

Permalink
feat: Efficient ZM quotient computation (#3016)
Browse files Browse the repository at this point in the history
Implements the efficient algorithm detailed in the ZM paper for
computing the multilinear quotients $q_k$ that are fundamental to the ZM
protocol. This replaces the original naive (and inefficient)
implementation. This work does not address parallelism in all possible
locations.

Co-authored-by: TohruKohrita <tohru@aztecprotocol.com>
  • Loading branch information
ledwards2225 and TohruKohrita authored Oct 27, 2023
1 parent 84f8db2 commit ebda5fc
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions barretenberg/cpp/src/barretenberg/honk/pcs/zeromorph/zeromorph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,21 @@ template <typename Curve> class ZeroMorphProver_ {

public:
/**
* @brief Compute multivariate quotients q_k(X_0, ..., X_{k-1}) for f(X_0, ..., X_{d-1})
* @details Given multilinear polynomial f = f(X_0, ..., X_{d-1}) for which f(u) = v, compute q_k such that:
* @brief Compute multivariate quotients q_k(X_0, ..., X_{k-1}) for f(X_0, ..., X_{n-1})
* @details Starting from the coefficients of f, compute q_k inductively from k = n - 1, to k = 0.
* f needs to be updated at each step.
*
* f(X_0, ..., X_{d-1}) - v = \sum_{k=0}^{d-1} (X_k - u_k)q_k(X_0, ..., X_{k-1})
* First, compute q_{n-1} of size N/2 by
* q_{n-1}[l] = f[N/2 + l ] - f[l].
*
* The polynomials q_k can be computed explicitly as the difference of the partial evaluation of f in the last
* (n - k) variables at, respectively, u'' = (u_k + 1, u_{k+1}, ..., u_{n-1}) and u' = (u_k, ..., u_{n-1}). I.e.
* Update f by f[l] <- f[l] + u_{n-1} * q_{n-1}[l]; f now has size N/2.
* Compute q_{n-2} of size N/(2^2) by
* q_{n-2}[l] = f[N/2^2 + l] - f[l].
*
* q_k(X_0, ..., X_{k-1}) = f(X_0,...,X_{k-1}, u'') - f(X_0,...,X_{k-1}, u')
* Update f by f[l] <- f[l] + u_{n-2} * q_{n-2}[l]; f now has size N/(2^2).
* Compute q_{n-3} of size N/(2^3) by
* q_{n-3}[l] = f[N/2^3 + l] - f[l]. Repeat similarly until you reach q_0.
*
* @note In practice, 2^d is equal to the circuit size N
*
* TODO(#739): This method has been designed for clarity at the expense of efficiency. Implement the more efficient
* algorithm detailed in the latest versions of the ZeroMorph paper.
* @param polynomial Multilinear polynomial f(X_0, ..., X_{d-1})
* @param u_challenge Multivariate challenge u = (u_0, ..., u_{d-1})
* @return std::vector<Polynomial> The quotients q_k
Expand All @@ -68,26 +69,36 @@ template <typename Curve> class ZeroMorphProver_ {
quotients.emplace_back(Polynomial(size)); // degree 2^k - 1
}

// Compute the q_k in reverse order, i.e. q_{n-1}, ..., q_0
for (size_t k = 0; k < log_N; ++k) {
// Define partial evaluation point u' = (u_k, ..., u_{n-1})
auto evaluation_point_size = static_cast<std::ptrdiff_t>(k + 1);
std::vector<FF> u_partial(u_challenge.end() - evaluation_point_size, u_challenge.end());
// Compute the coefficients of q_{n-1}
size_t size_q = 1 << (log_N - 1);
Polynomial q = Polynomial(size_q);
for (size_t l = 0; l < size_q; ++l) {
q[l] = polynomial[size_q + l] - polynomial[l];
}

// Compute f' = f(X_0,...,X_{k-1}, u')
auto f_1 = polynomial.partial_evaluate_mle(u_partial);
quotients[log_N - 1] = q;

// Increment first element to get altered partial evaluation point u'' = (u_k + 1, u_{k+1}, ..., u_{n-1})
u_partial[0] += 1;
std::vector<FF> f_k;
f_k.resize(size_q);

// Compute f'' = f(X_0,...,X_{k-1}, u'')
auto f_2 = polynomial.partial_evaluate_mle(u_partial);
std::vector<FF> g(polynomial.data().get(), polynomial.data().get() + size_q);

// Compute q_k = f''(X_0,...,X_{k-1}) - f'(X_0,...,X_{k-1})
auto q_k = f_2;
q_k -= f_1;
// Compute q_k in reverse order from k= n-2, i.e. q_{n-2}, ..., q_0
for (size_t k = 1; k < log_N; ++k) {
// Compute f_k
for (size_t l = 0; l < size_q; ++l) {
f_k[l] = g[l] + u_challenge[log_N - k] * q[l];
}

size_q = size_q / 2;
q = Polynomial(size_q);

for (size_t l = 0; l < size_q; ++l) {
q[l] = f_k[size_q + l] - f_k[l];
}

quotients[log_N - k - 1] = q_k;
quotients[log_N - k - 1] = q;
g = f_k;
}

return quotients;
Expand Down

0 comments on commit ebda5fc

Please sign in to comment.