diff --git a/Cargo.toml b/Cargo.toml index eb190965..26dc0b98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nova-snark" -version = "0.22.0" +version = "0.23.0" authors = ["Srinath Setty "] edition = "2021" description = "Recursive zkSNARKs without trusted setup" @@ -28,7 +28,6 @@ num-traits = "0.2" num-integer = "0.1" serde = { version = "1.0", features = ["derive"] } bincode = "1.3" -flate2 = "1.0" bitvec = "1.0" byteorder = "1.4.3" thiserror = "1.0" @@ -45,10 +44,12 @@ getrandom = { version = "0.2.0", default-features = false, features = ["js"] } [dev-dependencies] criterion = { version = "0.4", features = ["html_reports"] } rand = "0.8.4" +flate2 = "1.0" hex = "0.4.3" pprof = { version = "0.11" } cfg-if = "1.0.0" sha2 = "0.10.7" +proptest = "1.2.0" [[bench]] name = "recursive-snark" diff --git a/benches/compressed-snark.rs b/benches/compressed-snark.rs index 4effe262..2402217e 100644 --- a/benches/compressed-snark.rs +++ b/benches/compressed-snark.rs @@ -17,8 +17,12 @@ type G1 = pasta_curves::pallas::Point; type G2 = pasta_curves::vesta::Point; type EE1 = nova_snark::provider::ipa_pc::EvaluationEngine; type EE2 = nova_snark::provider::ipa_pc::EvaluationEngine; +// SNARKs without computational commitments type S1 = nova_snark::spartan::snark::RelaxedR1CSSNARK; type S2 = nova_snark::spartan::snark::RelaxedR1CSSNARK; +// SNARKs with computational commitments +type SS1 = nova_snark::spartan::ppsnark::RelaxedR1CSSNARK; +type SS2 = nova_snark::spartan::ppsnark::RelaxedR1CSSNARK; type C1 = NonTrivialTestCircuit<::Scalar>; type C2 = TrivialTestCircuit<::Scalar>; @@ -31,13 +35,13 @@ cfg_if::cfg_if! { criterion_group! { name = compressed_snark; config = Criterion::default().warm_up_time(Duration::from_millis(3000)).with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); - targets = bench_compressed_snark + targets = bench_compressed_snark, bench_compressed_snark_with_computational_commitments } } else { criterion_group! { name = compressed_snark; config = Criterion::default().warm_up_time(Duration::from_millis(3000)); - targets = bench_compressed_snark + targets = bench_compressed_snark, bench_compressed_snark_with_computational_commitments } } } @@ -61,7 +65,7 @@ fn bench_compressed_snark(c: &mut Criterion) { let c_secondary = TrivialTestCircuit::default(); // Produce public parameters - let pp = PublicParams::::setup(c_primary.clone(), c_secondary.clone()); + let pp = PublicParams::::setup(&c_primary, &c_secondary); // Produce prover and verifier keys for CompressedSNARK let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); @@ -129,6 +133,93 @@ fn bench_compressed_snark(c: &mut Criterion) { } } +fn bench_compressed_snark_with_computational_commitments(c: &mut Criterion) { + let num_samples = 10; + let num_cons_verifier_circuit_primary = 9819; + // we vary the number of constraints in the step circuit + for &num_cons_in_augmented_circuit in [9819, 16384, 32768, 65536, 131072, 262144].iter() { + // number of constraints in the step circuit + let num_cons = num_cons_in_augmented_circuit - num_cons_verifier_circuit_primary; + + let mut group = c.benchmark_group(format!( + "CompressedSNARK-Commitments-StepCircuitSize-{num_cons}" + )); + group + .sampling_mode(SamplingMode::Flat) + .sample_size(num_samples); + + let c_primary = NonTrivialTestCircuit::new(num_cons); + let c_secondary = TrivialTestCircuit::default(); + + // Produce public parameters + let pp = PublicParams::::setup(&c_primary, &c_secondary); + + // Produce prover and verifier keys for CompressedSNARK + let (pk, vk) = CompressedSNARK::<_, _, _, _, SS1, SS2>::setup(&pp).unwrap(); + + // produce a recursive SNARK + let num_steps = 3; + let mut recursive_snark: RecursiveSNARK = RecursiveSNARK::new( + &pp, + &c_primary, + &c_secondary, + vec![::Scalar::from(2u64)], + vec![::Scalar::from(2u64)], + ); + + for i in 0..num_steps { + let res = recursive_snark.prove_step( + &pp, + &c_primary, + &c_secondary, + vec![::Scalar::from(2u64)], + vec![::Scalar::from(2u64)], + ); + assert!(res.is_ok()); + + // verify the recursive snark at each step of recursion + let res = recursive_snark.verify( + &pp, + i + 1, + &[::Scalar::from(2u64)], + &[::Scalar::from(2u64)], + ); + assert!(res.is_ok()); + } + + // Bench time to produce a compressed SNARK + group.bench_function("Prove", |b| { + b.iter(|| { + assert!(CompressedSNARK::<_, _, _, _, SS1, SS2>::prove( + black_box(&pp), + black_box(&pk), + black_box(&recursive_snark) + ) + .is_ok()); + }) + }); + let res = CompressedSNARK::<_, _, _, _, SS1, SS2>::prove(&pp, &pk, &recursive_snark); + assert!(res.is_ok()); + let compressed_snark = res.unwrap(); + + // Benchmark the verification time + group.bench_function("Verify", |b| { + b.iter(|| { + assert!(black_box(&compressed_snark) + .verify( + black_box(&vk), + black_box(num_steps), + black_box(vec![::Scalar::from(2u64)]), + black_box(vec![::Scalar::from(2u64)]), + ) + .is_ok()); + }) + }); + + group.finish(); + } +} + #[derive(Clone, Debug, Default)] struct NonTrivialTestCircuit { num_cons: usize, diff --git a/benches/compute-digest.rs b/benches/compute-digest.rs index 47bdda1d..50105565 100644 --- a/benches/compute-digest.rs +++ b/benches/compute-digest.rs @@ -27,7 +27,7 @@ criterion_main!(compute_digest); fn bench_compute_digest(c: &mut Criterion) { c.bench_function("compute_digest", |b| { b.iter(|| { - PublicParams::::setup(black_box(C1::new(10)), black_box(C2::default())) + PublicParams::::setup(black_box(&C1::new(10)), black_box(&C2::default())) }) }); } diff --git a/benches/recursive-snark.rs b/benches/recursive-snark.rs index eed8d48f..5af803f8 100644 --- a/benches/recursive-snark.rs +++ b/benches/recursive-snark.rs @@ -56,7 +56,7 @@ fn bench_recursive_snark(c: &mut Criterion) { let c_secondary = TrivialTestCircuit::default(); // Produce public parameters - let pp = PublicParams::::setup(c_primary.clone(), c_secondary.clone()); + let pp = PublicParams::::setup(&c_primary, &c_secondary); // Bench time to produce a recursive SNARK; // we execute a certain number of warm-up steps since executing diff --git a/benches/sha256.rs b/benches/sha256.rs index f35500f8..642c6991 100644 --- a/benches/sha256.rs +++ b/benches/sha256.rs @@ -200,8 +200,8 @@ fn bench_recursive_snark(c: &mut Criterion) { group.sample_size(10); // Produce public parameters - let pp = - PublicParams::::setup(circuit_primary.clone(), TrivialTestCircuit::default()); + let ttc = TrivialTestCircuit::default(); + let pp = PublicParams::::setup(&circuit_primary, &ttc); let circuit_secondary = TrivialTestCircuit::default(); let z0_primary = vec![::Scalar::from(2u64)]; diff --git a/examples/minroot.rs b/examples/minroot.rs index 75c2d41d..dd5c8d60 100644 --- a/examples/minroot.rs +++ b/examples/minroot.rs @@ -172,7 +172,7 @@ fn main() { G2, MinRootCircuit<::Scalar>, TrivialTestCircuit<::Scalar>, - >::setup(circuit_primary.clone(), circuit_secondary.clone()); + >::setup(&circuit_primary, &circuit_secondary); println!("PublicParams::setup, took {:?} ", start.elapsed()); println!( diff --git a/src/bellperson/r1cs.rs b/src/bellperson/r1cs.rs index 56710f76..ae3388df 100644 --- a/src/bellperson/r1cs.rs +++ b/src/bellperson/r1cs.rs @@ -28,10 +28,7 @@ pub trait NovaShape { fn r1cs_shape(&self) -> (R1CSShape, CommitmentKey); } -impl NovaWitness for SatisfyingAssignment -where - G::Scalar: PrimeField, -{ +impl NovaWitness for SatisfyingAssignment { fn r1cs_instance_and_witness( &self, shape: &R1CSShape, @@ -48,10 +45,7 @@ where } } -impl NovaShape for ShapeCS -where - G::Scalar: PrimeField, -{ +impl NovaShape for ShapeCS { fn r1cs_shape(&self) -> (R1CSShape, CommitmentKey) { let mut A: Vec<(usize, usize, G::Scalar)> = Vec::new(); let mut B: Vec<(usize, usize, G::Scalar)> = Vec::new(); diff --git a/src/bellperson/shape_cs.rs b/src/bellperson/shape_cs.rs index bb964636..a80be8c5 100644 --- a/src/bellperson/shape_cs.rs +++ b/src/bellperson/shape_cs.rs @@ -48,10 +48,7 @@ impl Ord for OrderedVariable { #[allow(clippy::upper_case_acronyms)] /// `ShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. -pub struct ShapeCS -where - G::Scalar: PrimeField + Field, -{ +pub struct ShapeCS { named_objects: HashMap, current_namespace: Vec, #[allow(clippy::type_complexity)] @@ -92,10 +89,7 @@ fn proc_lc( map } -impl ShapeCS -where - G::Scalar: PrimeField, -{ +impl ShapeCS { /// Create a new, default `ShapeCS`, pub fn new() -> Self { ShapeCS::default() @@ -216,10 +210,7 @@ where } } -impl Default for ShapeCS -where - G::Scalar: PrimeField, -{ +impl Default for ShapeCS { fn default() -> Self { let mut map = HashMap::new(); map.insert("ONE".into(), NamedObject::Var(ShapeCS::::one())); @@ -233,10 +224,7 @@ where } } -impl ConstraintSystem for ShapeCS -where - G::Scalar: PrimeField, -{ +impl ConstraintSystem for ShapeCS { type Root = Self; fn alloc(&mut self, annotation: A, _f: F) -> Result diff --git a/src/bellperson/solver.rs b/src/bellperson/solver.rs index 0eaf088c..2357724a 100644 --- a/src/bellperson/solver.rs +++ b/src/bellperson/solver.rs @@ -1,7 +1,7 @@ //! Support for generating R1CS witness using bellperson. use crate::traits::Group; -use ff::{Field, PrimeField}; +use ff::Field; use bellperson::{ multiexp::DensityTracker, ConstraintSystem, Index, LinearCombination, SynthesisError, Variable, @@ -9,10 +9,7 @@ use bellperson::{ /// A `ConstraintSystem` which calculates witness values for a concrete instance of an R1CS circuit. #[derive(PartialEq)] -pub struct SatisfyingAssignment -where - G::Scalar: PrimeField, -{ +pub struct SatisfyingAssignment { // Density of queries a_aux_density: DensityTracker, b_input_density: DensityTracker, @@ -29,10 +26,7 @@ where } use std::fmt; -impl fmt::Debug for SatisfyingAssignment -where - G::Scalar: PrimeField, -{ +impl fmt::Debug for SatisfyingAssignment { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt .debug_struct("SatisfyingAssignment") @@ -69,10 +63,7 @@ where } } -impl ConstraintSystem for SatisfyingAssignment -where - G::Scalar: PrimeField, -{ +impl ConstraintSystem for SatisfyingAssignment { type Root = Self; fn new() -> Self { diff --git a/src/ccs/cccs.rs b/src/ccs/cccs.rs index 8eb01ed8..6c67f621 100644 --- a/src/ccs/cccs.rs +++ b/src/ccs/cccs.rs @@ -18,7 +18,6 @@ use crate::{ use bitvec::vec; use core::{cmp::max, marker::PhantomData}; use ff::{Field, PrimeField}; -use flate2::{write::ZlibEncoder, Compression}; use itertools::concat; use rayon::prelude::*; use serde::{Deserialize, Serialize}; diff --git a/src/ccs/lcccs.rs b/src/ccs/lcccs.rs index 22fb5177..96c6f01a 100644 --- a/src/ccs/lcccs.rs +++ b/src/ccs/lcccs.rs @@ -21,7 +21,6 @@ use crate::{ use bitvec::vec; use core::{cmp::max, marker::PhantomData}; use ff::{Field, PrimeField}; -use flate2::{write::ZlibEncoder, Compression}; use itertools::concat; use rand_core::RngCore; use rayon::prelude::*; diff --git a/src/ccs/mod.rs b/src/ccs/mod.rs index 1b3ba931..d56a6af7 100644 --- a/src/ccs/mod.rs +++ b/src/ccs/mod.rs @@ -24,7 +24,6 @@ use crate::{ use bitvec::vec; use core::{cmp::max, marker::PhantomData}; use ff::Field; -use flate2::{write::ZlibEncoder, Compression}; use itertools::concat; use rand_core::RngCore; use rayon::prelude::*; diff --git a/src/ccs/multifolding.rs b/src/ccs/multifolding.rs index 8aade340..188c5e07 100644 --- a/src/ccs/multifolding.rs +++ b/src/ccs/multifolding.rs @@ -24,7 +24,6 @@ use crate::{ use bitvec::vec; use core::{cmp::max, marker::PhantomData}; use ff::{Field, PrimeField}; -use flate2::{write::ZlibEncoder, Compression}; use itertools::concat; use rand_core::RngCore; use rayon::prelude::*; diff --git a/src/ccs/util/mod.rs b/src/ccs/util/mod.rs index 23616b35..02fd1af2 100644 --- a/src/ccs/util/mod.rs +++ b/src/ccs/util/mod.rs @@ -18,7 +18,6 @@ use crate::{ use bitvec::vec; use core::{cmp::max, marker::PhantomData}; use ff::{Field, PrimeField}; -use flate2::{write::ZlibEncoder, Compression}; use itertools::concat; use rayon::prelude::*; use serde::{Deserialize, Serialize}; diff --git a/src/ccs/util/virtual_poly.rs b/src/ccs/util/virtual_poly.rs index f1bcdf2a..c4822646 100644 --- a/src/ccs/util/virtual_poly.rs +++ b/src/ccs/util/virtual_poly.rs @@ -18,7 +18,6 @@ use crate::{ use bitvec::vec; use core::{cmp::max, marker::PhantomData}; use ff::{Field, PrimeField}; -use flate2::{write::ZlibEncoder, Compression}; use itertools::concat; use rand::Rng; use rand_core::RngCore; diff --git a/src/circuit.rs b/src/circuit.rs index 60744f0e..426ce18a 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -40,7 +40,7 @@ pub struct NovaAugmentedCircuitParams { } impl NovaAugmentedCircuitParams { - pub fn new(limb_width: usize, n_limbs: usize, is_primary_circuit: bool) -> Self { + pub const fn new(limb_width: usize, n_limbs: usize, is_primary_circuit: bool) -> Self { Self { limb_width, n_limbs, @@ -87,19 +87,19 @@ impl NovaAugmentedCircuitInputs { /// The augmented circuit F' in Nova that includes a step circuit F /// and the circuit for the verifier in Nova's non-interactive folding scheme -pub struct NovaAugmentedCircuit> { - params: NovaAugmentedCircuitParams, +pub struct NovaAugmentedCircuit<'a, G: Group, SC: StepCircuit> { + params: &'a NovaAugmentedCircuitParams, ro_consts: ROConstantsCircuit, inputs: Option>, - step_circuit: SC, // The function that is applied for each step + step_circuit: &'a SC, // The function that is applied for each step } -impl> NovaAugmentedCircuit { +impl<'a, G: Group, SC: StepCircuit> NovaAugmentedCircuit<'a, G, SC> { /// Create a new verification circuit for the input relaxed r1cs instances - pub fn new( - params: NovaAugmentedCircuitParams, + pub const fn new( + params: &'a NovaAugmentedCircuitParams, inputs: Option>, - step_circuit: SC, + step_circuit: &'a SC, ro_consts: ROConstantsCircuit, ) -> Self { Self { @@ -262,8 +262,8 @@ impl> NovaAugmentedCircuit { } } -impl> Circuit<::Base> - for NovaAugmentedCircuit +impl<'a, G: Group, SC: StepCircuit> Circuit<::Base> + for NovaAugmentedCircuit<'a, G, SC> { fn synthesize::Base>>( self, @@ -396,27 +396,19 @@ mod tests { G1: Group::Scalar>, G2: Group::Scalar>, { + let ttc1 = TrivialTestCircuit::default(); // Initialize the shape and ck for the primary - let circuit1: NovaAugmentedCircuit::Base>> = - NovaAugmentedCircuit::new( - primary_params.clone(), - None, - TrivialTestCircuit::default(), - ro_consts1.clone(), - ); + let circuit1: NovaAugmentedCircuit<'_, G2, TrivialTestCircuit<::Base>> = + NovaAugmentedCircuit::new(&primary_params, None, &ttc1, ro_consts1.clone()); let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit1.synthesize(&mut cs); let (shape1, ck1) = cs.r1cs_shape(); assert_eq!(cs.num_constraints(), num_constraints_primary); + let ttc2 = TrivialTestCircuit::default(); // Initialize the shape and ck for the secondary - let circuit2: NovaAugmentedCircuit::Base>> = - NovaAugmentedCircuit::new( - secondary_params.clone(), - None, - TrivialTestCircuit::default(), - ro_consts2.clone(), - ); + let circuit2: NovaAugmentedCircuit<'_, G1, TrivialTestCircuit<::Base>> = + NovaAugmentedCircuit::new(&secondary_params, None, &ttc2, ro_consts2.clone()); let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit2.synthesize(&mut cs); let (shape2, ck2) = cs.r1cs_shape(); @@ -434,13 +426,8 @@ mod tests { None, None, ); - let circuit1: NovaAugmentedCircuit::Base>> = - NovaAugmentedCircuit::new( - primary_params, - Some(inputs1), - TrivialTestCircuit::default(), - ro_consts1, - ); + let circuit1: NovaAugmentedCircuit<'_, G2, TrivialTestCircuit<::Base>> = + NovaAugmentedCircuit::new(&primary_params, Some(inputs1), &ttc1, ro_consts1); let _ = circuit1.synthesize(&mut cs1); let (inst1, witness1) = cs1.r1cs_instance_and_witness(&shape1, &ck1).unwrap(); // Make sure that this is satisfiable @@ -458,13 +445,8 @@ mod tests { Some(inst1), None, ); - let circuit2: NovaAugmentedCircuit::Base>> = - NovaAugmentedCircuit::new( - secondary_params, - Some(inputs2), - TrivialTestCircuit::default(), - ro_consts2, - ); + let circuit2: NovaAugmentedCircuit<'_, G1, TrivialTestCircuit<::Base>> = + NovaAugmentedCircuit::new(&secondary_params, Some(inputs2), &ttc2, ro_consts2); let _ = circuit2.synthesize(&mut cs2); let (inst2, witness2) = cs2.r1cs_instance_and_witness(&shape2, &ck2).unwrap(); // Make sure that it is satisfiable diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index e3ed1c09..f6c2d1dd 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -81,7 +81,7 @@ where } /// Returns coordinates associated with the point. - pub fn get_coordinates( + pub const fn get_coordinates( &self, ) -> ( &AllocatedNum, @@ -570,7 +570,7 @@ where G: Group, { /// Creates a new AllocatedPointNonInfinity from the specified coordinates - pub fn new(x: AllocatedNum, y: AllocatedNum) -> Self { + pub const fn new(x: AllocatedNum, y: AllocatedNum) -> Self { Self { x, y } } @@ -610,7 +610,7 @@ where } /// Returns coordinates associated with the point. - pub fn get_coordinates(&self) -> (&AllocatedNum, &AllocatedNum) { + pub const fn get_coordinates(&self) -> (&AllocatedNum, &AllocatedNum) { (&self.x, &self.y) } diff --git a/src/gadgets/nonnative/bignat.rs b/src/gadgets/nonnative/bignat.rs index 9db48480..eb3144ef 100644 --- a/src/gadgets/nonnative/bignat.rs +++ b/src/gadgets/nonnative/bignat.rs @@ -783,7 +783,9 @@ impl Polynomial { #[cfg(test)] mod tests { use super::*; - use bellperson::Circuit; + use bellperson::{gadgets::test::TestConstraintSystem, Circuit}; + use pasta_curves::pallas::Scalar; + use proptest::prelude::*; pub struct PolynomialMultiplier { pub a: Vec, @@ -818,4 +820,79 @@ mod tests { Ok(()) } } + + #[test] + fn test_polynomial_multiplier_circuit() { + let mut cs = TestConstraintSystem::::new(); + + let circuit = PolynomialMultiplier { + a: [1, 1, 1].iter().map(|i| Scalar::from_u128(*i)).collect(), + b: [1, 1].iter().map(|i| Scalar::from_u128(*i)).collect(), + }; + + circuit.synthesize(&mut cs).expect("synthesis failed"); + + if let Some(token) = cs.which_is_unsatisfied() { + eprintln!("Error: {} is unsatisfied", token); + } + } + + #[derive(Debug)] + pub struct BigNatBitDecompInputs { + pub n: BigInt, + } + + pub struct BigNatBitDecompParams { + pub limb_width: usize, + pub n_limbs: usize, + } + + pub struct BigNatBitDecomp { + inputs: Option, + params: BigNatBitDecompParams, + } + + impl Circuit for BigNatBitDecomp { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + let n = BigNat::alloc_from_nat( + cs.namespace(|| "n"), + || Ok(self.inputs.grab()?.n.clone()), + self.params.limb_width, + self.params.n_limbs, + )?; + n.decompose(cs.namespace(|| "decomp"))?; + Ok(()) + } + } + + proptest! { + + #![proptest_config(ProptestConfig { + cases: 10, // this test is costlier as max n gets larger + .. ProptestConfig::default() + })] + #[test] + fn test_big_nat_can_decompose(n in any::(), limb_width in 40u8..200) { + let n = n as usize; + + let n_limbs = if n == 0 { + 1 + } else { + (n - 1) / limb_width as usize + 1 + }; + + let circuit = BigNatBitDecomp { + inputs: Some(BigNatBitDecompInputs { + n: BigInt::from(n), + }), + params: BigNatBitDecompParams { + limb_width: limb_width as usize, + n_limbs, + }, + }; + let mut cs = TestConstraintSystem::::new(); + circuit.synthesize(&mut cs).expect("synthesis failed"); + prop_assert!(cs.is_satisfied()); + } + } } diff --git a/src/gadgets/nonnative/util.rs b/src/gadgets/nonnative/util.rs index e0cceee5..486270d2 100644 --- a/src/gadgets/nonnative/util.rs +++ b/src/gadgets/nonnative/util.rs @@ -69,7 +69,7 @@ pub struct Num { } impl Num { - pub fn new(value: Option, num: LinearCombination) -> Self { + pub const fn new(value: Option, num: LinearCombination) -> Self { Self { value, num } } pub fn alloc(mut cs: CS, value: F) -> Result diff --git a/src/lib.rs b/src/lib.rs index d2df81a8..78a0e138 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,7 +88,7 @@ where C2: StepCircuit, { /// Create a new `PublicParams` - pub fn setup(c_primary: C1, c_secondary: C2) -> Self { + pub fn setup(c_primary: &C1, c_secondary: &C2) -> Self { let augmented_circuit_params_primary = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); let augmented_circuit_params_secondary = @@ -105,8 +105,8 @@ where let ro_consts_circuit_secondary: ROConstantsCircuit = ROConstantsCircuit::::new(); // Initialize ck for the primary - let circuit_primary: NovaAugmentedCircuit = NovaAugmentedCircuit::new( - augmented_circuit_params_primary.clone(), + let circuit_primary: NovaAugmentedCircuit<'_, G2, C1> = NovaAugmentedCircuit::new( + &augmented_circuit_params_primary, None, c_primary, ro_consts_circuit_primary.clone(), @@ -116,8 +116,8 @@ where let (r1cs_shape_primary, ck_primary) = cs.r1cs_shape(); // Initialize ck for the secondary - let circuit_secondary: NovaAugmentedCircuit = NovaAugmentedCircuit::new( - augmented_circuit_params_secondary.clone(), + let circuit_secondary: NovaAugmentedCircuit<'_, G1, C2> = NovaAugmentedCircuit::new( + &augmented_circuit_params_secondary, None, c_secondary, ro_consts_circuit_secondary.clone(), @@ -151,7 +151,7 @@ where } /// Returns the number of constraints in the primary and secondary circuits - pub fn num_constraints(&self) -> (usize, usize) { + pub const fn num_constraints(&self) -> (usize, usize) { ( self.r1cs_shape_primary.num_cons, self.r1cs_shape_secondary.num_cons, @@ -159,7 +159,7 @@ where } /// Returns the number of variables in the primary and secondary circuits - pub fn num_variables(&self) -> (usize, usize) { + pub const fn num_variables(&self) -> (usize, usize) { ( self.r1cs_shape_primary.num_vars, self.r1cs_shape_secondary.num_vars, @@ -221,10 +221,10 @@ where None, ); - let circuit_primary: NovaAugmentedCircuit = NovaAugmentedCircuit::new( - pp.augmented_circuit_params_primary.clone(), + let circuit_primary: NovaAugmentedCircuit<'_, G2, C1> = NovaAugmentedCircuit::new( + &pp.augmented_circuit_params_primary, Some(inputs_primary), - c_primary.clone(), + c_primary, pp.ro_consts_circuit_primary.clone(), ); let _ = circuit_primary.synthesize(&mut cs_primary); @@ -244,10 +244,10 @@ where Some(u_primary.clone()), None, ); - let circuit_secondary: NovaAugmentedCircuit = NovaAugmentedCircuit::new( - pp.augmented_circuit_params_secondary.clone(), + let circuit_secondary: NovaAugmentedCircuit<'_, G1, C2> = NovaAugmentedCircuit::new( + &pp.augmented_circuit_params_secondary, Some(inputs_secondary), - c_secondary.clone(), + c_secondary, pp.ro_consts_circuit_secondary.clone(), ); let _ = circuit_secondary.synthesize(&mut cs_secondary); @@ -333,10 +333,10 @@ where Some(Commitment::::decompress(&nifs_secondary.comm_T)?), ); - let circuit_primary: NovaAugmentedCircuit = NovaAugmentedCircuit::new( - pp.augmented_circuit_params_primary.clone(), + let circuit_primary: NovaAugmentedCircuit<'_, G2, C1> = NovaAugmentedCircuit::new( + &pp.augmented_circuit_params_primary, Some(inputs_primary), - c_primary.clone(), + c_primary, pp.ro_consts_circuit_primary.clone(), ); let _ = circuit_primary.synthesize(&mut cs_primary); @@ -370,10 +370,10 @@ where Some(Commitment::::decompress(&nifs_primary.comm_T)?), ); - let circuit_secondary: NovaAugmentedCircuit = NovaAugmentedCircuit::new( - pp.augmented_circuit_params_secondary.clone(), + let circuit_secondary: NovaAugmentedCircuit<'_, G1, C2> = NovaAugmentedCircuit::new( + &pp.augmented_circuit_params_secondary, Some(inputs_secondary), - c_secondary.clone(), + c_secondary, pp.ro_consts_circuit_secondary.clone(), ); let _ = circuit_secondary.synthesize(&mut cs_secondary); @@ -870,7 +870,7 @@ mod tests { T1: StepCircuit, T2: StepCircuit, { - let pp = PublicParams::::setup(circuit1, circuit2); + let pp = PublicParams::::setup(&circuit1, &circuit2); let digest_str = pp .digest @@ -934,7 +934,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, TrivialTestCircuit<::Scalar>, - >::setup(test_circuit1.clone(), test_circuit2.clone()); + >::setup(&test_circuit1, &test_circuit2); let num_steps = 1; @@ -990,7 +990,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, CubicCircuit<::Scalar>, - >::setup(circuit_primary.clone(), circuit_secondary.clone()); + >::setup(&circuit_primary, &circuit_secondary); let num_steps = 3; @@ -1077,7 +1077,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, CubicCircuit<::Scalar>, - >::setup(circuit_primary.clone(), circuit_secondary.clone()); + >::setup(&circuit_primary, &circuit_secondary); let num_steps = 3; @@ -1172,7 +1172,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, CubicCircuit<::Scalar>, - >::setup(circuit_primary.clone(), circuit_secondary.clone()); + >::setup(&circuit_primary, &circuit_secondary); let num_steps = 3; @@ -1344,7 +1344,7 @@ mod tests { G2, FifthRootCheckingCircuit<::Scalar>, TrivialTestCircuit<::Scalar>, - >::setup(circuit_primary, circuit_secondary.clone()); + >::setup(&circuit_primary, &circuit_secondary); let num_steps = 3; @@ -1422,7 +1422,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, CubicCircuit<::Scalar>, - >::setup(test_circuit1.clone(), test_circuit2.clone()); + >::setup(&test_circuit1, &test_circuit2); let num_steps = 1; diff --git a/src/provider/ipa_pc.rs b/src/provider/ipa_pc.rs index fa8068bb..0ae536ab 100644 --- a/src/provider/ipa_pc.rs +++ b/src/provider/ipa_pc.rs @@ -177,7 +177,7 @@ where G: Group, CommitmentKey: CommitmentKeyExtTrait, { - fn protocol_name() -> &'static [u8] { + const fn protocol_name() -> &'static [u8] { b"IPA" } diff --git a/src/provider/pasta.rs b/src/provider/pasta.rs index 1f5284eb..471ee328 100644 --- a/src/provider/pasta.rs +++ b/src/provider/pasta.rs @@ -31,7 +31,7 @@ pub struct PallasCompressedElementWrapper { impl PallasCompressedElementWrapper { /// Wraps repr into the wrapper - pub fn new(repr: [u8; 32]) -> Self { + pub const fn new(repr: [u8; 32]) -> Self { Self { repr } } } @@ -44,7 +44,7 @@ pub struct VestaCompressedElementWrapper { impl VestaCompressedElementWrapper { /// Wraps repr into the wrapper - pub fn new(repr: [u8; 32]) -> Self { + pub const fn new(repr: [u8; 32]) -> Self { Self { repr } } } diff --git a/src/provider/pedersen.rs b/src/provider/pedersen.rs index bad1247e..fe00c52f 100644 --- a/src/provider/pedersen.rs +++ b/src/provider/pedersen.rs @@ -203,7 +203,9 @@ impl CommitmentEngineTrait for CommitmentEngine { } } -pub(crate) trait CommitmentKeyExtTrait { +/// A trait listing properties of a commitment key that can be managed in a divide-and-conquer fashion +pub trait CommitmentKeyExtTrait { + /// Holds the type of the commitment engine type CE: CommitmentEngineTrait; /// Splits the commitment key into two pieces at a specified point diff --git a/src/r1cs.rs b/src/r1cs.rs index 76d93ed4..105101ec 100644 --- a/src/r1cs.rs +++ b/src/r1cs.rs @@ -139,6 +139,16 @@ impl R1CSShape { }) } + // Checks regularity conditions on the R1CSShape, required in Spartan-class SNARKs + // Panics if num_cons, num_vars, or num_io are not powers of two, or if num_io > num_vars + #[inline] + pub(crate) fn check_regular_shape(&self) { + assert_eq!(self.num_cons.next_power_of_two(), self.num_cons); + assert_eq!(self.num_vars.next_power_of_two(), self.num_vars); + assert_eq!(self.num_io.next_power_of_two(), self.num_io); + assert!(self.num_io < self.num_vars); + } + pub fn multiply_vec( &self, z: &[G::Scalar], diff --git a/src/spartan/polynomial.rs b/src/spartan/polynomial.rs index 97255d65..ee8ecba5 100644 --- a/src/spartan/polynomial.rs +++ b/src/spartan/polynomial.rs @@ -31,7 +31,7 @@ impl EqPolynomial { /// Creates a new `EqPolynomial` from a vector of Scalars `r`. /// /// Each Scalar in `r` corresponds to a bit from the binary representation of an input value `e`. - pub fn new(r: Vec) -> Self { + pub const fn new(r: Vec) -> Self { EqPolynomial { r } } @@ -111,7 +111,7 @@ impl MultilinearPolynomial { } /// Returns the number of variables in the multilinear polynomial - pub fn get_num_vars(&self) -> usize { + pub const fn get_num_vars(&self) -> usize { self.num_vars } @@ -160,7 +160,7 @@ impl MultilinearPolynomial { (0..chis.len()) .into_par_iter() .map(|i| chis[i] * self.Z[i]) - .reduce(|| Scalar::ZERO, |x, y| x + y) + .sum() } /// Evaluates the polynomial with the given evaluations and point. diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 0011f463..57bec8fe 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -119,51 +119,34 @@ impl R1CSShapeSparkRepr { max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two() }; - let row = { - let mut r = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(r, _, _)| *r) - .collect::>(); - r.resize(N, 0usize); - r - }; + let (mut row, mut col) = (vec![0usize; N], vec![0usize; N]); - let col = { - let mut c = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(_, c, _)| *c) - .collect::>(); - c.resize(N, 0usize); - c - }; + for (i, (r, c, _)) in S.A.iter().chain(S.B.iter()).chain(S.C.iter()).enumerate() { + row[i] = *r; + col[i] = *c; + } let val_A = { - let mut val = S.A.iter().map(|(_, _, v)| *v).collect::>(); - val.resize(N, G::Scalar::ZERO); + let mut val = vec![G::Scalar::ZERO; N]; + for (i, (_, _, v)) in S.A.iter().enumerate() { + val[i] = *v; + } val }; let val_B = { - // prepend zeros - let mut val = vec![G::Scalar::ZERO; S.A.len()]; - val.extend(S.B.iter().map(|(_, _, v)| *v).collect::>()); - // append zeros - val.resize(N, G::Scalar::ZERO); + let mut val = vec![G::Scalar::ZERO; N]; + for (i, (_, _, v)) in S.B.iter().enumerate() { + val[S.A.len() + i] = *v; + } val }; let val_C = { - // prepend zeros - let mut val = vec![G::Scalar::ZERO; S.A.len() + S.B.len()]; - val.extend(S.C.iter().map(|(_, _, v)| *v).collect::>()); - // append zeros - val.resize(N, G::Scalar::ZERO); + let mut val = vec![G::Scalar::ZERO; N]; + for (i, (_, _, v)) in S.C.iter().enumerate() { + val[S.A.len() + S.B.len() + i] = *v; + } val }; @@ -265,29 +248,30 @@ impl R1CSShapeSparkRepr { let mem_row = EqPolynomial::new(r_x_padded).evals(); let mem_col = { - let mut z = z.to_vec(); - z.resize(self.N, G::Scalar::ZERO); - z + let mut val = vec![G::Scalar::ZERO; self.N]; + for (i, v) in z.iter().enumerate() { + val[i] = *v; + } + val }; - let mut E_row = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(r, _, _)| mem_row[*r]) - .collect::>(); - - let mut E_col = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(_, c, _)| mem_col[*c]) - .collect::>(); + let (E_row, E_col) = { + let mut E_row = vec![mem_row[0]; self.N]; // we place mem_row[0] since resized row is appended with 0s + let mut E_col = vec![mem_col[0]; self.N]; - E_row.resize(self.N, mem_row[0]); // we place mem_row[0] since resized row is appended with 0s - E_col.resize(self.N, mem_col[0]); + for (i, (val_r, val_c)) in S + .A + .iter() + .chain(S.B.iter()) + .chain(S.C.iter()) + .map(|(r, c, _)| (mem_row[*r], mem_col[*c])) + .enumerate() + { + E_row[i] = val_r; + E_col[i] = val_c; + } + (E_row, E_col) + }; (mem_row, mem_col, E_row, E_col) } @@ -411,12 +395,10 @@ impl ProductSumcheckInstance { let poly_A = MultilinearPolynomial::new(EqPolynomial::new(rand_eq).evals()); let poly_B_vec = left_vec - .clone() .into_par_iter() .map(MultilinearPolynomial::new) .collect::>(); let poly_C_vec = right_vec - .clone() .into_par_iter() .map(MultilinearPolynomial::new) .collect::>(); @@ -477,43 +459,10 @@ impl SumcheckEngine for ProductSumcheckInstance { .zip(self.poly_C_vec.iter()) .zip(self.poly_D_vec.iter()) .map(|((poly_B, poly_C), poly_D)| { - let len = poly_B.len() / 2; // Make an iterator returning the contributions to the evaluations - let (eval_point_0, eval_point_2, eval_point_3) = (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i]; - let eval_point_2 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i]; - let eval_point_3 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - (eval_point_0, eval_point_2, eval_point_3) - }) - .reduce( - || (G::Scalar::ZERO, G::Scalar::ZERO, G::Scalar::ZERO), - |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), - ); + let (eval_point_0, eval_point_2, eval_point_3) = + SumcheckProof::::compute_eval_points_cubic(poly_A, poly_B, poly_C, poly_D, &comb_func); + vec![eval_point_0, eval_point_2, eval_point_3] }) .collect::>>() @@ -584,44 +533,10 @@ impl SumcheckEngine for OuterSumcheckInstance { poly_C_comp: &G::Scalar, poly_D_comp: &G::Scalar| -> G::Scalar { *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) }; - let len = poly_A.len() / 2; // Make an iterator returning the contributions to the evaluations - let (eval_point_0, eval_point_2, eval_point_3) = (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i]; - let eval_point_2 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i]; - let eval_point_3 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - (eval_point_0, eval_point_2, eval_point_3) - }) - .reduce( - || (G::Scalar::ZERO, G::Scalar::ZERO, G::Scalar::ZERO), - |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), - ); + let (eval_point_0, eval_point_2, eval_point_3) = + SumcheckProof::::compute_eval_points_cubic(poly_A, poly_B, poly_C, poly_D, &comb_func); vec![vec![eval_point_0, eval_point_2, eval_point_3]] } @@ -673,6 +588,8 @@ impl SumcheckEngine for InnerSumcheckInstance { -> G::Scalar { *poly_A_comp * *poly_B_comp * *poly_C_comp }; let len = poly_A.len() / 2; + // TODO: make this call a function in sumcheck.rs by writing an n-ary variant of crate::spartan::sumcheck::SumcheckProof::::compute_eval_points_cubic + // once #[feature(array_methods)] stabilizes (this n-ary variant would need array::each_ref) // Make an iterator returning the contributions to the evaluations let (eval_point_0, eval_point_2, eval_point_3) = (0..len) .into_par_iter() @@ -862,7 +779,7 @@ impl> RelaxedR1CSSNARK let mut e = claim; let mut r: Vec = Vec::new(); - let mut cubic_polys: Vec> = Vec::new(); + let mut cubic_polys: Vec> = Vec::new(); let num_rounds = mem.size().log_2(); for _i in 0..num_rounds { let mut evals: Vec> = Vec::new(); @@ -967,10 +884,7 @@ impl> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait G::Scalar { (0..M.len()) - .collect::>() - .par_iter() - .map(|&i| { + .into_par_iter() + .map(|i| { let (row, col, val) = M[i]; T_x[row] * T_y[col] * val }) - .reduce(|| G::Scalar::ZERO, |acc, x| acc + x) + .sum() }; let (T_x, T_y) = rayon::join( @@ -436,9 +432,8 @@ impl> RelaxedR1CSSNARKTrait>() - .par_iter() - .map(|&i| evaluate_with_table(M_vec[i], &T_x, &T_y)) + .into_par_iter() + .map(|i| evaluate_with_table(M_vec[i], &T_x, &T_y)) .collect() }; diff --git a/src/spartan/sumcheck.rs b/src/spartan/sumcheck.rs index 01d99c5f..fc47a56b 100644 --- a/src/spartan/sumcheck.rs +++ b/src/spartan/sumcheck.rs @@ -3,19 +3,18 @@ use super::polynomial::MultilinearPolynomial; use crate::errors::NovaError; use crate::traits::{Group, TranscriptEngineTrait, TranscriptReprTrait}; -use core::marker::PhantomData; -use ff::Field; +use ff::{Field, PrimeField}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound = "")] pub(crate) struct SumcheckProof { - compressed_polys: Vec>, + compressed_polys: Vec>, } impl SumcheckProof { - pub fn new(compressed_polys: Vec>) -> Self { + pub fn new(compressed_polys: Vec>) -> Self { Self { compressed_polys } } @@ -61,6 +60,34 @@ impl SumcheckProof { Ok((e, r)) } + #[inline] + pub(in crate::spartan) fn compute_eval_points_quadratic( + poly_A: &MultilinearPolynomial, + poly_B: &MultilinearPolynomial, + comb_func: &F, + ) -> (G::Scalar, G::Scalar) + where + F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, + { + let len = poly_A.len() / 2; + (0..len) + .into_par_iter() + .map(|i| { + // eval 0: bound_func is A(low) + let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point); + (eval_point_0, eval_point_2) + }) + .reduce( + || (G::Scalar::ZERO, G::Scalar::ZERO), + |a, b| (a.0 + b.0, a.1 + b.1), + ) + } + pub fn prove_quad( claim: &G::Scalar, num_rounds: usize, @@ -73,29 +100,12 @@ impl SumcheckProof { F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, { let mut r: Vec = Vec::new(); - let mut polys: Vec> = Vec::new(); + let mut polys: Vec> = Vec::new(); let mut claim_per_round = *claim; for _ in 0..num_rounds { let poly = { - let len = poly_A.len() / 2; - - // Make an iterator returning the contributions to the evaluations - let (eval_point_0, eval_point_2) = (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point); - (eval_point_0, eval_point_2) - }) - .reduce( - || (G::Scalar::ZERO, G::Scalar::ZERO), - |a, b| (a.0 + b.0, a.1 + b.1), - ); + let (eval_point_0, eval_point_2) = + Self::compute_eval_points_quadratic(poly_A, poly_B, &comb_func); let evals = vec![eval_point_0, claim_per_round - eval_point_0, eval_point_2]; UniPoly::from_evals(&evals) @@ -136,30 +146,18 @@ impl SumcheckProof { transcript: &mut G::TE, ) -> Result<(Self, Vec, (Vec, Vec)), NovaError> where - F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar, + F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, { let mut e = *claim; let mut r: Vec = Vec::new(); - let mut quad_polys: Vec> = Vec::new(); + let mut quad_polys: Vec> = Vec::new(); for _j in 0..num_rounds { let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); for (poly_A, poly_B) in poly_A_vec.iter().zip(poly_B_vec.iter()) { - let mut eval_point_0 = G::Scalar::ZERO; - let mut eval_point_2 = G::Scalar::ZERO; - - let len = poly_A.len() / 2; - for i in 0..len { - // eval 0: bound_func is A(low) - eval_point_0 += comb_func(&poly_A[i], &poly_B[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - eval_point_2 += comb_func(&poly_A_bound_point, &poly_B_bound_point); - } - + let (eval_point_0, eval_point_2) = + Self::compute_eval_points_quadratic(poly_A, poly_B, &comb_func); evals.push((eval_point_0, eval_point_2)); } @@ -193,6 +191,55 @@ impl SumcheckProof { Ok((SumcheckProof::new(quad_polys), r, claims_prod)) } + #[inline] + pub(in crate::spartan) fn compute_eval_points_cubic( + poly_A: &MultilinearPolynomial, + poly_B: &MultilinearPolynomial, + poly_C: &MultilinearPolynomial, + poly_D: &MultilinearPolynomial, + comb_func: &F, + ) -> (G::Scalar, G::Scalar, G::Scalar) + where + F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync, + { + let len = poly_A.len() / 2; + (0..len) + .into_par_iter() + .map(|i| { + // eval 0: bound_func is A(low) + let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i]; + let eval_point_2 = comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i]; + let eval_point_3 = comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); + (eval_point_0, eval_point_2, eval_point_3) + }) + .reduce( + || (G::Scalar::ZERO, G::Scalar::ZERO, G::Scalar::ZERO), + |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), + ) + } + pub fn prove_cubic_with_additive_term( claim: &G::Scalar, num_rounds: usize, @@ -207,49 +254,14 @@ impl SumcheckProof { F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync, { let mut r: Vec = Vec::new(); - let mut polys: Vec> = Vec::new(); + let mut polys: Vec> = Vec::new(); let mut claim_per_round = *claim; for _ in 0..num_rounds { let poly = { - let len = poly_A.len() / 2; - // Make an iterator returning the contributions to the evaluations - let (eval_point_0, eval_point_2, eval_point_3) = (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i]; - let eval_point_2 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i]; - let eval_point_3 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - (eval_point_0, eval_point_2, eval_point_3) - }) - .reduce( - || (G::Scalar::ZERO, G::Scalar::ZERO, G::Scalar::ZERO), - |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), - ); + let (eval_point_0, eval_point_2, eval_point_3) = + Self::compute_eval_points_cubic(poly_A, poly_B, poly_C, poly_D, &comb_func); let evals = vec![ eval_point_0, @@ -291,25 +303,24 @@ impl SumcheckProof { // ax^2 + bx + c stored as vec![a,b,c] // ax^3 + bx^2 + cx + d stored as vec![a,b,c,d] #[derive(Debug)] -pub struct UniPoly { - coeffs: Vec, +pub struct UniPoly { + coeffs: Vec, } // ax^2 + bx + c stored as vec![a,c] // ax^3 + bx^2 + cx + d stored as vec![a,c,d] #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct CompressedUniPoly { - coeffs_except_linear_term: Vec, - _p: PhantomData, +pub struct CompressedUniPoly { + coeffs_except_linear_term: Vec, } -impl UniPoly { - pub fn from_evals(evals: &[G::Scalar]) -> Self { +impl UniPoly { + pub fn from_evals(evals: &[Scalar]) -> Self { // we only support degree-2 or degree-3 univariate polynomials assert!(evals.len() == 3 || evals.len() == 4); let coeffs = if evals.len() == 3 { // ax^2 + bx + c - let two_inv = G::Scalar::from(2).invert().unwrap(); + let two_inv = Scalar::from(2).invert().unwrap(); let c = evals[0]; let a = two_inv * (evals[2] - evals[1] - evals[1] + c); @@ -317,8 +328,8 @@ impl UniPoly { vec![c, b, a] } else { // ax^3 + bx^2 + cx + d - let two_inv = G::Scalar::from(2).invert().unwrap(); - let six_inv = G::Scalar::from(6).invert().unwrap(); + let two_inv = Scalar::from(2).invert().unwrap(); + let six_inv = Scalar::from(6).invert().unwrap(); let d = evals[0]; let a = six_inv @@ -341,18 +352,18 @@ impl UniPoly { self.coeffs.len() - 1 } - pub fn eval_at_zero(&self) -> G::Scalar { + pub fn eval_at_zero(&self) -> Scalar { self.coeffs[0] } - pub fn eval_at_one(&self) -> G::Scalar { + pub fn eval_at_one(&self) -> Scalar { (0..self.coeffs.len()) .into_par_iter() .map(|i| self.coeffs[i]) - .reduce(|| G::Scalar::ZERO, |a, b| a + b) + .sum() } - pub fn evaluate(&self, r: &G::Scalar) -> G::Scalar { + pub fn evaluate(&self, r: &Scalar) -> Scalar { let mut eval = self.coeffs[0]; let mut power = *r; for coeff in self.coeffs.iter().skip(1) { @@ -362,27 +373,26 @@ impl UniPoly { eval } - pub fn compress(&self) -> CompressedUniPoly { + pub fn compress(&self) -> CompressedUniPoly { let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat(); assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); CompressedUniPoly { coeffs_except_linear_term, - _p: Default::default(), } } } -impl CompressedUniPoly { +impl CompressedUniPoly { // we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // linear_term = hint - 2 * constant_term - deg2 term - deg3 term - pub fn decompress(&self, hint: &G::Scalar) -> UniPoly { + pub fn decompress(&self, hint: &Scalar) -> UniPoly { let mut linear_term = *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; for i in 1..self.coeffs_except_linear_term.len() { linear_term -= self.coeffs_except_linear_term[i]; } - let mut coeffs: Vec = Vec::new(); + let mut coeffs: Vec = Vec::new(); coeffs.push(self.coeffs_except_linear_term[0]); coeffs.push(linear_term); coeffs.extend(&self.coeffs_except_linear_term[1..]); @@ -391,7 +401,7 @@ impl CompressedUniPoly { } } -impl TranscriptReprTrait for UniPoly { +impl TranscriptReprTrait for UniPoly { fn to_transcript_bytes(&self) -> Vec { let coeffs = self.compress().coeffs_except_linear_term; coeffs.as_slice().to_transcript_bytes() diff --git a/src/traits/commitment.rs b/src/traits/commitment.rs index 9b4725fc..4ac8349c 100644 --- a/src/traits/commitment.rs +++ b/src/traits/commitment.rs @@ -6,10 +6,12 @@ use crate::{ }; use core::{ fmt::Debug, - ops::{Add, AddAssign, Mul, MulAssign}, + ops::{Add, AddAssign}, }; use serde::{Deserialize, Serialize}; +use super::ScalarMul; + /// Defines basic operations on commitments pub trait CommitmentOps: Add + AddAssign @@ -31,12 +33,6 @@ impl CommitmentOpsOwned for T where { } -/// A helper trait for types implementing a multiplication of a commitment with a scalar -pub trait ScalarMul: Mul + MulAssign {} - -impl ScalarMul for T where T: Mul + MulAssign -{} - /// This trait defines the behavior of the commitment pub trait CommitmentTrait: Clone diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 5138cea8..91d8d320 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -41,8 +41,7 @@ pub trait Group: + for<'de> Deserialize<'de>; /// A type representing an element of the scalar field of the group - type Scalar: PrimeField - + PrimeFieldBits + type Scalar: PrimeFieldBits + PrimeFieldExt + Send + Sync @@ -236,11 +235,9 @@ pub trait PrimeFieldExt: PrimeField { impl> TranscriptReprTrait for &[T] { fn to_transcript_bytes(&self) -> Vec { - (0..self.len()) - .map(|i| self[i].to_transcript_bytes()) - .collect::>() - .into_iter() - .flatten() + self + .iter() + .flat_map(|t| t.to_transcript_bytes()) .collect::>() } }