Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update aggregator doc #1365

Merged
merged 10 commits into from
Jul 16, 2024
128 changes: 96 additions & 32 deletions aggregator/src/recursion/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ use sv_halo2_base::{
QuantumCell::Existing,
};

use crate::param::ConfigParams as BatchCircuitConfigParams;
use crate::param::ConfigParams as RecursionCircuitConfigParams;

use super::*;

/// Convenience type to represent the verifying key.
type Svk = KzgSuccinctVerifyingKey<G1Affine>;

/// Convenience type to represent the polynomial commitment scheme.
type Pcs = Kzg<Bn256, Bdfg21>;

/// Convenience type to represent the accumulation scheme for accumulating proofs from multiple
/// SNARKs.
type As = KzgAs<Pcs>;

/// Select condition ? LHS : RHS.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add some more detail: both this and the below entry are helpers to construct corresponding circuits inside of the ecc_chip which is maintained in the loader argument.

select_accumulator select the lhs or rhs according to the condition and return the cells for ec point represent the result of selection.

accumulate put all the accumulators in argument into circuit and accumulate them into a new pair of ec point and return it.

It may be better also to note that change the size of the vec of accumulators argument would lead to a new circuit (vk changed).

fn select_accumulator<'a>(
loader: &Rc<Halo2Loader<'a>>,
condition: &AssignedValue<Fr>,
Expand All @@ -56,6 +63,7 @@ fn select_accumulator<'a>(
))
}

/// Accumulate a value into the current accumulator.
fn accumulate<'a>(
loader: &Rc<Halo2Loader<'a>>,
accumulators: Vec<KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>>,
Expand All @@ -68,21 +76,39 @@ fn accumulate<'a>(

#[derive(Clone)]
pub struct RecursionCircuit<ST> {
/// The verifying key for the circuit.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basicaly it is not the verifying key but the common params of the recursion circuit. The so called "vk" part which identify the circuit is passed via preprocessed_digest PI.

svk: Svk,
/// The default accumulator to initialise the circuit.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the accumulator that recursion circuit begin with when the "previous" snark is not avaliable (aka. round = 0)

default_accumulator: KzgAccumulator<G1Affine, NativeLoader>,
/// The SNARK witness from the k-th BatchCircuit.
app: SnarkWitness,
/// The SNARK witness from the (k-1)-th BatchCircuit.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the previous RecursionCircuit, i.e. the RecursionCircuit which has aggregated (k-1) snarks from BatchCircuit

previous: SnarkWitness,
/// The recursion round, starting at round=0 and incrementing at every subsequent recursion.
round: usize,
/// The public inputs to the RecursionCircuit itself.
instances: Vec<Fr>,
/// The accumulation of the SNARK proofs recursed over thus far.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should say "the proofs of AppCircuit's snark under accumulation scheme"?

as_proof: Value<Vec<u8>>,
app_is_aggregation: bool,

_marker: PhantomData<ST>,
}

impl<ST: StateTransition> RecursionCircuit<ST> {
/// The index of the preprocessed digest in the [`RecursionCircuit`]'s instances. Note that we
/// need a single cell to hold this value as it is a poseidon hash over the bn256 curve, hence
/// it fits within an [`Fr`] cell.
///
/// [`Fr`]: halo2_proofs::halo2curves::bn256::Fr
const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS;

/// The index within the instances to find the "initial" state in the state transition.
const INITIAL_STATE_ROW: usize = Self::PREPROCESSED_DIGEST_ROW + 1;

/// Construct a new instance of the [`RecursionCircuit`] given the SNARKs from the current and
/// previous [`BatchCircuit`], and the recursion round.
///
/// [`BatchCircuit`]: aggregator::BatchCircuit
pub fn new(
params: &ParamsKZG<Bn256>,
app: Snark,
Expand Down Expand Up @@ -180,7 +206,6 @@ impl<ST: StateTransition> RecursionCircuit<ST> {
round,
instances,
as_proof: Value::known(as_proof),
app_is_aggregation: true,
_marker: Default::default(),
}
}
Expand All @@ -206,7 +231,14 @@ impl<ST: StateTransition> RecursionCircuit<ST> {

/// Returns the number of instance cells in the Recursion Circuit, help to refine the CircuitExt trait
pub fn num_instance_fixed() -> usize {
// [..lhs, ..rhs, preprocessed_digest, initial_state, state, round]
// [
// ..lhs (accumulator LHS),
// ..rhs (accumulator RHS),
// preprocessed_digest,
// initial_state,
// state,
// round
// ]
4 * LIMBS + 2 * ST::num_transition_instance() + ST::num_additional_instance() + 2
}
}
Expand All @@ -225,14 +257,13 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
instances: self.instances.clone(),
as_proof: Value::unknown(),
_marker: Default::default(),
app_is_aggregation: self.app_is_aggregation,
}
}

fn configure(meta: &mut ConstraintSystem<Fr>) -> Self::Config {
let path = std::env::var("BUNDLE_CONFIG")
.unwrap_or_else(|_| "configs/bundle_circuit.config".to_owned());
let params: BatchCircuitConfigParams = serde_json::from_reader(
let params: RecursionCircuitConfigParams = serde_json::from_reader(
File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")),
)
.unwrap();
Expand All @@ -251,7 +282,7 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {

let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner
let assigned_instances = layouter.assign_region(
|| "",
|| "recursion circuit",
|region| -> Result<Vec<Cell>, Error> {
if first_pass {
first_pass = false;
Expand All @@ -266,28 +297,38 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
},
);

let init_state_row_beg = Self::INITIAL_STATE_ROW;
let state_row_beg = init_state_row_beg + ST::num_transition_instance();
let addition_state_beg = state_row_beg + ST::num_transition_instance();
let round_row = addition_state_beg + ST::num_additional_instance();
// The index of the "initial state", i.e. the state last finalised on L1.
let index_init_state = Self::INITIAL_STATE_ROW;
// The index of the "state", i.e. the state achieved post the current batch.
let index_state = index_init_state + ST::num_transition_instance();
// The index where the "additional" fields required to define the state are
// present.
let index_additional_state = index_state + ST::num_transition_instance();
// The index to find the "round" of recursion in the current instance of the
// Recursion Circuit.
let index_round = index_additional_state + ST::num_additional_instance();

log::debug!(
"state position: init {}|cur {}|add {}",
state_row_beg,
addition_state_beg,
round_row
"indices within instances: init {} |cur {} | add {} | round {}",
index_init_state,
index_state,
index_additional_state,
index_round,
);

// Get the field elements representing the "preprocessed digest" and "recursion round".
let [preprocessed_digest, round] = [
self.instances[Self::PREPROCESSED_DIGEST_ROW],
self.instances[round_row],
self.instances[index_round],
]
.map(|instance| {
main_gate
.assign_integer(&mut ctx, Value::known(instance))
.unwrap()
});

let initial_state = self.instances[init_state_row_beg..state_row_beg]
// Get the field elements representing the "initial state"
let initial_state = self.instances[index_init_state..index_state]
.iter()
.map(|&instance| {
main_gate
Expand All @@ -296,7 +337,9 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
})
.collect::<Vec<_>>();

let state = self.instances[state_row_beg..round_row]
// Get the field elements representing the "state" post batch. This includes the
// additional state fields as well.
let state = self.instances[index_state..index_round]
.iter()
.map(|&instance| {
main_gate
Expand All @@ -305,6 +348,7 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
})
.collect::<Vec<_>>();

// Whether or not we are in the first round of recursion.
let first_round = main_gate.is_zero(&mut ctx, &round);
let not_first_round = main_gate.not(&mut ctx, Existing(first_round));

Expand All @@ -318,6 +362,8 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
Some(preprocessed_digest),
);

// Choose between the default accumulator or the previous accumulator depending on
// whether or not we are in the first round of recursion.
let default_accumulator = self.load_default_accumulator(&loader)?;
let previous_accumulators = previous_accumulators
.iter()
Expand All @@ -331,6 +377,8 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
})
.collect::<Result<Vec<_>, Error>>()?;

// Accumulate the accumulators over the previous accumulators, to compute the
// accumulator values for this instance of the Recursion Circuit.
let KzgAccumulator { lhs, rhs } = accumulate(
&loader,
[app_accumulators, previous_accumulators].concat(),
Expand All @@ -343,31 +391,39 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
let previous_instances = previous_instances.pop().unwrap();

let mut ctx = loader.ctx_mut();

//////////////////////////////////////////////////////////////////////////////////
/////////////////////////////// CONSTRAINTS //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////

// Propagate the "initial state"
let initial_state_propagate = initial_state
.iter()
.zip_eq(previous_instances[init_state_row_beg..state_row_beg].iter())
.zip_eq(previous_instances[index_init_state..index_state].iter())
.zip_eq(
ST::state_prev_indices()
.into_iter()
.map(|i| &app_instances[i]),
)
.flat_map(|((&st, &previous_st), &app_inst)| {
[
// Propagate initial_state
(
main_gate.mul(&mut ctx, Existing(st), Existing(not_first_round)),
previous_st,
),
// Verify initial_state is same as the first application snark
// Verify initial_state is same as the first application snark in the
// first round of recursion.
(
main_gate.mul(&mut ctx, Existing(st), Existing(first_round)),
main_gate.mul(&mut ctx, Existing(app_inst), Existing(first_round)),
),
// Propagate initial_state for subsequent rounds of recursion.
(
main_gate.mul(&mut ctx, Existing(st), Existing(not_first_round)),
previous_st,
),
]
})
.collect::<Vec<_>>();

// Verify current state is same as the current application snark
// Verify that the current "state" is the same as the state defined in the
// application SNARK.
let verify_app_state = state
.iter()
.zip_eq(
Expand All @@ -383,8 +439,10 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
.map(|(&st, &app_inst)| (st, app_inst))
.collect::<Vec<_>>();

// Verify previous state (additional state not included) is same as the current application snark
let verify_app_init_state = previous_instances[state_row_beg..addition_state_beg]
// Verify that the "previous state" (additional state not included) is the same
// as the previous state defined in the current application SNARK. This check is
// meaningful only in subsequent recursion rounds after the first round.
let verify_app_init_state = previous_instances[index_state..index_additional_state]
.iter()
.zip_eq(
ST::state_prev_indices()
Expand All @@ -399,8 +457,10 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
})
.collect::<Vec<_>>();

// Finally apply the equality constraints between the (LHS, RHS) values constructed
// above.
for (lhs, rhs) in [
// Propagate preprocessed_digest
// Propagate the preprocessed digest.
(
main_gate.mul(
&mut ctx,
Expand All @@ -409,13 +469,13 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
),
previous_instances[Self::PREPROCESSED_DIGEST_ROW],
),
// Verify round is increased by 1 when not at first round
// Verify that "round" increments by 1 when not the first round of recursion.
(
round,
main_gate.add(
&mut ctx,
Existing(not_first_round),
Existing(previous_instances[round_row]),
Existing(previous_instances[index_round]),
),
),
]
Expand All @@ -427,13 +487,15 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
ctx.region.constrain_equal(lhs.cell(), rhs.cell())?;
}

// IMPORTANT:
// Mark the end of this phase.
config.base_field_config.finalize(&mut ctx);

#[cfg(feature = "display")]
dbg!(ctx.total_advice);
#[cfg(feature = "display")]
println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1);

// Return the computed instance cells for this Recursion Circuit.
Ok([lhs.x(), lhs.y(), rhs.x(), rhs.y()]
.into_iter()
.flat_map(|coordinate| coordinate.limbs())
Expand All @@ -447,6 +509,8 @@ impl<ST: StateTransition> Circuit<Fr> for RecursionCircuit<ST> {
)?;

assert_eq!(assigned_instances.len(), self.num_instance()[0]);

// Ensure that the computed instances are in fact the instances for this circuit.
for (row, limb) in assigned_instances.into_iter().enumerate() {
layouter.constrain_instance(limb, config.instance, row)?;
}
Expand Down
Loading