Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize verifier eq #102

Merged
merged 5 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion arithmetic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ pub use multilinear_polynomial::{
};
pub use univariate_polynomial::{build_l, get_uni_domain};
pub use util::{bit_decompose, gen_eval_point, get_batched_nv, get_index};
pub use virtual_polynomial::{build_eq_x_r, build_eq_x_r_vec, VPAuxInfo, VirtualPolynomial};
pub use virtual_polynomial::{
build_eq_x_r, build_eq_x_r_vec, eq_eval, VPAuxInfo, VirtualPolynomial,
};
17 changes: 17 additions & 0 deletions arithmetic/src/virtual_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,23 @@ impl<F: PrimeField> VirtualPolynomial<F> {
}
}

/// Evaluate eq polynomial.
pub fn eq_eval<F: PrimeField>(x: &[F], y: &[F]) -> Result<F, ArithErrors> {
if x.len() != y.len() {
return Err(ArithErrors::InvalidParameters(
"x and y have different length".to_string(),
));
}
let start = start_timer!(|| "eq_eval");
let mut res = F::one();
for (&xi, &yi) in x.iter().zip(y.iter()) {
let xi_yi = xi * yi;
res *= xi_yi + xi_yi - xi - yi + F::one();
}
end_timer!(start);
Ok(res)
}

/// This function build the eq(x, r) polynomial for any given r.
///
/// Evaluate
Expand Down
1 change: 0 additions & 1 deletion hyperplonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ rayon = { version = "1.5.2", default-features = false, optional = true }

[dev-dependencies]
ark-bls12-381 = { version = "0.3.0", default-features = false, features = [ "curve" ] }

# Benchmarks
[[bench]]
name = "hyperplonk-benches"
Expand Down
17 changes: 11 additions & 6 deletions hyperplonk/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
};

pub struct MockCircuit<F: PrimeField> {
pub public_inputs: Vec<F>,
pub witnesses: Vec<WitnessColumn<F>>,
pub index: HyperPlonkIndex<F>,
}
Expand Down Expand Up @@ -85,10 +86,11 @@ impl<F: PrimeField> MockCircuit<F> {
witnesses[i].append(cur_witness[i]);
}
}
let public_inputs = witnesses[0].0[0..4].to_vec();

let params = HyperPlonkParams {
num_constraints,
num_pub_input: num_constraints,
num_pub_input: public_inputs.len(),
gate_func: gate.clone(),
};

Expand All @@ -99,7 +101,11 @@ impl<F: PrimeField> MockCircuit<F> {
selectors,
};

Self { witnesses, index }
Self {
public_inputs,
witnesses,
index,
}
}

pub fn is_satisfied(&self) -> bool {
Expand Down Expand Up @@ -144,7 +150,7 @@ mod test {

const SUPPORTED_SIZE: usize = 20;
const MIN_NUM_VARS: usize = 8;
const MAX_NUM_VARS: usize = 15;
const MAX_NUM_VARS: usize = 19;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it to 15 to decrease test time

const CUSTOM_DEGREE: [usize; 6] = [1, 2, 4, 8, 16, 32];

#[test]
Expand Down Expand Up @@ -177,7 +183,6 @@ mod test {
assert!(circuit.is_satisfied());

let index = circuit.index;

// generate pk and vks
let (pk, vk) =
<PolyIOP<Fr> as HyperPlonkSNARK<Bls12_381, MultilinearKzgPCS<Bls12_381>>>::preprocess(
Expand All @@ -187,14 +192,14 @@ mod test {
let proof =
<PolyIOP<Fr> as HyperPlonkSNARK<Bls12_381, MultilinearKzgPCS<Bls12_381>>>::prove(
&pk,
&circuit.witnesses[0].0,
&circuit.public_inputs,
&circuit.witnesses,
)?;

let verify =
<PolyIOP<Fr> as HyperPlonkSNARK<Bls12_381, MultilinearKzgPCS<Bls12_381>>>::verify(
&vk,
&circuit.witnesses[0].0,
&circuit.public_inputs,
&proof,
)?;
assert!(verify);
Expand Down
21 changes: 13 additions & 8 deletions hyperplonk/src/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,11 @@ where
// - 4.4. public input consistency checks
// - pi_poly(r_pi) where r_pi is sampled from transcript
let r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?;
let tmp_point = [vec![E::Fr::zero(); num_vars - ell], r_pi].concat();
pcs_acc.insert_poly_and_points(&witness_polys[0], &witness_commits[0], &tmp_point);
// padded with zeros
let r_pi_padded = [r_pi, vec![E::Fr::zero(); num_vars - ell]].concat();
// Evaluate witness_poly[0] at r_pi||0s which is equal to public_input evaluated
// at r_pi. Assumes that public_input is a power of 2
pcs_acc.insert_poly_and_points(&witness_polys[0], &witness_commits[0], &r_pi_padded);
end_timer!(step);

// =======================================================================
Expand Down Expand Up @@ -515,7 +518,7 @@ where
// =======================================================================
// 3. Verify the opening against the commitment
// =======================================================================
let step = start_timer!(|| "verify commitments");
let step = start_timer!(|| "assemble commitments");

// generate evaluation points and commitments
let mut comms = vec![];
Expand All @@ -535,7 +538,6 @@ where
points.push(perm_check_point_0.clone());
points.push(perm_check_point_1.clone());
points.push(prod_final_query_point);

// frac(x)'s points
comms.push(proof.perm_check_proof.frac_comm);
comms.push(proof.perm_check_proof.frac_comm);
Expand Down Expand Up @@ -575,21 +577,24 @@ where
// - 4.4. public input consistency checks
// - pi_poly(r_pi) where r_pi is sampled from transcript
let r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?;
let tmp_point = [vec![E::Fr::zero(); num_vars - ell], r_pi].concat();

// check public evaluation
let pi_poly = DenseMultilinearExtension::from_evaluations_slice(ell as usize, pub_input);
let expect_pi_eval = evaluate_opt(&pi_poly, &tmp_point[..]);
let expect_pi_eval = evaluate_opt(&pi_poly, &r_pi[..]);
if expect_pi_eval != *pi_eval {
return Err(HyperPlonkErrors::InvalidProver(format!(
"Public input eval mismatch: got {}, expect {}",
pi_eval, expect_pi_eval,
)));
}
comms.push(proof.witness_commits[0]);
points.push(tmp_point);
let r_pi_padded = [r_pi, vec![E::Fr::zero(); num_vars - ell]].concat();

comms.push(proof.witness_commits[0]);
points.push(r_pi_padded);
assert_eq!(comms.len(), proof.batch_openings.f_i_eval_at_point_i.len());

end_timer!(step);
let step = start_timer!(|| "PCS batch verify");
// check proof
let res = PCS::batch_verify(
&vk.pcs_param,
Expand Down
7 changes: 7 additions & 0 deletions hyperplonk/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ pub(crate) fn prover_sanity_check<F: PrimeField>(
params.num_pub_input
)));
}
if !pub_input.len().is_power_of_two() {
return Err(HyperPlonkErrors::InvalidProver(format!(
"Public input length is not power of two: got {}",
pub_input.len(),
)));
}

// witnesses length
for (i, w) in witnesses.iter().enumerate() {
if w.0.len() != params.num_constraints {
Expand Down
2 changes: 2 additions & 0 deletions subroutines/src/poly_iop/sum_check/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
/// This implementation is linear in number of inputs in terms of field
/// operations. It also has a quadratic term in primitive operations which is
/// negligible compared to field operations.
/// TODO: The quadratic term can be removed by precomputing the lagrange
/// coefficients.
fn interpolate_uni_poly<F: PrimeField>(p_i: &[F], eval_at: F) -> Result<F, PolyIOPErrors> {
let start = start_timer!(|| "sum check interpolate uni poly opt");

Expand Down
10 changes: 3 additions & 7 deletions subroutines/src/poly_iop/zero_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
use std::fmt::Debug;

use crate::poly_iop::{errors::PolyIOPErrors, sum_check::SumCheck, PolyIOP};
use arithmetic::build_eq_x_r;
use arithmetic::eq_eval;
use ark_ff::PrimeField;
use ark_poly::MultilinearExtension;
use ark_std::{end_timer, start_timer};
use transcript::IOPTranscript;

Expand Down Expand Up @@ -103,11 +102,8 @@ impl<F: PrimeField> ZeroCheck<F> for PolyIOP<F> {

// expected_eval = sumcheck.expect_eval/eq(v, r)
// where v = sum_check_sub_claim.point
let eq_x_r = build_eq_x_r(&r)?;
let expected_evaluation = sum_subclaim.expected_evaluation
/ eq_x_r.evaluate(&sum_subclaim.point).ok_or_else(|| {
PolyIOPErrors::InvalidParameters("evaluation dimension does not match".to_string())
})?;
let eq_x_r_eval = eq_eval(&sum_subclaim.point, &r)?;
let expected_evaluation = sum_subclaim.expected_evaluation / eq_x_r_eval;

end_timer!(start);
Ok(ZeroCheckSubClaim {
Expand Down