Skip to content

Commit

Permalink
refactor: more zip_with instances (microsoft#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
huitseeker committed Dec 8, 2023
1 parent 0e08caf commit 6f154da
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 38 deletions.
7 changes: 2 additions & 5 deletions src/spartan/batched_ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,8 @@ where
),
|poly_ABC, poly_E, poly_Mz, poly_tau, eval_Mz, u| {
let [poly_Az, poly_Bz, poly_Cz] = poly_ABC;
let poly_uCz_E = poly_Cz
.par_iter()
.zip_eq(poly_E.par_iter())
.map(|(cz, e)| *u * cz + e)
.collect();
let poly_uCz_E =
zip_with!((poly_Cz.par_iter(), poly_E.par_iter()), |cz, e| *u * cz + e).collect();
OuterSumcheckInstance::new(
poly_tau.clone(),
poly_Az.clone(),
Expand Down
60 changes: 27 additions & 33 deletions src/spartan/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,12 @@ impl<E: Engine> SumcheckProof<E> {
//
// claim = ∑ᵢ coeffᵢ⋅2^{n-nᵢ}⋅cᵢ
let claim = zip_with!(
(
zip_with!((claims.iter(), num_rounds.iter()), |claim, num_rounds| {
let scaling_factor = 1 << (num_rounds_max - num_rounds);
E::Scalar::from(scaling_factor as u64) * claim
}),
coeffs.iter()
),
|scaled_claim, coeff| scaled_claim * coeff
(claims.iter(), num_rounds.iter(), coeffs.iter()),
|claim, num_rounds, coeff| {
let scaling_factor = 1 << (num_rounds_max - num_rounds);
let scaled_claim = E::Scalar::from(scaling_factor as u64) * claim;
scaled_claim * coeff
}
)
.sum();

Expand Down Expand Up @@ -215,15 +213,14 @@ impl<E: Engine> SumcheckProof<E> {
}

let num_rounds_max = *num_rounds.iter().max().unwrap();
let mut e = claims
.iter()
.zip_eq(num_rounds)
.map(|(claim, num_rounds)| {
E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim
})
.zip_eq(coeffs)
.map(|(claim, c)| claim * c)
.sum();
let mut e = zip_with!(
(claims.iter(), num_rounds, coeffs),
|claim, num_rounds, coeff| {
let scaled_claim = E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim;
scaled_claim * coeff
}
)
.sum();
let mut r: Vec<E::Scalar> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<E::Scalar>> = Vec::new();

Expand Down Expand Up @@ -293,13 +290,11 @@ impl<E: Engine> SumcheckProof<E> {
.map(|poly| poly[0])
.collect::<Vec<_>>();

let eval_expected = poly_A_final
.iter()
.zip_eq(poly_B_final.iter())
.map(|(eA, eB)| comb_func(eA, eB))
.zip_eq(coeffs.iter())
.map(|(e, c)| e * c)
.sum::<E::Scalar>();
let eval_expected = zip_with!(
(poly_A_final.iter(), poly_B_final.iter(), coeffs.iter()),
|eA, eB, coeff| comb_func(eA, eB) * coeff
)
.sum::<E::Scalar>();
assert_eq!(e, eval_expected);

let claims_prod = (poly_A_final, poly_B_final);
Expand Down Expand Up @@ -532,15 +527,14 @@ impl<E: Engine> SumcheckProof<E> {

let mut r: Vec<E::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<E::Scalar>> = Vec::new();
let mut claim_per_round = claims
.iter()
.zip_eq(num_rounds)
.map(|(claim, num_rounds)| {
E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim
})
.zip_eq(coeffs.iter())
.map(|(claim, c)| claim * c)
.sum();
let mut claim_per_round = zip_with!(
(claims.iter(), num_rounds, coeffs.iter()),
|claim, num_rounds, coeff| {
let scaled_claim = E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim;
claim * coeff
}
)
.sum();

for current_round in 0..num_rounds_max {
let remaining_rounds = num_rounds_max - current_round;
Expand Down

0 comments on commit 6f154da

Please sign in to comment.