Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce the number of rounds in batching #93

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 48 additions & 48 deletions subroutines/src/pcs/multilinear_kzg/batching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ use crate::{
poly_iop::{prelude::SumCheck, PolyIOP},
IOPProof,
};
use arithmetic::{
build_eq_x_r_vec, fix_last_variables, DenseMultilinearExtension, VPAuxInfo, VirtualPolynomial,
};
use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, VirtualPolynomial};
use ark_ec::{AffineCurve, PairingEngine, ProjectiveCurve};
use ark_std::{end_timer, log2, start_timer, One, Zero};
use std::{marker::PhantomData, rc::Rc};
Expand All @@ -38,10 +36,11 @@ where
/// Steps:
/// 1. get challenge point t from transcript
/// 2. build eq(t,i) for i in [0..k]
/// 3. build \tilde g(i, b) = eq(t, i) * f_i(b)
/// 4. compute \tilde eq
/// 5. run sumcheck on \tilde eq * \tilde g(i, b)
/// 6. build g'(a2) where (a1, a2) is the sumcheck's point
/// 3. build \tilde g_i(b) = eq(t, i) * f_i(b)
/// 4. compute \tilde eq_i(b) = eq(b, point_i)
/// 5. run sumcheck on \sum_i=1..k \tilde eq_i * \tilde g_i
/// 6. build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is
/// the sumcheck's point 7. open g'(X) at point (a2)
pub(crate) fn multi_open_internal<E, PCS>(
prover_param: &PCS::ProverParam,
polynomials: &[PCS::Polynomial],
Expand All @@ -64,48 +63,46 @@ where
let num_var = polynomials[0].num_vars;
let k = polynomials.len();
let ell = log2(k) as usize;
let merged_num_var = num_var + ell;

// challenge point t
let t = transcript.get_and_append_challenge_vectors("t".as_ref(), ell)?;

// eq(t, i) for i in [0..k]
let eq_t_i_list = build_eq_x_r_vec(t.as_ref())?;

// \tilde g(i, b) = eq(t, i) * f_i(b)
// \tilde g_i(b) = eq(t, i) * f_i(b)
let timer = start_timer!(|| format!("compute tilde g for {} points", points.len()));
let mut tilde_g_eval = vec![E::Fr::zero(); 1 << (ell + num_var)];
let block_size = 1 << num_var;
let mut tilde_gs = vec![];
for (index, f_i) in polynomials.iter().enumerate() {
let mut tilde_g_eval = vec![E::Fr::zero(); 1 << num_var];
for (j, &f_i_eval) in f_i.iter().enumerate() {
tilde_g_eval[index * block_size + j] = f_i_eval * eq_t_i_list[index];
tilde_g_eval[j] = f_i_eval * eq_t_i_list[index];
}
tilde_gs.push(Rc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var,
tilde_g_eval,
)));
}
let tilde_g = Rc::new(DenseMultilinearExtension::from_evaluations_vec(
merged_num_var,
tilde_g_eval,
));
end_timer!(timer);

let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len()));
let mut tilde_eq_eval = vec![E::Fr::zero(); 1 << (ell + num_var)];
for (index, point) in points.iter().enumerate() {
let mut tilde_eqs = vec![];
for point in points.iter() {
let eq_b_zi = build_eq_x_r_vec(point)?;
let start = index * block_size;
tilde_eq_eval[start..start + block_size].copy_from_slice(eq_b_zi.as_slice());
tilde_eqs.push(Rc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var, eq_b_zi,
)));
}
let tilde_eq = Rc::new(DenseMultilinearExtension::from_evaluations_vec(
merged_num_var,
tilde_eq_eval,
));
end_timer!(timer);

// built the virtual polynomial for SumCheck
let timer = start_timer!(|| format!("sum check prove of {} variables", num_var + ell));
let timer = start_timer!(|| format!("sum check prove of {} variables", num_var));

let step = start_timer!(|| "add mle");
let mut sum_check_vp = VirtualPolynomial::new(num_var + ell);
sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?;
let mut sum_check_vp = VirtualPolynomial::new(num_var);
for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) {
sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?;
}
end_timer!(step);

let proof = match <PolyIOP<E::Fr> as SumCheck<E::Fr>>::prove(&sum_check_vp, transcript) {
Expand All @@ -120,15 +117,23 @@ where

end_timer!(timer);

// (a1, a2) := sumcheck's point
let step = start_timer!(|| "open at a2");
let a1 = &proof.point[num_var..];
// a2 := sumcheck's point
let a2 = &proof.point[..num_var];
end_timer!(step);

// build g'(a2)
// build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the
// sumcheck's point \tilde eq_i(a2) = eq(a2, point_i)
let step = start_timer!(|| "evaluate at a2");
let g_prime = Rc::new(fix_last_variables(&tilde_g, a1));
let mut g_prime_evals = vec![E::Fr::zero(); 1 << num_var];
for (tilde_g, point) in tilde_gs.iter().zip(points.iter()) {
let eq_i_a2 = eq_eval(a2, point)?;
for (j, &tilde_g_eval) in tilde_g.iter().enumerate() {
g_prime_evals[j] += tilde_g_eval * eq_i_a2;
}
}
let g_prime = Rc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var,
g_prime_evals,
));
end_timer!(step);

let step = start_timer!(|| "pcs open");
Expand All @@ -150,8 +155,8 @@ where
/// Steps:
/// 1. get challenge point t from transcript
/// 2. build g' commitment
/// 3. ensure \sum_i eq(t, <i>) * f_i_evals matches the sum via SumCheck
/// verification 4. verify commitment
/// 3. ensure \sum_i eq(a2, point_i) * eq(t, <i>) * f_i_evals matches the sum
/// via SumCheck verification 4. verify commitment
pub(crate) fn batch_verify_internal<E, PCS>(
verifier_param: &PCS::VerifierParam,
f_i_commitments: &[Commitment<E>],
Expand All @@ -175,34 +180,33 @@ where

let k = f_i_commitments.len();
let ell = log2(k) as usize;
let num_var = proof.sum_check_proof.point.len() - ell;
let num_var = proof.sum_check_proof.point.len();

// challenge point t
let t = transcript.get_and_append_challenge_vectors("t".as_ref(), ell)?;

// sum check point (a1, a2)
let a1 = &proof.sum_check_proof.point[num_var..];
// sum check point (a2)
let a2 = &proof.sum_check_proof.point[..num_var];

// build g' commitment
let eq_a1_list = build_eq_x_r_vec(a1)?;
let eq_t_list = build_eq_x_r_vec(t.as_ref())?;

let mut g_prime_commit = E::G1Affine::zero().into_projective();
for i in 0..k {
let tmp = eq_a1_list[i] * eq_t_list[i];

for (i, point) in points.iter().enumerate() {
let eq_i_a2 = eq_eval(a2, point)?;
let tmp = eq_i_a2 * eq_t_list[i];
g_prime_commit += &f_i_commitments[i].0.mul(tmp);
}

// ensure \sum_i eq(t, <i>) * f_i_evals matches the sum via SumCheck
// verification
let mut sum = E::Fr::zero();
for (i, &e) in eq_t_list.iter().enumerate().take(k) {
sum += e * proof.f_i_eval_at_point_i[i];
}
let aux_info = VPAuxInfo {
max_degree: 2,
num_variables: num_var + ell,
num_variables: num_var,
phantom: PhantomData,
};
let subclaim = match <PolyIOP<E::Fr> as SumCheck<E::Fr>>::verify(
Expand All @@ -219,11 +223,7 @@ where
));
},
};
let mut eq_tilde_eval = E::Fr::zero();
for (point, &coef) in points.iter().zip(eq_a1_list.iter()) {
eq_tilde_eval += coef * eq_eval(a2, point)?;
}
let tilde_g_eval = subclaim.expected_evaluation / eq_tilde_eval;
let tilde_g_eval = subclaim.expected_evaluation;

// verify commitment
let res = PCS::verify(
Expand Down