diff --git a/keccak-air/Cargo.toml b/keccak-air/Cargo.toml index 2804df4c2..bfcf786ce 100644 --- a/keccak-air/Cargo.toml +++ b/keccak-air/Cargo.toml @@ -9,6 +9,7 @@ p3-air = { path = "../air" } p3-field = { path = "../field" } p3-matrix = { path = "../matrix" } p3-maybe-rayon = { path = "../maybe-rayon" } +p3-uni-stark = { path = "../uni-stark" } p3-util = { path = "../util" } tracing = "0.1.37" @@ -29,7 +30,6 @@ p3-poseidon = { path = "../poseidon" } p3-poseidon2 = { path = "../poseidon2" } p3-sha256 = { path = "../sha256", features = ["asm"] } p3-symmetric = { path = "../symmetric" } -p3-uni-stark = { path = "../uni-stark" } rand = "0.8.5" tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } diff --git a/keccak-air/src/air.rs b/keccak-air/src/air.rs index 2425207cc..a3905f617 100644 --- a/keccak-air/src/air.rs +++ b/keccak-air/src/air.rs @@ -3,6 +3,7 @@ use core::borrow::Borrow; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::AbstractField; use p3_matrix::Matrix; +use p3_uni_stark::MultiStageAir; use crate::columns::{KeccakCols, NUM_KECCAK_COLS}; use crate::constants::rc_value_bit; @@ -20,6 +21,8 @@ impl BaseAir for KeccakAir { } } +impl MultiStageAir for KeccakAir {} + impl Air for KeccakAir { #[inline] fn eval(&self, builder: &mut AB) { diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 7584f8a85..d033ef4d2 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -1,5 +1,6 @@ use alloc::vec::Vec; +use itertools::Itertools; use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::Field; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; @@ -7,17 +8,21 @@ use p3_matrix::stack::VerticalPair; use p3_matrix::Matrix; use tracing::instrument; +use crate::traits::MultistageAirBuilder; + #[instrument(name = "check constraints", skip_all)] pub(crate) fn check_constraints( air: &A, preprocessed: &RowMajorMatrix, - main: &RowMajorMatrix, - public_values: &Vec, + traces_by_stage: Vec<&RowMajorMatrix>, + public_values_by_stage: &Vec<&Vec>, + challenges: Vec<&Vec>, ) where F: Field, A: for<'a> Air>, { - let height = main.height(); + let num_stages = traces_by_stage.len(); + let height = traces_by_stage[0].height(); (0..height).for_each(|i| { let i_next = (i + 1) % height; @@ -29,18 +34,30 @@ pub(crate) fn check_constraints( RowMajorMatrixView::new_row(&*next_preprocessed), ); - let local = main.row_slice(i); - let next = main.row_slice(i_next); - let main = VerticalPair::new( - RowMajorMatrixView::new_row(&*local), - RowMajorMatrixView::new_row(&*next), - ); + let stages_local_next = traces_by_stage + .iter() + .map(|trace| { + let stage_local = trace.row_slice(i); + let stage_next = trace.row_slice(i_next); + (stage_local, stage_next) + }) + .collect_vec(); + + let traces_by_stage = (0..num_stages) + .map(|stage| { + VerticalPair::new( + RowMajorMatrixView::new_row(&*stages_local_next[stage].0), + RowMajorMatrixView::new_row(&*stages_local_next[stage].1), + ) + }) + .collect(); let mut builder = DebugConstraintBuilder { row_index: i, + challenges: challenges.clone(), preprocessed, - main, - public_values, + traces_by_stage, + public_values_by_stage, is_first_row: F::from_bool(i == 0), is_last_row: F::from_bool(i == height - 1), is_transition: F::from_bool(i != height - 1), @@ -56,8 +73,9 @@ pub(crate) fn check_constraints( pub struct DebugConstraintBuilder<'a, F: Field> { row_index: usize, preprocessed: VerticalPair, RowMajorMatrixView<'a, F>>, - main: VerticalPair, RowMajorMatrixView<'a, F>>, - public_values: &'a [F], + challenges: Vec<&'a Vec>, + traces_by_stage: Vec, RowMajorMatrixView<'a, F>>>, + public_values_by_stage: &'a [&'a Vec], is_first_row: F, is_last_row: F, is_transition: F, @@ -89,7 +107,7 @@ where } fn main(&self) -> Self::M { - self.main + self.traces_by_stage[0] } fn assert_zero>(&mut self, x: I) { @@ -115,8 +133,8 @@ where impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> { type PublicVar = Self::F; - fn public_values(&self) -> &[Self::F] { - self.public_values + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) } } @@ -125,3 +143,19 @@ impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> { self.preprocessed } } + +impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> { + type Challenge = Self::Expr; + + fn stage_public_values(&self, stage: usize) -> &[Self::F] { + self.public_values_by_stage[stage] + } + + fn stage_trace(&self, stage: usize) -> Self::M { + self.traces_by_stage[stage] + } + + fn stage_challenges(&self, stage: usize) -> &[Self::Expr] { + self.challenges[stage] + } +} diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index ea1563254..fb280ee37 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -5,13 +5,15 @@ use p3_field::AbstractField; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::stack::VerticalPair; +use crate::traits::MultistageAirBuilder; use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val}; #[derive(Debug)] pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { - pub main: RowMajorMatrix>, + pub challenges: Vec>>, + pub traces_by_stage: Vec>>, pub preprocessed: RowMajorMatrix>, - pub public_values: &'a Vec>, + pub public_values_by_stage: &'a Vec>>, pub is_first_row: PackedVal, pub is_last_row: PackedVal, pub is_transition: PackedVal, @@ -23,9 +25,10 @@ type ViewPair<'a, T> = VerticalPair, RowMajorMatrixVie #[derive(Debug)] pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { - pub main: ViewPair<'a, SC::Challenge>, + pub challenges: Vec>>, + pub traces_by_stage: Vec>, pub preprocessed: ViewPair<'a, SC::Challenge>, - pub public_values: &'a Vec>, + pub public_values_by_stage: Vec<&'a Vec>>, pub is_first_row: SC::Challenge, pub is_last_row: SC::Challenge, pub is_transition: SC::Challenge, @@ -40,7 +43,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { type M = RowMajorMatrix>; fn main(&self) -> Self::M { - self.main.clone() + self.traces_by_stage[0].clone() } fn is_first_row(&self) -> Self::Expr { @@ -67,10 +70,25 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { } impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraintFolder<'a, SC> { - type PublicVar = Self::F; + type PublicVar = Val; - fn public_values(&self) -> &[Self::F] { - self.public_values + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) + } +} + +impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for ProverConstraintFolder<'a, SC> { + type Challenge = Val; + + fn stage_trace(&self, stage: usize) -> ::M { + self.traces_by_stage[stage].clone() + } + + fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { + &self.challenges[stage] + } + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + &self.public_values_by_stage[stage] } } @@ -87,7 +105,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> type M = ViewPair<'a, SC::Challenge>; fn main(&self) -> Self::M { - self.main + self.traces_by_stage[0] } fn is_first_row(&self) -> Self::Expr { @@ -114,10 +132,25 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> } impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstraintFolder<'a, SC> { - type PublicVar = Self::F; + type PublicVar = Val; - fn public_values(&self) -> &[Self::F] { - self.public_values + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) + } +} + +impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for VerifierConstraintFolder<'a, SC> { + type Challenge = Val; + + fn stage_trace(&self, stage: usize) -> ::M { + self.traces_by_stage[stage] + } + + fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { + &self.challenges[stage] + } + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + self.public_values_by_stage[stage] } } diff --git a/uni-stark/src/lib.rs b/uni-stark/src/lib.rs index 488c8f266..6a0d401e6 100644 --- a/uni-stark/src/lib.rs +++ b/uni-stark/src/lib.rs @@ -11,6 +11,7 @@ mod prover; mod symbolic_builder; mod symbolic_expression; mod symbolic_variable; +mod traits; mod verifier; mod zerofier_coset; @@ -26,5 +27,6 @@ pub use prover::*; pub use symbolic_builder::*; pub use symbolic_expression::*; pub use symbolic_variable::*; +pub use traits::*; pub use verifier::*; pub use zerofier_coset::*; diff --git a/uni-stark/src/proof.rs b/uni-stark/src/proof.rs index 79b8b96a5..ce5681042 100644 --- a/uni-stark/src/proof.rs +++ b/uni-stark/src/proof.rs @@ -1,9 +1,10 @@ use alloc::vec::Vec; use p3_commit::Pcs; +use p3_matrix::dense::RowMajorMatrix; use serde::{Deserialize, Serialize}; -use crate::StarkGenericConfig; +use crate::{StarkGenericConfig, Val}; type Com = <::Pcs as Pcs< ::Challenge, @@ -29,7 +30,7 @@ pub struct Proof { #[derive(Debug, Serialize, Deserialize)] pub struct Commitments { - pub(crate) trace: Com, + pub(crate) traces_by_stage: Vec, pub(crate) quotient_chunks: Com, } @@ -37,8 +38,8 @@ pub struct Commitments { pub struct OpenedValues { pub(crate) preprocessed_local: Vec, pub(crate) preprocessed_next: Vec, - pub(crate) trace_local: Vec, - pub(crate) trace_next: Vec, + pub(crate) traces_by_stage_local: Vec>, + pub(crate) traces_by_stage_next: Vec>, pub(crate) quotient_chunks: Vec>, } @@ -52,3 +53,12 @@ pub struct StarkProvingKey { pub struct StarkVerifyingKey { pub preprocessed_commit: Com, } + +pub struct ProcessedStage { + pub(crate) commitment: Com, + pub(crate) prover_data: PcsProverData, + pub(crate) challenge_values: Vec>, + pub(crate) public_values: Vec>, + #[cfg(debug_assertions)] + pub(crate) trace: RowMajorMatrix>, +} diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index 5c70568c1..06d5c6ca8 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -1,6 +1,7 @@ +use alloc::borrow::ToOwned; use alloc::vec; use alloc::vec::Vec; -use core::iter; +use core::iter::{self, once}; use itertools::{izip, Itertools}; use p3_air::Air; @@ -14,11 +15,20 @@ use p3_util::log2_strict_usize; use tracing::{info_span, instrument}; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; +use crate::traits::MultiStageAir; use crate::{ - Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, Proof, ProverConstraintFolder, - StarkGenericConfig, StarkProvingKey, Val, + Commitments, Domain, OpenedValues, PackedChallenge, PackedVal, ProcessedStage, Proof, + ProverConstraintFolder, StarkGenericConfig, StarkProvingKey, Val, }; +struct UnusedCallback; + +impl NextStageTraceCallback for UnusedCallback { + fn compute_stage(&self, _: u32, _: &[Val]) -> CallbackResult> { + unreachable!() + } +} + #[instrument(skip_all)] #[allow(clippy::multiple_bound_locations)] // cfg not supported in where clauses? pub fn prove< @@ -29,14 +39,23 @@ pub fn prove< config: &SC, air: &A, challenger: &mut SC::Challenger, - trace: RowMajorMatrix>, + main_trace: RowMajorMatrix>, public_values: &Vec>, ) -> Proof where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { - prove_with_key(config, None, air, challenger, trace, public_values) + prove_with_key( + config, + None, + air, + challenger, + main_trace, + &UnusedCallback, + public_values, + ) } #[instrument(skip_all)] @@ -45,39 +64,32 @@ pub fn prove_with_key< SC, #[cfg(debug_assertions)] A: for<'a> Air>>, #[cfg(not(debug_assertions))] A, + C, >( config: &SC, proving_key: Option<&StarkProvingKey>, air: &A, challenger: &mut SC::Challenger, - trace: RowMajorMatrix>, - public_values: &Vec>, + stage_0_trace: RowMajorMatrix>, + next_stage_trace_callback: &C, + #[allow(clippy::ptr_arg)] + // we do not use `&[Val]` in order to keep the same API + stage_0_public_values: &Vec>, ) -> Proof where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, + C: NextStageTraceCallback, { - #[cfg(debug_assertions)] - crate::check_constraints::check_constraints( - air, - &air.preprocessed_trace() - .unwrap_or(RowMajorMatrix::new(vec![], 0)), - &trace, - public_values, - ); - - let degree = trace.height(); + let degree = stage_0_trace.height(); let log_degree = log2_strict_usize(degree); - let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); - let quotient_degree = 1 << log_quotient_degree; + let stage_count = >>::stage_count(air); let pcs = config.pcs(); let trace_domain = pcs.natural_domain_for_degree(degree); - let (trace_commit, trace_data) = - info_span!("commit to trace data").in_scope(|| pcs.commit(vec![(trace_domain, trace)])); - // Observe the instance. challenger.observe(Val::::from_canonical_usize(log_degree)); // TODO: Might be best practice to include other instance data here; see verifier comment. @@ -85,8 +97,82 @@ where if let Some(proving_key) = proving_key { challenger.observe(proving_key.preprocessed_commit.clone()) }; - challenger.observe(trace_commit.clone()); - challenger.observe_slice(public_values); + + let mut state: ProverState = ProverState::new(pcs, trace_domain, challenger); + let mut stage = Stage { + trace: stage_0_trace, + challenge_count: >>::stage_challenge_count(air, 0), + public_values: stage_0_public_values.to_owned(), + }; + + assert!(stage_count >= 1); + // generate all stages starting from the second one based on the witgen callback + for stage_id in 1..stage_count { + state = state.run_stage(stage); + // get the challenges drawn at the end of the previous stage + let local_challenges = &state.processed_stages.last().unwrap().challenge_values; + let CallbackResult { + trace, + public_values, + challenges, + } = next_stage_trace_callback.compute_stage(stage_id as u32, local_challenges); + // replace the challenges of the last stage with the ones received + state.processed_stages.last_mut().unwrap().challenge_values = challenges; + // go to the next stage + stage = Stage { + trace, + challenge_count: >>::stage_challenge_count( + air, + stage_id as u32, + ), + public_values, + }; + } + + // run the last stage + state = state.run_stage(stage); + + // sanity check that the last stage did not create any challenges + assert!(state + .processed_stages + .last() + .unwrap() + .challenge_values + .is_empty()); + // sanity check that we processed as many stages as expected + assert_eq!(state.processed_stages.len(), stage_count); + + // with the witness complete, check the constraints + #[cfg(debug_assertions)] + crate::check_constraints::check_constraints( + air, + &air.preprocessed_trace() + .unwrap_or(RowMajorMatrix::new(Default::default(), 0)), + state.processed_stages.iter().map(|s| &s.trace).collect(), + &state + .processed_stages + .iter() + .map(|s| &s.public_values) + .collect(), + state + .processed_stages + .iter() + .map(|s| &s.challenge_values) + .collect(), + ); + + let log_quotient_degree = get_log_quotient_degree::, A>( + air, + &state + .processed_stages + .iter() + .map(|s| s.public_values.len()) + .collect::>(), + ); + let quotient_degree = 1 << log_quotient_degree; + + let challenger = &mut state.challenger; + let alpha: SC::Challenge = challenger.sample_ext_element(); let quotient_domain = @@ -96,15 +182,32 @@ where pcs.get_evaluations_on_domain(&proving_key.preprocessed_data, 0, quotient_domain) }); - let trace_on_quotient_domain = pcs.get_evaluations_on_domain(&trace_data, 0, quotient_domain); + let traces_on_quotient_domain = state + .processed_stages + .iter() + .map(|s| pcs.get_evaluations_on_domain(&s.prover_data, 0, quotient_domain)) + .collect(); + + let challenges = state + .processed_stages + .iter() + .map(|stage| stage.challenge_values.clone()) + .collect(); + + let public_values_by_stage = state + .processed_stages + .iter() + .map(|stage| stage.public_values.clone()) + .collect(); let quotient_values = quotient_values( air, - public_values, + &public_values_by_stage, trace_domain, quotient_domain, preprocessed_on_quotient_domain, - trace_on_quotient_domain, + traces_on_quotient_domain, + challenges, alpha, ); let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base(); @@ -116,7 +219,11 @@ where challenger.observe(quotient_commit.clone()); let commitments = Commitments { - trace: trace_commit, + traces_by_stage: state + .processed_stages + .iter() + .map(|s| s.commitment.clone()) + .collect(), quotient_chunks: quotient_commit, }; @@ -132,14 +239,20 @@ where }) .into_iter(), ) - .chain([ - (&trace_data, vec![vec![zeta, zeta_next]]), - ( - "ient_data, - // open every chunk at zeta - (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), - ), - ]) + .chain( + state + .processed_stages + .iter() + .map(|processed_stage| { + (&processed_stage.prover_data, vec![vec![zeta, zeta_next]]) + }) + .collect_vec(), + ) + .chain(once(( + "ient_data, + // open every chunk at zeta + (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), + ))) .collect_vec(), challenger, ); @@ -155,12 +268,17 @@ where (vec![], vec![]) }; - // get values for the trace - let value = opened_values.next().unwrap(); - assert_eq!(value.len(), 1); - assert_eq!(value[0].len(), 2); - let trace_local = value[0][0].clone(); - let trace_next = value[0][1].clone(); + // get values for the traces + let (traces_by_stage_local, traces_by_stage_next): (Vec<_>, Vec<_>) = state + .processed_stages + .iter() + .map(|_| { + let value = opened_values.next().unwrap(); + assert_eq!(value.len(), 1); + assert_eq!(value[0].len(), 2); + (value[0][0].clone(), value[0][1].clone()) + }) + .unzip(); // get values for the quotient let value = opened_values.next().unwrap(); @@ -168,8 +286,8 @@ where let quotient_chunks = value.iter().map(|v| v[0].clone()).collect_vec(); let opened_values = OpenedValues { - trace_local, - trace_next, + traces_by_stage_local, + traces_by_stage_next, preprocessed_local, preprocessed_next, quotient_chunks, @@ -182,19 +300,21 @@ where } } +#[allow(clippy::too_many_arguments)] #[instrument(name = "compute quotient polynomial", skip_all)] -fn quotient_values( +fn quotient_values<'a, SC, A, Mat>( air: &A, - public_values: &Vec>, + public_values_by_stage: &'a Vec>>, trace_domain: Domain, quotient_domain: Domain, preprocessed_on_quotient_domain: Option, - trace_on_quotient_domain: Mat, + traces_on_quotient_domain: Vec, + challenges: Vec>>, alpha: SC::Challenge, ) -> Vec where SC: StarkGenericConfig, - A: for<'a> Air>, + A: Air>, Mat: Matrix> + Sync, { let quotient_size = quotient_domain.size(); @@ -202,7 +322,6 @@ where .as_ref() .map(Matrix::width) .unwrap_or_default(); - let width = trace_on_quotient_domain.width(); let mut sels = trace_domain.selectors_on_coset(quotient_domain); let qdb = log2_strict_usize(quotient_domain.size()) - log2_strict_usize(trace_domain.size()); @@ -241,19 +360,27 @@ where preprocessed_width, ); - let main = RowMajorMatrix::new( - iter::empty() - .chain(trace_on_quotient_domain.vertically_packed_row(i_start)) - .chain(trace_on_quotient_domain.vertically_packed_row(i_start + next_step)) - .collect_vec(), - width, - ); + let traces_by_stage = traces_on_quotient_domain + .iter() + .map(|trace_on_quotient_domain| { + RowMajorMatrix::new( + iter::empty() + .chain(trace_on_quotient_domain.vertically_packed_row(i_start)) + .chain( + trace_on_quotient_domain.vertically_packed_row(i_start + next_step), + ) + .collect_vec(), + trace_on_quotient_domain.width(), + ) + }) + .collect(); let accumulator = PackedChallenge::::zero(); let mut folder = ProverConstraintFolder { + challenges: challenges.clone(), + traces_by_stage, preprocessed, - main, - public_values, + public_values_by_stage, is_first_row, is_last_row, is_transition, @@ -275,3 +402,85 @@ where }) .collect() } + +pub struct ProverState<'a, SC: StarkGenericConfig> { + pub(crate) processed_stages: Vec>, + pub(crate) challenger: &'a mut SC::Challenger, + pub(crate) pcs: &'a ::Pcs, + pub(crate) trace_domain: Domain, +} + +impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> { + pub(crate) fn new( + pcs: &'a ::Pcs, + trace_domain: Domain, + challenger: &'a mut ::Challenger, + ) -> Self { + Self { + processed_stages: Default::default(), + challenger, + pcs, + trace_domain, + } + } + + pub(crate) fn run_stage(mut self, stage: Stage) -> Self { + #[cfg(debug_assertions)] + let trace = stage.trace.clone(); + + // commit to the trace for this stage + let (commitment, prover_data) = info_span!("commit to stage {stage} data") + .in_scope(|| self.pcs.commit(vec![(self.trace_domain, stage.trace)])); + + self.challenger.observe(commitment.clone()); + // observe the public inputs for this stage + self.challenger.observe_slice(&stage.public_values); + + let challenge_values = (0..stage.challenge_count) + .map(|_| self.challenger.sample()) + .collect(); + + self.processed_stages.push(ProcessedStage { + public_values: stage.public_values, + prover_data, + commitment, + challenge_values, + #[cfg(debug_assertions)] + trace, + }); + self + } +} + +pub struct Stage { + /// the witness for this stage + pub(crate) trace: RowMajorMatrix>, + /// the number of challenges to be drawn at the end of this stage + pub(crate) challenge_count: usize, + /// the public values for this stage + pub(crate) public_values: Vec>, +} + +pub struct CallbackResult { + /// the trace for this stage + pub(crate) trace: RowMajorMatrix, + /// the values of the public inputs of this stage + pub(crate) public_values: Vec, + /// the values of the challenges drawn at the previous stage + pub(crate) challenges: Vec, +} + +impl CallbackResult { + pub fn new(trace: RowMajorMatrix, public_values: Vec, challenges: Vec) -> Self { + Self { + trace, + public_values, + challenges, + } + } +} + +pub trait NextStageTraceCallback { + /// Computes the stage number `trace_stage` based on `challenges` drawn at the end of stage `trace_stage - 1` + fn compute_stage(&self, stage: u32, challenges: &[Val]) -> CallbackResult>; +} diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 813f60032..6516b48e5 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder}; +use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_util::log2_ceil_usize; @@ -9,16 +9,17 @@ use tracing::instrument; use crate::symbolic_expression::SymbolicExpression; use crate::symbolic_variable::SymbolicVariable; +use crate::traits::{MultiStageAir, MultistageAirBuilder}; use crate::Entry; #[instrument(name = "infer log of constraint degree", skip_all)] -pub fn get_log_quotient_degree(air: &A, num_public_values: usize) -> usize +pub fn get_log_quotient_degree(air: &A, public_values_counts: &[usize]) -> usize where F: Field, - A: Air>, + A: MultiStageAir>, { // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. - let constraint_degree = get_max_constraint_degree(air, num_public_values).max(2); + let constraint_degree = get_max_constraint_degree(air, public_values_counts).max(2); // The quotient's actual degree is approximately (max_constraint_degree - 1) n, // where subtracting 1 comes from division by the zerofier. @@ -27,12 +28,12 @@ where } #[instrument(name = "infer constraint degree", skip_all, level = "debug")] -pub fn get_max_constraint_degree(air: &A, num_public_values: usize) -> usize +pub fn get_max_constraint_degree(air: &A, public_values_counts: &[usize]) -> usize where F: Field, - A: Air>, + A: MultiStageAir>, { - get_symbolic_constraints(air, num_public_values) + get_symbolic_constraints(air, public_values_counts) .iter() .map(|c| c.degree_multiple()) .max() @@ -42,14 +43,24 @@ where #[instrument(name = "evaluate constraints symbolically", skip_all, level = "debug")] pub fn get_symbolic_constraints( air: &A, - num_public_values: usize, + public_values_counts: &[usize], ) -> Vec> where F: Field, - A: Air>, + A: MultiStageAir>, { - let mut builder = - SymbolicAirBuilder::new(air.preprocessed_width(), air.width(), num_public_values); + let widths: Vec<_> = (0..air.stage_count()) + .map(|i| air.stage_trace_width(i as u32)) + .collect(); + let challenges: Vec<_> = (0..air.stage_count()) + .map(|i| air.stage_challenge_count(i as u32)) + .collect(); + let mut builder = SymbolicAirBuilder::new( + air.preprocessed_width(), + &widths, + public_values_counts, + challenges, + ); air.eval(&mut builder); builder.constraints() } @@ -57,14 +68,20 @@ where /// An `AirBuilder` for evaluating constraints symbolically, and recording them for later use. #[derive(Debug)] pub struct SymbolicAirBuilder { + challenges: Vec>>, preprocessed: RowMajorMatrix>, - main: RowMajorMatrix>, - public_values: Vec>, + traces_by_stage: Vec>>, + public_values_by_stage: Vec>>, constraints: Vec>, } impl SymbolicAirBuilder { - pub(crate) fn new(preprocessed_width: usize, width: usize, num_public_values: usize) -> Self { + pub(crate) fn new( + preprocessed_width: usize, + stage_widths: &[usize], + public_value_counts: &[usize], + challenges: Vec, + ) -> Self { let prep_values = [0, 1] .into_iter() .flat_map(|offset| { @@ -72,19 +89,50 @@ impl SymbolicAirBuilder { .map(move |index| SymbolicVariable::new(Entry::Preprocessed { offset }, index)) }) .collect(); - let main_values = [0, 1] - .into_iter() - .flat_map(|offset| { - (0..width).map(move |index| SymbolicVariable::new(Entry::Main { offset }, index)) + let traces_by_stage = stage_widths + .iter() + .map(|width| { + let values = [0, 1] + .into_iter() + .flat_map(|offset| { + (0..*width) + .map(move |index| SymbolicVariable::new(Entry::Main { offset }, index)) + }) + .collect(); + RowMajorMatrix::new(values, *width) }) .collect(); - let public_values = (0..num_public_values) - .map(move |index| SymbolicVariable::new(Entry::Public, index)) + let mut challenge_index = 0; + let challenges = challenges + .iter() + .map(|count| { + (0..*count) + .map(|_| { + let res = SymbolicVariable::new(Entry::Challenge, challenge_index); + challenge_index += 1; + res + }) + .collect() + }) + .collect(); + let mut public_value_index = 0; + let public_values_by_stage = public_value_counts + .iter() + .map(|count| { + (0..*count) + .map(|_| { + let res = SymbolicVariable::new(Entry::Public, public_value_index); + public_value_index += 1; + res + }) + .collect() + }) .collect(); Self { + challenges, preprocessed: RowMajorMatrix::new(prep_values, preprocessed_width), - main: RowMajorMatrix::new(main_values, width), - public_values, + traces_by_stage, + public_values_by_stage, constraints: vec![], } } @@ -101,7 +149,7 @@ impl AirBuilder for SymbolicAirBuilder { type M = RowMajorMatrix; fn main(&self) -> Self::M { - self.main.clone() + self.traces_by_stage[0].clone() } fn is_first_row(&self) -> Self::Expr { @@ -127,8 +175,25 @@ impl AirBuilder for SymbolicAirBuilder { impl AirBuilderWithPublicValues for SymbolicAirBuilder { type PublicVar = SymbolicVariable; + fn public_values(&self) -> &[Self::PublicVar] { - &self.public_values + self.stage_public_values(0) + } +} + +impl MultistageAirBuilder for SymbolicAirBuilder { + type Challenge = Self::Var; + + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + &self.public_values_by_stage[stage] + } + + fn stage_trace(&self, stage: usize) -> Self::M { + self.traces_by_stage[stage].clone() + } + + fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { + &self.challenges[stage] } } diff --git a/uni-stark/src/traits.rs b/uni-stark/src/traits.rs new file mode 100644 index 000000000..30d159481 --- /dev/null +++ b/uni-stark/src/traits.rs @@ -0,0 +1,38 @@ +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues}; + +pub trait MultistageAirBuilder: AirBuilderWithPublicValues { + type Challenge: Clone + Into; + + /// Traces from each stage. + fn stage_trace(&self, stage: usize) -> Self::M; + + /// Challenges from each stage, drawn from the base field + fn stage_challenges(&self, stage: usize) -> &[Self::Challenge]; + + /// Public values for each stage + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + match stage { + 0 => self.public_values(), + _ => unimplemented!(), + } + } +} + +pub trait MultiStageAir: Air { + fn stage_count(&self) -> usize { + 1 + } + + /// The number of trace columns in this stage + fn stage_trace_width(&self, stage: u32) -> usize { + match stage { + 0 => self.width(), + _ => unimplemented!(), + } + } + + /// The number of challenges produced at the end of each stage + fn stage_challenge_count(&self, _stage: u32) -> usize { + 0 + } +} diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 36c242b4f..8dece8b68 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -2,8 +2,8 @@ use alloc::vec; use alloc::vec::Vec; use core::iter; -use itertools::Itertools; -use p3_air::{Air, BaseAir}; +use itertools::{izip, Itertools}; +use p3_air::BaseAir; use p3_challenger::{CanObserve, CanSample, FieldChallenger}; use p3_commit::{Pcs, PolynomialSpace}; use p3_field::{AbstractExtensionField, AbstractField, Field}; @@ -12,6 +12,7 @@ use p3_matrix::stack::VerticalPair; use tracing::instrument; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; +use crate::traits::MultiStageAir; use crate::{ PcsError, Proof, StarkGenericConfig, StarkVerifyingKey, Val, VerifierConstraintFolder, }; @@ -26,9 +27,10 @@ pub fn verify( ) -> Result<(), VerificationError>> where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { - verify_with_key(config, None, air, challenger, proof, public_values) + verify_with_key(config, None, air, challenger, proof, vec![public_values]) } #[instrument(skip_all)] @@ -38,11 +40,12 @@ pub fn verify_with_key( air: &A, challenger: &mut SC::Challenger, proof: &Proof, - public_values: &Vec>, + public_values_by_stage: Vec<&Vec>>, ) -> Result<(), VerificationError>> where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { let Proof { commitments, @@ -52,8 +55,18 @@ where } = proof; let degree = 1 << degree_bits; - let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); + let log_quotient_degree = get_log_quotient_degree::, A>( + air, + &public_values_by_stage + .iter() + .map(|values| values.len()) + .collect::>(), + ); let quotient_degree = 1 << log_quotient_degree; + let stage_count = proof.commitments.traces_by_stage.len(); + let challenge_counts: Vec = (0..stage_count) + .map(|i| >>::stage_challenge_count(air, i as u32)) + .collect(); let pcs = config.pcs(); let trace_domain = pcs.natural_domain_for_degree(degree); @@ -61,17 +74,32 @@ where trace_domain.create_disjoint_domain(1 << (degree_bits + log_quotient_degree)); let quotient_chunks_domains = quotient_domain.split_domains(quotient_degree); - let air_width = >>::width(air); + let air_widths = (0..stage_count) + .map(|stage| { + >>>::stage_trace_width(air, stage as u32) + }) + .collect::>(); let air_fixed_width = >>::preprocessed_width(air); let valid_shape = opened_values.preprocessed_local.len() == air_fixed_width && opened_values.preprocessed_next.len() == air_fixed_width - && opened_values.trace_local.len() == air_width - && opened_values.trace_next.len() == air_width + && opened_values + .traces_by_stage_local + .iter() + .zip(&air_widths) + .all(|(stage, air_width)| stage.len() == *air_width) + && opened_values + .traces_by_stage_next + .iter() + .zip(&air_widths) + .all(|(stage, air_width)| stage.len() == *air_width) && opened_values.quotient_chunks.len() == quotient_degree && opened_values .quotient_chunks .iter() - .all(|qc| qc.len() == >>::D); + .all(|qc| qc.len() == >>::D) + && public_values_by_stage.len() == stage_count + && challenge_counts.len() == stage_count; + if !valid_shape { return Err(VerificationError::InvalidProofShape); } @@ -87,8 +115,19 @@ where if let Some(verifying_key) = verifying_key { challenger.observe(verifying_key.preprocessed_commit.clone()) }; - challenger.observe(commitments.trace.clone()); - challenger.observe_slice(public_values); + + let mut challenges = vec![]; + + commitments + .traces_by_stage + .iter() + .zip(&public_values_by_stage) + .zip(challenge_counts) + .for_each(|((commitment, public_values), challenge_count)| { + challenger.observe(commitment.clone()); + challenger.observe_slice(public_values); + challenges.push((0..challenge_count).map(|_| challenger.sample()).collect()); + }); let alpha: SC::Challenge = challenger.sample_ext_element(); challenger.observe(commitments.quotient_chunks.clone()); @@ -113,26 +152,34 @@ where }) .into_iter(), ) - .chain([ - ( - commitments.trace.clone(), - vec![( - trace_domain, - vec![ - (zeta, opened_values.trace_local.clone()), - (zeta_next, opened_values.trace_next.clone()), - ], - )], - ), - ( - commitments.quotient_chunks.clone(), - quotient_chunks_domains - .iter() - .zip(&opened_values.quotient_chunks) - .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) - .collect_vec(), - ), - ]) + .chain( + izip!( + commitments.traces_by_stage.iter(), + opened_values.traces_by_stage_local.iter(), + opened_values.traces_by_stage_next.iter() + ) + .map(|(trace_commit, opened_local, opened_next)| { + ( + trace_commit.clone(), + vec![( + trace_domain, + vec![ + (zeta, opened_local.clone()), + (zeta_next, opened_next.clone()), + ], + )], + ) + }) + .collect_vec(), + ) + .chain([( + commitments.quotient_chunks.clone(), + quotient_chunks_domains + .iter() + .zip(&opened_values.quotient_chunks) + .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) + .collect_vec(), + )]) .collect_vec(), opening_proof, challenger, @@ -174,15 +221,23 @@ where RowMajorMatrixView::new_row(&opened_values.preprocessed_next), ); - let main = VerticalPair::new( - RowMajorMatrixView::new_row(&opened_values.trace_local), - RowMajorMatrixView::new_row(&opened_values.trace_next), - ); + let traces_by_stage = opened_values + .traces_by_stage_local + .iter() + .zip(opened_values.traces_by_stage_next.iter()) + .map(|(trace_local, trace_next)| { + VerticalPair::new( + RowMajorMatrixView::new_row(trace_local), + RowMajorMatrixView::new_row(trace_next), + ) + }) + .collect::>>(); let mut folder = VerifierConstraintFolder { + challenges, preprocessed, - main, - public_values, + traces_by_stage, + public_values_by_stage, is_first_row: sels.is_first_row, is_last_row: sels.is_last_row, is_transition: sels.is_transition, diff --git a/uni-stark/tests/fib_air.rs b/uni-stark/tests/fib_air.rs index b6d456afa..e22a23525 100644 --- a/uni-stark/tests/fib_air.rs +++ b/uni-stark/tests/fib_air.rs @@ -13,7 +13,7 @@ use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; -use p3_uni_stark::{prove, verify, StarkConfig}; +use p3_uni_stark::{prove, verify, MultiStageAir, StarkConfig}; use rand::thread_rng; /// For testing the public values feature @@ -56,6 +56,8 @@ impl Air for FibonacciAir { } } +impl MultiStageAir for FibonacciAir {} + pub fn generate_trace_rows(a: u64, b: u64, n: usize) -> RowMajorMatrix { assert!(n.is_power_of_two()); diff --git a/uni-stark/tests/mul_air.rs b/uni-stark/tests/mul_air.rs index c5fe81ede..ae03b1178 100644 --- a/uni-stark/tests/mul_air.rs +++ b/uni-stark/tests/mul_air.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use std::marker::PhantomData; use itertools::Itertools; -use p3_air::{Air, AirBuilder, BaseAir}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; use p3_challenger::{DuplexChallenger, HashChallenger, SerializingChallenger32}; use p3_circle::CirclePcs; @@ -21,7 +21,7 @@ use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; use p3_symmetric::{ CompressionFunctionFromHasher, PaddingFreeSponge, SerializingHasher32, TruncatedPermutation, }; -use p3_uni_stark::{prove, verify, StarkConfig, StarkGenericConfig, Val}; +use p3_uni_stark::{prove, verify, MultiStageAir, StarkConfig, StarkGenericConfig, Val}; use rand::distributions::{Distribution, Standard}; use rand::{thread_rng, Rng}; @@ -116,6 +116,8 @@ impl Air for MulAir { } } +impl MultiStageAir for MulAir {} + fn do_test( config: SC, air: MulAir,