Skip to content

Commit

Permalink
compute actual multiplicity from max_multiplicity
Browse files Browse the repository at this point in the history
This commit defines a function `min_multiplicity` which can compute the actual multiplicity that will be used from `max_multiplicity` and `payload_len`. The original argument `multiplicity` has been renamed to `max_multiplicity` to indicate that this is an upper bound.
  • Loading branch information
akonring committed Aug 23, 2024
1 parent 7cd4f76 commit 6e730c2
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 64 deletions.
130 changes: 78 additions & 52 deletions vid/src/advz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ where
{
recovery_threshold: u32,
num_storage_nodes: u32,
multiplicity: u32,
max_multiplicity: u32,
ck: KzgProverParam<E>,
vk: KzgVerifierParam<E>,
multi_open_domain: Radix2EvaluationDomain<KzgPoint<E>>,
Expand Down Expand Up @@ -131,15 +131,20 @@ where
) -> VidResult<Self> {
// TODO intelligent choice of multiplicity
// https://github.com/EspressoSystems/jellyfish/issues/534
let multiplicity = 1;
let max_multiplicity = 1;

Self::with_multiplicity_internal(num_storage_nodes, recovery_threshold, multiplicity, srs)
Self::with_multiplicity_internal(
num_storage_nodes,
recovery_threshold,
max_multiplicity,
srs,
)
}

pub(crate) fn with_multiplicity_internal(
num_storage_nodes: u32, // n (code rate: r = k/n)
recovery_threshold: u32, // k
multiplicity: u32, // batch m chunks, keep the rate r = (m*k)/(m*n)
max_multiplicity: u32, // batch m chunks, keep the rate r = (m*k)/(m*n)
srs: impl Borrow<KzgSrs<E>>,
) -> VidResult<Self> {
if num_storage_nodes < recovery_threshold {
Expand All @@ -149,15 +154,15 @@ where
)));
}

if !multiplicity.is_power_of_two() {
if !max_multiplicity.is_power_of_two() {
return Err(VidError::Argument(format!(
"multiplicity {multiplicity} should be a power of two"
"max multiplicity {max_multiplicity} should be a power of two"
)));
}

// erasure code params
let chunk_size = multiplicity * recovery_threshold; // message length m
let code_word_size = multiplicity * num_storage_nodes; // code word length n
let chunk_size = max_multiplicity * recovery_threshold; // message length m
let code_word_size = max_multiplicity * num_storage_nodes; // code word length n
let poly_degree = chunk_size - 1;

let (ck, vk) = UnivariateKzgPCS::trim_fft_size(srs, poly_degree as usize).map_err(vid)?;
Expand Down Expand Up @@ -187,7 +192,7 @@ where
Ok(Self {
recovery_threshold,
num_storage_nodes,
multiplicity,
max_multiplicity,
ck,
vk,
multi_open_domain,
Expand Down Expand Up @@ -224,21 +229,28 @@ where
Self::new_internal(num_storage_nodes, recovery_threshold, srs)
}

/// Like [`Advz::new`] except with a `multiplicity` arg.
/// Like [`Advz::new`] except with a `max_multiplicity` arg.
///
/// `multiplicity` is an implementation-specific optimization arg.
/// Each storage node gets `multiplicity` evaluations per polynomial.
/// `max_multiplicity` is an implementation-specific optimization arg.
/// Each storage node gets up to `max_multiplicity` evaluations per
/// polynomial. The actual multiplicity used will be the smallest value m
/// such that payload can fit (payload_len <= m * recovery_threshold).
///
/// # Errors
/// In addition to [`Advz::new`], return [`VidError::Argument`] if
/// - TEMPORARY `multiplicity` is not a power of two [github issue](https://github.com/EspressoSystems/jellyfish/issues/339)
/// - TEMPORARY `max_multiplicity` is not a power of two [github issue](https://github.com/EspressoSystems/jellyfish/issues/339)
pub fn with_multiplicity(
num_storage_nodes: u32,
recovery_threshold: u32,
multiplicity: u32,
max_multiplicity: u32,
srs: impl Borrow<KzgSrs<E>>,
) -> VidResult<Self> {
Self::with_multiplicity_internal(num_storage_nodes, recovery_threshold, multiplicity, srs)
Self::with_multiplicity_internal(
num_storage_nodes,
recovery_threshold,
max_multiplicity,
srs,
)
}
}

Expand All @@ -262,13 +274,13 @@ where
pub fn with_multiplicity(
num_storage_nodes: u32,
recovery_threshold: u32,
multiplicity: u32,
max_multiplicity: u32,
srs: impl Borrow<KzgSrs<E>>,
) -> VidResult<Self> {
let mut advz = Self::with_multiplicity_internal(
num_storage_nodes,
recovery_threshold,
multiplicity,
max_multiplicity,
srs,
)?;
advz.init_gpu_srs()?;
Expand Down Expand Up @@ -307,7 +319,7 @@ where
evals: Vec<KzgEval<E>>,

#[serde(with = "canonical")]
// aggretate_proofs.len() equals self.multiplicity
// aggretate_proofs.len() equals multiplicity
// TODO further aggregate into a single KZG proof.
aggregate_proofs: Vec<KzgProof<E>>,

Expand Down Expand Up @@ -407,8 +419,11 @@ where
B: AsRef<[u8]>,
{
let payload = payload.as_ref();
let payload_byte_len = payload.len().try_into().map_err(vid)?;
let multiplicity = self.min_multiplicity(payload_byte_len, self.max_multiplicity);
let chunk_size = multiplicity * self.recovery_threshold;
let bytes_to_polys_time = start_timer!(|| "encode payload bytes into polynomials");
let polys = self.bytes_to_polys(payload);
let polys = self.bytes_to_polys(payload, chunk_size as usize);
end_timer!(bytes_to_polys_time);

let poly_commits_time = start_timer!(|| "batch poly commit");
Expand All @@ -428,12 +443,13 @@ where
"VID disperse {} payload bytes to {} nodes",
payload_byte_len, self.num_storage_nodes
));
let _chunk_size = self.multiplicity * self.recovery_threshold;
let code_word_size = self.multiplicity * self.num_storage_nodes;
let multiplicity = self.min_multiplicity(payload_byte_len, self.max_multiplicity);
let chunk_size = multiplicity * self.recovery_threshold;
let code_word_size = multiplicity * self.num_storage_nodes;

// partition payload into polynomial coefficients
let bytes_to_polys_time = start_timer!(|| "encode payload bytes into polynomials");
let polys = self.bytes_to_polys(payload);
let polys = self.bytes_to_polys(payload, chunk_size as usize);
end_timer!(bytes_to_polys_time);

// evaluate polynomials
Expand All @@ -442,7 +458,7 @@ where
polys.len(),
_chunk_size
));
let all_storage_node_evals = self.evaluate_polys(&polys)?;
let all_storage_node_evals = self.evaluate_polys(&polys, code_word_size as usize)?;
end_timer!(all_storage_node_evals_timer);

// vector commitment to polynomial evaluations
Expand All @@ -458,7 +474,7 @@ where
all_evals_digest: all_evals_commit.commitment().digest(),
payload_byte_len,
num_storage_nodes: self.num_storage_nodes,
multiplicity: self.multiplicity,
multiplicity,
};
end_timer!(common_timer);

Expand Down Expand Up @@ -489,8 +505,12 @@ where
end_timer!(agg_proofs_timer);

let assemblage_timer = start_timer!(|| "assemble shares for dispersal");
let shares =
self.assemble_shares(all_storage_node_evals, aggregate_proofs, all_evals_commit)?;
let shares = self.assemble_shares(
all_storage_node_evals,
aggregate_proofs,
all_evals_commit,
multiplicity,
)?;
end_timer!(assemblage_timer);
end_timer!(disperse_time);

Expand All @@ -514,21 +534,21 @@ where
common.num_storage_nodes, self.num_storage_nodes
)));
}
if common.multiplicity != self.multiplicity {
let multiplicity = self.min_multiplicity(common.payload_byte_len, self.max_multiplicity);
if common.multiplicity != multiplicity {
return Err(VidError::Argument(format!(
"common multiplicity {} differs from self {}",
common.multiplicity, self.multiplicity
"common multiplicity {} differs from derived min {}",
common.multiplicity, multiplicity
)));
}
let multiplicity: usize = common.multiplicity.try_into().map_err(vid)?;
if share.evals.len() / multiplicity != common.poly_commits.len() {
if share.evals.len() / multiplicity as usize != common.poly_commits.len() {
return Err(VidError::Argument(format!(
"number of share evals / multiplicity {}/{} differs from number of common polynomial commitments {}",
share.evals.len(), multiplicity,
common.poly_commits.len()
)));
}
if share.eval_proofs.len() != multiplicity {
if share.eval_proofs.len() != multiplicity as usize {
return Err(VidError::Argument(format!(
"number of eval_proofs {} differs from common multiplicity {}",
share.eval_proofs.len(),
Expand All @@ -543,10 +563,10 @@ where
}

// verify eval proofs
for i in 0..self.multiplicity {
for i in 0..multiplicity {
if KzgEvalsMerkleTree::<E, H>::verify(
common.all_evals_digest,
&KzgEvalsMerkleTreeIndex::<E, H>::from((share.index * self.multiplicity) + i),
&KzgEvalsMerkleTreeIndex::<E, H>::from((share.index * multiplicity) + i),
&share.eval_proofs[i as usize],
)
.map_err(vid)?
Expand Down Expand Up @@ -577,7 +597,7 @@ where
//
// some boilerplate needed to accommodate builds without `parallel`
// feature.
let multiplicities = Vec::from_iter((0..self.multiplicity as usize));
let multiplicities = Vec::from_iter((0..multiplicity as usize));
let polys_len = common.poly_commits.len();
let verification_iter = parallelizable_slice_iter(&multiplicities).map(|i| {
let range = i * polys_len..(i + 1) * polys_len;
Expand All @@ -601,7 +621,7 @@ where
&aggregate_poly_commit,
&self
.multi_open_domain
.element((share.index as usize * multiplicity) + i),
.element((share.index * multiplicity) as usize + i),
&aggregate_eval,
&share.aggregate_proofs[*i],
)
Expand Down Expand Up @@ -645,6 +665,7 @@ where
.ok_or_else(|| VidError::Argument("shares is empty".into()))?
.evals
.len();
let multiplicity = common.multiplicity;
if let Some((index, share)) = shares
.iter()
.enumerate()
Expand All @@ -658,15 +679,15 @@ where
share.evals.len()
)));
}
if num_evals != self.multiplicity as usize * common.poly_commits.len() {
if num_evals != multiplicity as usize * common.poly_commits.len() {
return Err(VidError::Argument(format!(
"num_evals should be (multiplicity * poly_commits): {} but is instead: {}",
self.multiplicity as usize * common.poly_commits.len(),
multiplicity as usize * common.poly_commits.len(),
num_evals,
)));
}
let chunk_size = self.multiplicity * self.recovery_threshold;
let num_polys = num_evals / self.multiplicity as usize;
let chunk_size = multiplicity * self.recovery_threshold;
let num_polys = num_evals / multiplicity as usize;

let elems_capacity = num_polys * chunk_size as usize;
let mut elems = Vec::with_capacity(elems_capacity);
Expand All @@ -675,9 +696,9 @@ where
for p in 0..num_polys {
for share in shares {
// extract all evaluations for polynomial p from the share
for m in 0..self.multiplicity as usize {
for m in 0..multiplicity as usize {
evals.push((
(share.index * self.multiplicity) as usize + m,
(share.index * multiplicity) as usize + m,
share.evals[(m * num_polys) + p],
))
}
Expand Down Expand Up @@ -741,12 +762,12 @@ where
fn evaluate_polys(
&self,
polys: &[DensePolynomial<<E as Pairing>::ScalarField>],
code_word_size: usize,
) -> Result<Vec<Vec<<E as Pairing>::ScalarField>>, VidError>
where
E: Pairing,
H: HasherDigest,
{
let code_word_size = (self.num_storage_nodes * self.multiplicity) as usize;
let mut all_storage_node_evals = vec![Vec::with_capacity(polys.len()); code_word_size];
// this is to avoid `SrsRef` not implementing `Sync` problem,
// instead of sending entire `self` cross thread, we only send a ref which is
Expand Down Expand Up @@ -809,11 +830,14 @@ where
Ok(PrimeField::from_le_bytes_mod_order(&hasher.finalize()))
}

fn bytes_to_polys(&self, payload: &[u8]) -> Vec<DensePolynomial<<E as Pairing>::ScalarField>>
fn bytes_to_polys(
&self,
payload: &[u8],
chunk_size: usize,
) -> Vec<DensePolynomial<<E as Pairing>::ScalarField>>
where
E: Pairing,
{
let chunk_size = (self.recovery_threshold * self.multiplicity) as usize;
let elem_bytes_len = bytes_to_field::elem_byte_capacity::<<E as Pairing>::ScalarField>();
let eval_domain_ref = &self.eval_domain;

Expand Down Expand Up @@ -858,16 +882,17 @@ where
DenseUVPolynomial::from_coefficients_vec(coeffs_vec)
}

fn polynomial<I>(&self, coeffs: I) -> KzgPolynomial<E>
fn polynomial<I>(&self, coeffs: I, chunk_size: usize) -> KzgPolynomial<E>
where
I: Iterator,
I::Item: Borrow<KzgEval<E>>,
{
Self::polynomial_internal(
&self.eval_domain,
(self.recovery_threshold * self.multiplicity) as usize,
coeffs,
)
Self::polynomial_internal(&self.eval_domain, chunk_size, coeffs)
}

fn min_multiplicity(&self, payload_byte_len: u32, multiplicity: u32) -> u32 {
let _elem_bytes_len = bytes_to_field::elem_byte_capacity::<<E as Pairing>::ScalarField>();
multiplicity
}

/// Derive a commitment from whatever data is needed.
Expand Down Expand Up @@ -914,6 +939,7 @@ where
all_storage_node_evals: Vec<Vec<<E as Pairing>::ScalarField>>,
aggregate_proofs: Vec<UnivariateKzgProof<E>>,
all_evals_commit: KzgEvalsMerkleTree<E, H>,
multiplicity: u32,
) -> Result<Vec<Share<E, H>>, VidError>
where
E: Pairing,
Expand All @@ -937,7 +963,7 @@ where
// split share data into chunks of size multiplicity
Ok(share_data
.into_iter()
.chunks(self.multiplicity as usize)
.chunks(multiplicity as usize)
.into_iter()
.enumerate()
.map(|(index, chunk)| {
Expand Down
8 changes: 5 additions & 3 deletions vid/src/advz/payload_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ where

let elems_iter = bytes_to_field::<_, KzgEval<E>>(&payload[range_poly_byte]);
let mut proofs = Vec::with_capacity(range_poly.len() * points.len());
let chunk_size =
self.min_multiplicity(payload.len() as u32, self.recovery_threshold) as usize;
for (i, evals_iter) in elems_iter
.chunks(self.recovery_threshold as usize)
.into_iter()
.enumerate()
{
let poly = self.polynomial(evals_iter);
let poly = self.polynomial(evals_iter, chunk_size);
let points_range = Range {
// first polynomial? skip to the start of the proof range
start: if i == 0 { offset_elem } else { 0 },
Expand Down Expand Up @@ -260,14 +262,14 @@ where
.chain(proof.suffix_bytes.iter()),
))
.chain(proof.suffix_elems.iter().cloned());

let chunk_size = (stmt.common.multiplicity * self.recovery_threshold) as usize;
// rebuild the poly commits, check against `common`
for (commit_index, evals_iter) in range_poly.into_iter().zip(
elems_iter
.chunks(self.recovery_threshold as usize)
.into_iter(),
) {
let poly = self.polynomial(evals_iter);
let poly = self.polynomial(evals_iter, chunk_size);
let poly_commit = UnivariateKzgPCS::commit(&self.ck, &poly).map_err(vid)?;
if poly_commit != stmt.common.poly_commits[commit_index] {
return Ok(Err(()));
Expand Down
Loading

0 comments on commit 6e730c2

Please sign in to comment.