Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized interpolation #28

Merged
merged 4 commits into from
May 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 106 additions & 44 deletions poly-iop/src/sum_check/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,67 +181,129 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
fn interpolate_uni_poly<F: PrimeField>(p_i: &[F], eval_at: F) -> Result<F, PolyIOPErrors> {
let start = start_timer!(|| "sum check interpolate uni poly opt");

let mut res = F::zero();

// compute
// - prod = \prod (eval_at - j)
// - evals = [eval_at - j]
let mut evals = vec![];
let len = p_i.len();
let mut evals = vec![];
let mut prod = eval_at;
evals.push(eval_at);

// `prod = \prod_{j} (eval_at - j)`
for e in 1..len {
let tmp = eval_at - F::from(e as u64);
evals.push(tmp);
prod *= tmp;
}
let mut res = F::zero();
// we want to compute \prod (j!=i) (i-j) for a given i
//
// we start from the last step, which is
// denom[len-1] = (len-1) * (len-2) *... * 2 * 1
// the step before that is
// denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1
// and the step before that is
// denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2
//
// i.e., for any i, the one before this will be derived from
// denom[i-1] = denom[i] * (len-i) / i
//
// that is, we only need to store
// - the last denom for i = len-1, and
// - the ratio between current step and fhe last step, which is the product of
// (len-i) / i from all previous steps and we store this product as a fraction
// number to reduce field divisions.

for i in 0..len {
// res += p_i * prod / (divisor * (eval_at - j))
let divisor = get_divisor(i, len)?;
let divisor_f = {
if divisor < 0 {
-F::from((-divisor) as u128)
// We know
// - 2^61 < factorial(20) < 2^62
// - 2^122 < factorial(33) < 2^123
// so we will be able to compute the ratio
// - for len <= 20 with i64
// - for len <= 33 with i128
// - for len > 33 with BigInt
if p_i.len() <= 20 {
let last_denominator = F::from(u64_factorial(len - 1));
let mut ratio_numerator = 1i64;
let mut ratio_enumerator = 1u64;

for i in (0..len).rev() {
let ratio_numerator_f = if ratio_numerator < 0 {
-F::from((-ratio_numerator) as u64)
} else {
F::from(divisor as u128)
}
};
res += p_i[i] * prod / (divisor_f * evals[i]);
}
F::from(ratio_numerator as u64)
};

end_timer!(start);
Ok(res)
}
res += p_i[i] * prod * F::from(ratio_enumerator)
/ (last_denominator * ratio_numerator_f * evals[i]);

/// Compute \prod_{j!=i)^len (i-j). This function takes O(n^2) number of
/// primitive operations which is negligible compared to field operations.
// We know
// - factorial(20) ~ 2^61
// - factorial(33) ~ 2^123
// so we will be able to store the result for len<=20 with i64;
// for len<=33 with i128; and we do not currently support len>33.
#[inline]
fn get_divisor(i: usize, len: usize) -> Result<i128, PolyIOPErrors> {
if len <= 20 {
let mut res = 1i64;
for j in 0..len {
if j != i {
res *= i as i64 - j as i64;
// compute denom for the next step is current_denom * (len-i)/i
if i != 0 {
ratio_numerator *= -(len as i64 - i as i64);
ratio_enumerator *= i as u64;
}
}
Ok(res as i128)
} else if len <= 33 {
let mut res = 1i128;
for j in 0..len {
if j != i {
res *= i as i128 - j as i128;
} else if p_i.len() <= 33 {
let last_denominator = F::from(u128_factorial(len - 1));
let mut ratio_numerator = 1i128;
let mut ratio_enumerator = 1u128;

for i in (0..len).rev() {
let ratio_numerator_f = if ratio_numerator < 0 {
-F::from((-ratio_numerator) as u128)
} else {
F::from(ratio_numerator as u128)
};

res += p_i[i] * prod * F::from(ratio_enumerator)
/ (last_denominator * ratio_numerator_f * evals[i]);

// compute denom for the next step is current_denom * (len-i)/i
if i != 0 {
ratio_numerator *= -(len as i128 - i as i128);
ratio_enumerator *= i as u128;
}
}
Ok(res)
} else {
Err(PolyIOPErrors::InvalidParameters(
"Do not support number variable > 33".to_string(),
))
let mut denom_up = field_factorial::<F>(len - 1);
let mut denom_down = F::one();

for i in (0..len).rev() {
res += p_i[i] * prod * denom_down / (denom_up * evals[i]);

// compute denom for the next step is current_denom * (len-i)/i
if i != 0 {
denom_up *= -F::from((len - i) as u64);
denom_down *= F::from(i as u64);
}
}
}
end_timer!(start);
Ok(res)
}

/// compute the factorial(a) = 1 * 2 * ... * a
#[inline]
fn field_factorial<F: PrimeField>(a: usize) -> F {
let mut res = 1u64;
for i in 1..=a {
res *= i as u64;
}
F::from(res)
}

/// compute the factorial(a) = 1 * 2 * ... * a
#[inline]
fn u128_factorial(a: usize) -> u128 {
let mut res = 1u128;
for i in 1..=a {
res *= i as u128;
}
res
}

/// compute the factorial(a) = 1 * 2 * ... * a
#[inline]
fn u64_factorial(a: usize) -> u64 {
let mut res = 1u64;
for i in 1..=a {
res *= i as u64;
}
res
}