Skip to content

Commit

Permalink
Batch all (#89)
Browse files Browse the repository at this point in the history
- use sumcheck to batch open PCS
- split Prod and witness into two batches
- benchmark code
  • Loading branch information
zhenfeizhang authored Oct 14, 2022
1 parent baaa06b commit 719f595
Show file tree
Hide file tree
Showing 56 changed files with 1,349 additions and 2,510 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
members = [
"arithmetic",
"hyperplonk",
"poly-iop",
"subroutines",
"transcript",
"util"
]
8 changes: 4 additions & 4 deletions arithmetic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ mod virtual_polynomial;

pub use errors::ArithErrors;
pub use multilinear_polynomial::{
evaluate_no_par, evaluate_opt, fix_first_variable, fix_variables, identity_permutation_mle,
merge_polynomials, random_mle_list, random_permutation_mle, random_zero_mle_list,
DenseMultilinearExtension,
evaluate_no_par, evaluate_opt, fix_last_variables, fix_last_variables_no_par, fix_variables,
identity_permutation_mle, merge_polynomials, random_mle_list, random_permutation_mle,
random_zero_mle_list, DenseMultilinearExtension,
};
pub use univariate_polynomial::{build_l, get_uni_domain};
pub use util::{bit_decompose, gen_eval_point, get_batched_nv, get_index};
pub use virtual_polynomial::{build_eq_x_r, VPAuxInfo, VirtualPolynomial};
pub use virtual_polynomial::{build_eq_x_r, build_eq_x_r_vec, VPAuxInfo, VirtualPolynomial};
93 changes: 74 additions & 19 deletions arithmetic/src/multilinear_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,36 +123,24 @@ pub fn fix_variables<F: Field>(
DenseMultilinearExtension::<F>::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
}

pub fn fix_first_variable<F: Field>(
poly: &DenseMultilinearExtension<F>,
partial_point: &F,
) -> DenseMultilinearExtension<F> {
assert!(poly.num_vars != 0, "invalid size of partial point");

let nv = poly.num_vars;
let res = fix_one_variable_helper(&poly.evaluations, nv, partial_point);
DenseMultilinearExtension::<F>::from_evaluations_slice(nv - 1, &res)
}

fn fix_one_variable_helper<F: Field>(data: &[F], nv: usize, point: &F) -> Vec<F> {
let mut res = vec![F::zero(); 1 << (nv - 1)];
let one_minus_p = F::one() - point;

// evaluate single variable of partial point from left to right
#[cfg(not(feature = "parallel"))]
for b in 0..(1 << (nv - 1)) {
res[b] = data[b << 1] * one_minus_p + data[(b << 1) + 1] * point;
for i in 0..(1 << (nv - 1)) {
res[i] = data[i] + (data[(i << 1) + 1] - data[i << 1]) * point;
}

#[cfg(feature = "parallel")]
if nv >= 13 {
// on my computer we parallelization doesn't help till nv >= 13
res.par_iter_mut().enumerate().for_each(|(i, x)| {
*x = data[i << 1] * one_minus_p + data[(i << 1) + 1] * point;
*x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
});
} else {
for b in 0..(1 << (nv - 1)) {
res[b] = data[b << 1] * one_minus_p + data[(b << 1) + 1] * point;
for i in 0..(1 << (nv - 1)) {
res[i] = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
}
}

Expand All @@ -178,9 +166,8 @@ fn fix_variables_no_par<F: Field>(
// evaluate single variable of partial point from left to right
for i in 1..dim + 1 {
let r = partial_point[i - 1];
let one_minus_r = F::one() - r;
for b in 0..(1 << (nv - i)) {
poly[b] = poly[b << 1] * one_minus_r + poly[(b << 1) + 1] * r;
poly[b] = poly[b << 1] + (poly[(b << 1) + 1] - poly[b << 1]) * r;
}
}
DenseMultilinearExtension::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
Expand Down Expand Up @@ -210,3 +197,71 @@ pub fn merge_polynomials<F: PrimeField>(
merged_nv, scalars,
)))
}

pub fn fix_last_variables_no_par<F: PrimeField>(
poly: &DenseMultilinearExtension<F>,
partial_point: &[F],
) -> DenseMultilinearExtension<F> {
let mut res = fix_last_variable_no_par(poly, partial_point.last().unwrap());
for p in partial_point.iter().rev().skip(1) {
res = fix_last_variable_no_par(&res, p);
}
res
}

fn fix_last_variable_no_par<F: PrimeField>(
poly: &DenseMultilinearExtension<F>,
partial_point: &F,
) -> DenseMultilinearExtension<F> {
let nv = poly.num_vars();
let half_len = 1 << (nv - 1);
let mut res = vec![F::zero(); half_len];
for (i, e) in res.iter_mut().enumerate().take(half_len) {
*e = poly.evaluations[i]
+ *partial_point * (poly.evaluations[i + half_len] - poly.evaluations[i]);
}
DenseMultilinearExtension::from_evaluations_vec(nv - 1, res)
}
pub fn fix_last_variables<F: PrimeField>(
poly: &DenseMultilinearExtension<F>,
partial_point: &[F],
) -> DenseMultilinearExtension<F> {
assert!(
partial_point.len() <= poly.num_vars,
"invalid size of partial point"
);
let nv = poly.num_vars;
let mut poly = poly.evaluations.to_vec();
let dim = partial_point.len();
// evaluate single variable of partial point from left to right
for (i, point) in partial_point.iter().rev().enumerate().take(dim) {
poly = fix_last_variable_helper(&poly, nv - i, point);
}

DenseMultilinearExtension::<F>::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
}

fn fix_last_variable_helper<F: Field>(data: &[F], nv: usize, point: &F) -> Vec<F> {
let half_len = 1 << (nv - 1);
let mut res = vec![F::zero(); half_len];

// evaluate single variable of partial point from left to right
#[cfg(not(feature = "parallel"))]
for b in 0..half_len {
res[b] = data[b] + (data[b + half_len] - data[b]) * point;
}

#[cfg(feature = "parallel")]
if nv >= 13 {
// on my computer we parallelization doesn't help till nv >= 13
res.par_iter_mut().enumerate().for_each(|(i, x)| {
*x = data[i] + (data[i + half_len] - data[i]) * point;
});
} else {
for b in 0..(1 << (nv - 1)) {
res[b] = data[b] + (data[b + half_len] - data[b]) * point;
}
}

res
}
58 changes: 40 additions & 18 deletions arithmetic/src/virtual_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use ark_std::{
rand::{Rng, RngCore},
start_timer,
};
use rayon::prelude::*;
use std::{cmp::max, collections::HashMap, marker::PhantomData, ops::Add, rc::Rc};

#[rustfmt::skip]
Expand Down Expand Up @@ -324,16 +325,29 @@ impl<F: PrimeField> VirtualPolynomial<F> {
}
}

// This function build the eq(x, r) polynomial for any given r.
//
// Evaluate
// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
// over r, which is
// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
/// This function build the eq(x, r) polynomial for any given r.
///
/// Evaluate
/// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
/// over r, which is
/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
pub fn build_eq_x_r<F: PrimeField>(
r: &[F],
) -> Result<Rc<DenseMultilinearExtension<F>>, ArithErrors> {
let start = start_timer!(|| "zero check build eq_x_r");
let evals = build_eq_x_r_vec(r)?;
let mle = DenseMultilinearExtension::from_evaluations_vec(r.len(), evals);

Ok(Rc::new(mle))
}
/// This function build the eq(x, r) polynomial for any given r, and output the
/// evaluation of eq(x, r) in its vector form.
///
/// Evaluate
/// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
/// over r, which is
/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
pub fn build_eq_x_r_vec<F: PrimeField>(r: &[F]) -> Result<Vec<F>, ArithErrors> {
let start = start_timer!(|| format!("build eq_x_r of size {}", r.len()));

// we build eq(x,r) from its evaluations
// we want to evaluate eq(x,r) over x \in {0, 1}^num_vars
Expand All @@ -349,11 +363,8 @@ pub fn build_eq_x_r<F: PrimeField>(
let mut eval = Vec::new();
build_eq_x_r_helper(r, &mut eval)?;

let mle = DenseMultilinearExtension::from_evaluations_vec(r.len(), eval);

let res = Rc::new(mle);
end_timer!(start);
Ok(res)
Ok(eval)
}

/// A helper function to build eq(x, r) recursively.
Expand All @@ -373,13 +384,24 @@ fn build_eq_x_r_helper<F: PrimeField>(r: &[F], buf: &mut Vec<F>) -> Result<(), A
// for the current step we will need
// if x_0 = 0: (1-r0) * [b_1, ..., b_k]
// if x_0 = 1: r0 * [b_1, ..., b_k]

let mut res = vec![];
for &b_i in buf.iter() {
let tmp = r[0] * b_i;
res.push(b_i - tmp);
res.push(tmp);
}
// let mut res = vec![];
// for &b_i in buf.iter() {
// let tmp = r[0] * b_i;
// res.push(b_i - tmp);
// res.push(tmp);
// }
// *buf = res;

let mut res = vec![F::zero(); buf.len() << 1];
res.par_iter_mut().enumerate().for_each(|(i, val)| {
let bi = buf[i >> 1];
let tmp = r[0] * bi;
if i & 1 == 0 {
*val = bi - tmp;
} else {
*val = tmp;
}
});
*buf = res;
}

Expand Down
51 changes: 51 additions & 0 deletions bench_results/plot_component
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
filename = 'pie_chart.txt'
set terminal postscript eps enhanced color font "18"
set size square
set output "components.eps"

rowi = 0
rowf = 7

# obtain sum(column(2)) from rows `rowi` to `rowf`
set datafile separator ','
stats filename u 2 every ::rowi::rowf noout prefix "A"

# rowf should not be greater than length of file
rowf = (rowf-rowi > A_records - 1 ? A_records + rowi - 1 : rowf)

angle(x)=x*360/A_sum
percentage(x)=x*100/A_sum

# circumference dimensions for pie-chart
centerX=0
centerY=0
radius=1

# label positions
yposmin = 0.0
yposmax = 0.95*radius
xpos = -0.8*radius
ypos(i) = -2.2*radius + yposmax - i*(yposmax-yposmin)/(1.0*rowf-rowi)

#-------------------------------------------------------------------
# now we can configure the canvas
set style fill solid 1 # filled pie-chart
unset key # no automatic labels
unset tics # remove tics
unset border # remove borders; if some label is missing, comment to see what is happening

set size ratio -1 # equal scale length
set xrange [-radius:3*radius] # [-1:2] leaves space for labels
set yrange [-3*radius:radius] # [-1:1]

#-------------------------------------------------------------------
pos = 0 # init angle
colour = 0 # init colour

# 1st line: plot pie-chart
# 2nd line: draw colored boxes at (xpos):(ypos)
# 3rd line: place labels at (xpos+offset):(ypos)
plot filename u (centerX):(centerY):(radius):(pos):(pos=pos+angle($2)):(colour=colour+1) every ::rowi::rowf w circle lc var,\
for [i=0:rowf-rowi] '+' u (xpos):(ypos(i)) w p pt 5 ps 4 lc i+1,\
for [i=0:rowf-rowi] filename u (xpos):(ypos(i)):(sprintf('%04.1f%% %s', percentage($2), stringcolumn(1))) every ::i+rowi::i+rowi w labels left offset 3,0

31 changes: 31 additions & 0 deletions bench_results/plot_high_degree
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
set terminal postscript eps enhanced color font "18"
filename = '64threads_growing_degree.txt'
set output "grow_degree.eps"

# set font "32"

set key left top
set grid
# set logscale y
# set logscale x


set title font ",64"
set key font ",18"
set xtics font ",20"
set ytics font ",20"
set xlabel font ",20"
set ylabel font ",20"

# set key title "IOP proving time"
set key title font ", 20"
# set key title "2^{15} constraints"
set xlabel "degree d"
set ylabel 'time (us)'
# set yrange []
# set xrange [500000:1100000]
# set xtics (0, 1,2,4,8,16,32)
plot filename using 1:2 w lp t "q_Lw_1 + q_Rw_2 + q_Mw_1^{d-1}w_2 + q_C = 0",


reset
59 changes: 59 additions & 0 deletions bench_results/plot_iop
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
set terminal postscript eps enhanced color font "18"
sumcheck = 'iop/sum_check.txt'
zerocheck = 'iop/zero_check.txt'
permcheck = 'iop/perm_check.txt'
prodcheck = 'iop/prod_check.txt'

set output "iop_prover.eps"

set font "64"

set key left
set grid
set logscale y

set title font ",64"
set key font ",18"
set xtics font ",20"
set ytics font ",20"
set xlabel font ",20"
set ylabel font ",20"

set key title "IOP proving time"
set key title font ", 20"
set xlabel "\#variables"
set ylabel 'time (ms)'
# set xtics (4,8,16,32,64)
plot sumcheck using 1:2 w lp t "Sum Check",\
zerocheck using 1:2 w lp t "Zero Check",\
prodcheck using 1:2 w lp t "Prod Check",\
permcheck using 1:2 w lp t "Perm Check",
reset


# set terminal postscript eps enhanced color
# sumcheck = 'iop/sum_check.txt'
# zerocheck = 'iop/zero_check.txt'
# permcheck = 'iop/perm_check.txt'
# prodcheck = 'iop/prod_check.txt'

# set output "iop_verifier.eps"

# set font "32"

# set key left
# set grid
# set logscale y

# set title font ",10"
# set key title "IOP verifier time"
# set xlabel "\#variables"
# set ylabel 'log time (us)'
# # set xtics (4,8,16,32,64)
# plot sumcheck using 1:3 w lp t "Sum Check",\
# zerocheck using 1:3 w lp t "Zero Check",\
# prodcheck using 1:3 w lp t "Prod Check",\
# permcheck using 1:3 w lp t "Perm Check",
# reset


Loading

0 comments on commit 719f595

Please sign in to comment.