diff --git a/Cargo.lock b/Cargo.lock index d1dbe26f3..90ceb6dff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1995,6 +1995,7 @@ dependencies = [ "derivative", "http", "http-api-problem", + "itertools", "janus_core", "janus_messages", "mockito", @@ -2985,9 +2986,8 @@ dependencies = [ [[package]] name = "prio" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3675d093a7713f2b861f77b16c3c33fadd6de0a69bf7203014d938b9d5daa6f7" +version = "0.12.0" +source = "git+https://github.com/divviup/libprio-rs.git?rev=54a46230615d28c7e131d0595cc558e1619b8071#54a46230615d28c7e131d0595cc558e1619b8071" dependencies = [ "aes", "base64 0.21.0", diff --git a/Cargo.toml b/Cargo.toml index 6bf399064..56f77ab99 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ version = "0.4.6" # (yet) need other default features. # https://docs.rs/chrono/latest/chrono/#duration chrono = { version = "0.4", default-features = false } +itertools = "0.10" janus_aggregator = { version = "0.4", path = "aggregator" } janus_aggregator_api = { version = "0.4", path = "aggregator_api" } janus_aggregator_core = { version = "0.4", path = "aggregator_core" } @@ -39,7 +40,8 @@ janus_interop_binaries = { version = "0.4", path = "interop_binaries" } janus_messages = { version = "0.4", path = "messages" } k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.82.1", default-features = false, features = ["client", "rustls-tls"] } -prio = { version = "0.12.1", features = ["multithreaded"] } +# prio = { version = "0.12.0", features = ["multithreaded"] } # XXX: go back to a released version of prio once https://github.com/divviup/libprio-rs/commit/54a46230615d28c7e131d0595cc558e1619b8071 is released. +prio = { git = "https://github.com/divviup/libprio-rs.git", rev = "54a46230615d28c7e131d0595cc558e1619b8071", features = ["multithreaded", "experimental"] } trillium = "0.2.8" trillium-api = { version = "0.2.0-rc.2", default-features = false } trillium-caching-headers = "0.2.1" diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index 3b91130ac..859ecd25e 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -106,7 +106,7 @@ uuid = { version = "1.3.1", features = ["v4"] } [dev-dependencies] assert_matches = "1" hyper = "0.14.26" -itertools = "0.10.5" +itertools.workspace = true janus_aggregator = { path = ".", features = ["fpvec_bounded_l2", "test-util"] } janus_aggregator_core = { workspace = true, features = ["test-util"] } mockito = "1.0.2" diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 36d962451..f24447b4a 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -35,7 +35,7 @@ use janus_core::test_util::dummy_vdaf; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, http::response_to_problem_details, - task::{AuthenticationToken, VdafInstance, DAP_AUTH_HEADER, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, DAP_AUTH_HEADER, VERIFY_KEY_LEN}, time::{Clock, DurationExt, IntervalExt, TimeExt}, }; use janus_messages::{ @@ -43,9 +43,9 @@ use janus_messages::{ query_type::{FixedSize, TimeInterval}, AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, - BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeCiphertext, - HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, - PrepareStep, PrepareStepResult, Report, ReportShare, ReportShareError, Role, TaskId, + BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeConfigList, + InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, + PrepareStepResult, Report, ReportShare, ReportShareError, Role, TaskId, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter}, @@ -53,13 +53,14 @@ use opentelemetry::{ }; #[cfg(feature = "fpvec_bounded_l2")] use prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded; -#[cfg(feature = "test-util")] -use prio::vdaf::{PrepareTransition, VdafError}; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, vdaf::{ self, + poplar1::Poplar1, + prg::PrgSha3, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded}, + PrepareTransition, VdafError, }, }; use reqwest::Client; @@ -122,12 +123,16 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "decrypt_failure", "input_share_decode_failure", "public_share_decode_failure", + "prepare_message_decode_failure", + "leader_prep_share_decode_failure", + "helper_prep_share_decode_failure", "continue_mismatch", "accumulate_failure", "finish_mismatch", "helper_step_failure", "plaintext_input_share_decode_failure", "duplicate_extension", + "missing_prepare_message", ] { aggregate_step_failure_counter.add( &Context::current(), @@ -523,6 +528,12 @@ impl TaskAggregator { VdafOps::Prio3FixedPoint64BitBoundedL2VecSum(Arc::new(vdaf), verify_key) } + VdafInstance::Poplar1 { bits } => { + let vdaf = Poplar1::new_sha3(*bits); + let verify_key = task.primary_vdaf_verify_key()?; + VdafOps::Poplar1(Arc::new(vdaf), verify_key) + } + #[cfg(feature = "test-util")] VdafInstance::Fake => VdafOps::Fake(Arc::new(dummy_vdaf::Vdaf::new())), @@ -538,7 +549,7 @@ impl TaskAggregator { #[cfg(feature = "test-util")] VdafInstance::FakeFailsPrepStep => { VdafOps::Fake(Arc::new(dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result, VdafError> { + |_| -> Result, VdafError> { Err(VdafError::Uncategorized( "FakeFailsPrepStep failed at prep_step".to_string(), )) @@ -683,33 +694,29 @@ impl TaskAggregator { /// VdafOps stores VDAF-specific operations for a TaskAggregator in a non-generic way. #[allow(clippy::enum_variant_names)] enum VdafOps { - Prio3Count(Arc, VerifyKey), - Prio3CountVec( - Arc, - VerifyKey, - ), - Prio3Sum(Arc, VerifyKey), - Prio3SumVec( - Arc, - VerifyKey, - ), - Prio3Histogram(Arc, VerifyKey), + Prio3Count(Arc, VerifyKey), + Prio3CountVec(Arc, VerifyKey), + Prio3Sum(Arc, VerifyKey), + Prio3SumVec(Arc, VerifyKey), + Prio3Histogram(Arc, VerifyKey), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint32BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint64BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), + Poplar1(Arc>, VerifyKey), + #[cfg(feature = "test-util")] Fake(Arc), } @@ -720,13 +727,13 @@ enum VdafOps { /// specify the VDAF's type, and the name of a const that will be set to the VDAF's verify key /// length, also for explicitly specifying type parameters. macro_rules! vdaf_ops_dispatch { - ($vdaf_ops:expr, ($vdaf:pat_param, $verify_key:pat_param, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + ($vdaf_ops:expr, ($vdaf:pat_param, $verify_key:pat_param, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_ops { crate::aggregator::VdafOps::Prio3Count(vdaf, verify_key) => { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -734,7 +741,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -742,7 +749,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -750,7 +757,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -758,7 +765,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -768,7 +775,7 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -778,7 +785,7 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -788,7 +795,15 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; + $body + } + + crate::aggregator::VdafOps::Poplar1(vdaf, verify_key) => { + let $vdaf = vdaf; + let $verify_key = verify_key; + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -797,7 +812,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = &VerifyKey::new([]); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } } @@ -816,8 +831,8 @@ impl VdafOps { ) -> Result<(), Arc> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_upload_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_upload_generic::( Arc::clone(vdaf), clock, upload_decrypt_failure_counter, @@ -830,8 +845,8 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_upload_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_upload_generic::( Arc::clone(vdaf), clock, upload_decrypt_failure_counter, @@ -859,8 +874,8 @@ impl VdafOps { ) -> Result { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_init_generic::( + vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_init_generic::( datastore, vdaf, aggregate_step_failure_counter, @@ -874,8 +889,8 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_init_generic::( + vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_init_generic::( datastore, vdaf, aggregate_step_failure_counter, @@ -903,8 +918,8 @@ impl VdafOps { ) -> Result { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_continue_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_continue_generic::( datastore, Arc::clone(vdaf), aggregate_step_failure_counter, @@ -918,8 +933,8 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_continue_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_continue_generic::( datastore, Arc::clone(vdaf), aggregate_step_failure_counter, @@ -953,24 +968,16 @@ impl VdafOps { C: Clock, Q: UploadableQueryType, { - // The leader's report is the first one. - // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 - if report.encrypted_input_shares().len() != 2 { - return Err(Arc::new(Error::UnrecognizedMessage( - Some(*task.id()), - "unexpected number of encrypted shares in report", - ))); - } - let leader_encrypted_input_share = - &report.encrypted_input_shares()[Role::Leader.index().unwrap()]; - // Verify that the report's HPKE config ID is known. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 let hpke_keypair = task .hpke_keys() - .get(leader_encrypted_input_share.config_id()) + .get(report.leader_encrypted_input_share().config_id()) .ok_or_else(|| { - Error::OutdatedHpkeConfig(*task.id(), *leader_encrypted_input_share.config_id()) + Error::OutdatedHpkeConfig( + *task.id(), + *report.leader_encrypted_input_share().config_id(), + ) })?; let report_deadline = clock @@ -1038,7 +1045,7 @@ impl VdafOps { hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, task.role()), - leader_encrypted_input_share, + report.leader_encrypted_input_share(), &InputShareAad::new( *task.id(), report.metadata().clone(), @@ -1080,16 +1087,13 @@ impl VdafOps { } }; - let helper_encrypted_input_share = - &report.encrypted_input_shares()[Role::Helper.index().unwrap()]; - let report = LeaderStoredReport::new( *task.id(), report.metadata().clone(), public_share, Vec::from(leader_plaintext_input_share.extensions()), leader_input_share, - helper_encrypted_input_share.clone(), + report.helper_encrypted_input_share().clone(), ); report_writer @@ -1224,9 +1228,9 @@ impl VdafOps { // If two ReportShare messages have the same report ID, then the helper MUST abort with // error "unrecognizedMessage". (§4.4.4.1) - let mut seen_report_ids = HashSet::with_capacity(req.report_shares().len()); - for share in req.report_shares() { - if !seen_report_ids.insert(share.metadata().id()) { + let mut seen_report_ids = HashSet::with_capacity(req.report_inits().len()); + for report_init in req.report_inits() { + if !seen_report_ids.insert(*report_init.report_share().metadata().id()) { return Err(Error::UnrecognizedMessage( Some(*task.id()), "aggregate request contains duplicate report IDs", @@ -1238,13 +1242,18 @@ impl VdafOps { let mut saw_continue = false; let mut report_share_data = Vec::new(); let agg_param = A::AggregationParam::get_decoded(req.aggregation_parameter())?; - for (ord, report_share) in req.report_shares().iter().enumerate() { + for (ord, report_init) in req.report_inits().iter().enumerate() { let hpke_keypair = task .hpke_keys() - .get(report_share.encrypted_input_share().config_id()) + .get( + report_init + .report_share() + .encrypted_input_share() + .config_id(), + ) .ok_or_else(|| { info!( - config_id = %report_share.encrypted_input_share().config_id(), + config_id = %report_init.report_share().encrypted_input_share().config_id(), "Helper encrypted input share references unknown HPKE config ID" ); aggregate_step_failure_counter.add( @@ -1261,18 +1270,18 @@ impl VdafOps { hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - report_share.encrypted_input_share(), + report_init.report_share().encrypted_input_share(), &InputShareAad::new( *task.id(), - report_share.metadata().clone(), - report_share.public_share().to_vec(), + report_init.report_share().metadata().clone(), + report_init.report_share().public_share().to_vec(), ) .get_encoded(), ) .map_err(|error| { info!( task_id = %task.id(), - metadata = ?report_share.metadata(), + metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decrypt helper's report share" ); @@ -1287,7 +1296,7 @@ impl VdafOps { let plaintext_input_share = plaintext.and_then(|plaintext| { let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext).map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's plaintext input share"); + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decode helper's plaintext input share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "plaintext_input_share_decode_failure")]); ReportShareError::UnrecognizedMessage })?; @@ -1297,7 +1306,7 @@ impl VdafOps { .extensions() .iter() .all(|extension| extension_types.insert(extension.extension_type())) { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), "Received report share with duplicate extensions"); + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), "Received report share with duplicate extensions"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "duplicate_extension")]); return Err(ReportShareError::UnrecognizedMessage) } @@ -1307,14 +1316,14 @@ impl VdafOps { let input_share = plaintext_input_share.and_then(|plaintext_input_share| { A::InputShare::get_decoded_with_param(&(vdaf, Role::Helper.index().unwrap()), plaintext_input_share.payload()) .map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's input share"); + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decode helper's input share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "input_share_decode_failure")]); ReportShareError::UnrecognizedMessage }) }); - let public_share = A::PublicShare::get_decoded_with_param(vdaf, report_share.public_share()).map_err(|error|{ - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode public share"); + let public_share = A::PublicShare::get_decoded_with_param(vdaf, report_init.report_share().public_share()).map_err(|error|{ + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decode public share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "public_share_decode_failure")]); ReportShareError::UnrecognizedMessage }); @@ -1324,81 +1333,109 @@ impl VdafOps { // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF // associated with the task and computes the first state transition. [...] If either // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) - let init_rslt = shares.and_then(|(public_share, input_share)| { + let init_state= shares.and_then(|(public_share, input_share)| { vdaf .prepare_init( verify_key.as_bytes(), Role::Helper.index().unwrap(), &agg_param, - report_share.metadata().id().as_ref(), + report_init.report_share().metadata().id().as_ref(), &public_share, &input_share, ) .map_err(|error| { - info!(task_id = %task.id(), report_id = %report_share.metadata().id(), ?error, "Couldn't prepare_init report share"); + info!(task_id = %task.id(), report_id = %report_init.report_share().metadata().id(), ?error, "Couldn't prepare_init report share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_init_failure")]); ReportShareError::VdafPrepError }) }); - report_share_data.push(match init_rslt { - Ok((prep_state, prep_share)) => { - saw_continue = true; + // Next, the Helper computes the next prepare message, then steps its own VDAF + // preparation state. + let stepped = init_state.and_then(|(prep_state, helper_prep_share)| { + let leader_prep_share = A::PrepareShare::get_decoded_with_param( + &prep_state, + report_init.leader_prep_share(), + ).map_err(|err| { + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?err, "Couldn't decode Leader prepare share"); + aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "leader_prep_share_decode_failure")]); + ReportShareError::UnrecognizedMessage + })?; - let encoded_prep_share = prep_share.get_encoded(); - ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - Some(PrepareStep::new( - *report_share.metadata().id(), - PrepareStepResult::Continued(encoded_prep_share), - )), - ReportAggregationState::::Waiting(prep_state, None), - ), - ) - } + let prep_msg = vdaf.prepare_preprocess([leader_prep_share, helper_prep_share]) + .map_err(|err| { + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?err, "Couldn't compute prepare message"); + aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_message_failure")]); + ReportShareError::VdafPrepError + })?; - Err(err) => ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - Some(PrepareStep::new( - *report_share.metadata().id(), - PrepareStepResult::Failed(err), - )), - ReportAggregationState::::Failed(err), + let prep_msg_encoded = prep_msg.get_encoded(); + let helper_next_transition = vdaf.prepare_step(prep_state, prep_msg) + .map_err(|err| { + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?err, "Prepare step failed"); + aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_step_failure")]); + ReportShareError::VdafPrepError + })?; + + Ok(match helper_next_transition { + PrepareTransition::Continue(prep_state, prep_share) => { + saw_continue = true; + let prep_share_encoded = prep_share.get_encoded(); + (ReportAggregationState::::Waiting(prep_state, None), + PrepareStepResult::Continued{ + prep_msg: prep_msg_encoded, + prep_share: prep_share_encoded, + }) + }, + PrepareTransition::Finish(out_share) => ( + ReportAggregationState::::Finished(out_share), + PrepareStepResult::Finished { prep_msg: prep_msg_encoded }, ), - ), + }) }); + + let (report_aggregation_state, prep_step_rslt) = stepped.unwrap_or_else(|err| { + ( + ReportAggregationState::::Failed(err), + PrepareStepResult::Failed(err), + ) + }); + + report_share_data.push(ReportShareData::new( + report_init.report_share().clone(), + ReportAggregation::::new( + *task.id(), + *aggregation_job_id, + *report_init.report_share().metadata().id(), + *report_init.report_share().metadata().time(), + ord.try_into()?, + Some(PrepareStep::new( + *report_init.report_share().metadata().id(), + prep_step_rslt, + )), + report_aggregation_state, + ), + )); } // Store data to datastore. let req = Arc::new(req); let min_client_timestamp = req - .report_shares() + .report_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|report_init| *report_init.report_share().metadata().time()) .min() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let max_client_timestamp = req - .report_shares() + .report_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|report_init| *report_init.report_share().metadata().time()) .max() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let client_timestamp_interval = Interval::new( - *min_client_timestamp, + min_client_timestamp, max_client_timestamp - .difference(min_client_timestamp)? + .difference(&min_client_timestamp)? .add(&Duration::from_seconds(1))?, )?; let aggregation_job = Arc::new(AggregationJob::::new( @@ -1497,8 +1534,6 @@ impl VdafOps { Err(e) => return Err(e), }; - // Construct a response and write any new report shares and report aggregations - // as we go. if !replayed_request { let mut accumulator = Accumulator::::new( Arc::clone(&task), @@ -1689,9 +1724,9 @@ impl VdafOps { ) -> Result<(), Error> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_create_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -1700,9 +1735,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_create_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -1836,9 +1871,9 @@ impl VdafOps { ) -> Result>, Error> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_get_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -1847,9 +1882,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_get_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -1979,10 +2014,8 @@ impl VdafOps { ), *report_count, spanned_interval, - Vec::::from([ - encrypted_leader_aggregate_share, - encrypted_helper_aggregate_share.clone(), - ]), + encrypted_leader_aggregate_share, + encrypted_helper_aggregate_share.clone(), ) .get_encoded(), )) @@ -2010,9 +2043,9 @@ impl VdafOps { ) -> Result<(), Error> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_delete_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -2021,9 +2054,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_delete_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -2087,9 +2120,9 @@ impl VdafOps { ) -> Result { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_aggregate_share_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -2098,9 +2131,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_aggregate_share_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -2454,25 +2487,23 @@ mod tests { }; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, - test_util::{dummy_vdaf, install_test_trace_subscriber}, + task::{VdafInstance, VERIFY_KEY_LEN}, + test_util::{install_test_trace_subscriber, VdafTranscript}, time::{Clock, MockClock, TimeExt}, }; use janus_messages::{ query_type::TimeInterval, Duration, Extension, HpkeCiphertext, HpkeConfig, HpkeConfigId, InputShareAad, Interval, PlaintextInputShare, Report, ReportId, ReportMetadata, - ReportShare, Role, TaskId, Time, + ReportPrepInit, ReportShare, Role, TaskId, Time, }; use opentelemetry::global::meter; use prio::{ codec::Encode, - vdaf::{self, prio3::Prio3Count, Client as _}, + vdaf::{self, prio3::Prio3Count, Client as VdafClient}, }; use rand::random; use std::{collections::HashSet, iter, sync::Arc, time::Duration as StdDuration}; - pub(super) const DUMMY_VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; - pub(crate) fn default_aggregator_config() -> Config { // Enable upload write batching & batch aggregation sharding by default, in hopes that we // can shake out any bugs. @@ -2520,7 +2551,8 @@ mod tests { Report::new( report_metadata, public_share.get_encoded(), - Vec::from([leader_ciphertext, helper_ciphertext]), + leader_ciphertext, + helper_ciphertext, ) } @@ -2661,29 +2693,6 @@ mod tests { assert_eq!(want_report_ids, got_report_ids); } - #[tokio::test] - async fn upload_wrong_number_of_encrypted_shares() { - install_test_trace_subscriber(); - - let (_, aggregator, clock, task, _, _ephemeral_datastore) = - setup_upload_test(default_aggregator_config()).await; - let report = create_report(&task, clock.now()); - let report = Report::new( - report.metadata().clone(), - report.public_share().to_vec(), - Vec::from([report.encrypted_input_shares()[0].clone()]), - ); - - assert_matches!( - aggregator - .handle_upload(task.id(), &report.get_encoded()) - .await - .unwrap_err() - .as_ref(), - Error::UnrecognizedMessage(_, _) - ); - } - #[tokio::test] async fn upload_wrong_hpke_config_id() { install_test_trace_subscriber(); @@ -2700,16 +2709,15 @@ mod tests { let report = Report::new( report.metadata().clone(), report.public_share().to_vec(), - Vec::from([ - HpkeCiphertext::new( - unused_hpke_config_id, - report.encrypted_input_shares()[0] - .encapsulated_key() - .to_vec(), - report.encrypted_input_shares()[0].payload().to_vec(), - ), - report.encrypted_input_shares()[1].clone(), - ]), + HpkeCiphertext::new( + unused_hpke_config_id, + report + .leader_encrypted_input_share() + .encapsulated_key() + .to_vec(), + report.leader_encrypted_input_share().payload().to_vec(), + ), + report.helper_encrypted_input_share().clone(), ); assert_matches!(aggregator.handle_upload(task.id(), &report.get_encoded()).await.unwrap_err().as_ref(), Error::OutdatedHpkeConfig(task_id, config_id) => { @@ -2794,17 +2802,15 @@ mod tests { .run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.put_collection_job(&CollectionJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - random(), - batch_interval, - (), - CollectionJobState::Start, - )) + tx.put_collection_job( + &CollectionJob::::new( + *task.id(), + random(), + batch_interval, + (), + CollectionJobState::Start, + ), + ) .await }) }) @@ -2819,40 +2825,52 @@ mod tests { }); } - pub(crate) fn generate_helper_report_share>( + pub(crate) fn generate_helper_report_init< + const SEED_SIZE: usize, + V: vdaf::Client<16> + vdaf::Aggregator, + >( task_id: TaskId, report_metadata: ReportMetadata, cfg: &HpkeConfig, - public_share: &V::PublicShare, + transcript: &VdafTranscript, extensions: Vec, - input_share: &V::InputShare, - ) -> ReportShare { - generate_helper_report_share_for_plaintext( + ) -> ReportPrepInit { + generate_helper_report_init_for_plaintext( report_metadata.clone(), cfg, - public_share.get_encoded(), - &PlaintextInputShare::new(extensions, input_share.get_encoded()).get_encoded(), - &InputShareAad::new(task_id, report_metadata, public_share.get_encoded()).get_encoded(), + transcript.public_share.get_encoded(), + &PlaintextInputShare::new(extensions, transcript.input_shares[1].get_encoded()) + .get_encoded(), + &InputShareAad::new( + task_id, + report_metadata, + transcript.public_share.get_encoded(), + ) + .get_encoded(), + transcript.leader_prep_state(0).1.get_encoded(), ) } - - pub(super) fn generate_helper_report_share_for_plaintext( + pub(crate) fn generate_helper_report_init_for_plaintext( metadata: ReportMetadata, cfg: &HpkeConfig, encoded_public_share: Vec, plaintext: &[u8], associated_data: &[u8], - ) -> ReportShare { - ReportShare::new( - metadata, - encoded_public_share, - hpke::seal( - cfg, - &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - plaintext, - associated_data, - ) - .unwrap(), + leader_prep_share: Vec, + ) -> ReportPrepInit { + ReportPrepInit::new( + ReportShare::new( + metadata, + encoded_public_share, + hpke::seal( + cfg, + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), + plaintext, + associated_data, + ) + .unwrap(), + ), + leader_prep_share, ) } } diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index be7a73532..e4d2ddcd3 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -1,4 +1,4 @@ -use crate::aggregator::{aggregator_handler, tests::generate_helper_report_share, Config}; +use crate::aggregator::{aggregator_handler, tests::generate_helper_report_init, Config}; use janus_aggregator_core::{ datastore::{ test_util::{ephemeral_datastore, EphemeralDatastore}, @@ -13,87 +13,94 @@ use janus_core::{ }; use janus_messages::{ query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, PartialBatchSelector, - ReportMetadata, ReportShare, Role, + ReportMetadata, ReportPrepInit, Role, }; -use prio::codec::Encode; +use prio::{codec::Encode, vdaf}; use rand::random; use std::sync::Arc; use trillium::{Handler, KnownHeaderName, Status}; use trillium_testing::{prelude::put, TestConn}; -pub(super) struct ReportShareGenerator { +pub(super) struct ReportInitGenerator +where + V: vdaf::Vdaf, +{ clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, - vdaf: dummy_vdaf::Vdaf, + vdaf: V, + aggregation_param: V::AggregationParam, } -impl ReportShareGenerator { +impl ReportInitGenerator +where + V: vdaf::Vdaf + vdaf::Aggregator + vdaf::Client<16>, +{ pub(super) fn new( clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, + vdaf: V, + aggregation_param: V::AggregationParam, ) -> Self { Self { clock, task, + vdaf, aggregation_param, - vdaf: dummy_vdaf::Vdaf::new(), } } - fn with_vdaf(mut self, vdaf: dummy_vdaf::Vdaf) -> Self { - self.vdaf = vdaf; - self - } - - pub(super) fn next(&self) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { - self.next_with_metadata(ReportMetadata::new( - random(), - self.clock - .now() - .to_batch_interval_start(self.task.time_precision()) - .unwrap(), - )) + pub(super) fn next( + &self, + measurement: &V::Measurement, + ) -> (ReportPrepInit, VdafTranscript) { + self.next_with_metadata( + ReportMetadata::new( + random(), + self.clock + .now() + .to_batch_interval_start(self.task.time_precision()) + .unwrap(), + ), + measurement, + ) } pub(super) fn next_with_metadata( &self, report_metadata: ReportMetadata, - ) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { + measurement: &V::Measurement, + ) -> (ReportPrepInit, VdafTranscript) { let transcript = run_vdaf( &self.vdaf, self.task.primary_vdaf_verify_key().unwrap().as_bytes(), &self.aggregation_param, report_metadata.id(), - &(), + measurement, ); - let report_share = generate_helper_report_share::( + let report_init = generate_helper_report_init::( *self.task.id(), report_metadata, self.task.current_hpke_key().config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - - (report_share, transcript) + (report_init, transcript) } } -pub(super) struct AggregationJobInitTestCase { +pub(super) struct AggregationJobInitTestCase { pub(super) clock: MockClock, pub(super) task: Task, - pub(super) report_share_generator: ReportShareGenerator, - pub(super) report_shares: Vec, + pub(super) report_init_generator: ReportInitGenerator, + pub(super) report_inits: Vec, pub(super) aggregation_job_id: AggregationJobId, - pub(super) aggregation_param: dummy_vdaf::AggregationParam, + pub(super) aggregation_param: V::AggregationParam, pub(super) handler: Box, pub(super) datastore: Arc>, _ephemeral_datastore: EphemeralDatastore, } -pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { +pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy_vdaf::Vdaf> { install_test_trace_subscriber(); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); @@ -108,19 +115,23 @@ pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { let aggregation_param = dummy_vdaf::AggregationParam(0); - let report_share_generator = - ReportShareGenerator::new(clock.clone(), task.clone(), aggregation_param); + let report_init_generator = ReportInitGenerator::new( + clock.clone(), + task.clone(), + dummy_vdaf::Vdaf::new(), + aggregation_param, + ); - let report_shares = Vec::from([ - report_share_generator.next().0, - report_share_generator.next().0, + let report_inits = Vec::from([ + report_init_generator.next(&0).0, + report_init_generator.next(&0).0, ]); let aggregation_job_id = random(); let aggregation_job_init_req = AggregationJobInitializeReq::new( aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - report_shares.clone(), + report_inits.clone(), ); let response = put_aggregation_job( @@ -135,8 +146,8 @@ pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase { AggregationJobInitTestCase { clock, task, - report_shares, - report_share_generator, + report_inits, + report_init_generator, aggregation_job_id, aggregation_param, handler: Box::new(handler), @@ -173,7 +184,7 @@ async fn aggregation_job_mutation_aggregation_job() { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(1).get_encoded(), PartialBatchSelector::new_time_interval(), - test_case.report_shares, + test_case.report_inits, ); let response = put_aggregation_job( @@ -192,28 +203,28 @@ async fn aggregation_job_mutation_report_shares() { // Put the aggregation job again, mutating the associated report shares' metadata such that // uniqueness constraints on client_reports are violated - for mutated_report_shares in [ + for mutated_report_inits in [ // Omit a report share that was included previously - Vec::from(&test_case.report_shares[0..test_case.report_shares.len() - 1]), + Vec::from(&test_case.report_inits[0..test_case.report_inits.len() - 1]), // Include a different report share than was included previously [ - &test_case.report_shares[0..test_case.report_shares.len() - 1], - &[test_case.report_share_generator.next().0], + &test_case.report_inits[0..test_case.report_inits.len() - 1], + &[test_case.report_init_generator.next(&0).0], ] .concat(), // Include an extra report share than was included previously [ - test_case.report_shares.as_slice(), - &[test_case.report_share_generator.next().0], + test_case.report_inits.as_slice(), + &[test_case.report_init_generator.next(&0).0], ] .concat(), // Reverse the order of the reports - test_case.report_shares.into_iter().rev().collect(), + test_case.report_inits.into_iter().rev().collect(), ] { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( test_case.aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - mutated_report_shares, + mutated_report_inits, ); let response = put_aggregation_job( &test_case.task, @@ -230,18 +241,16 @@ async fn aggregation_job_mutation_report_shares() { async fn aggregation_job_mutation_report_aggregations() { let test_case = setup_aggregate_init_test().await; - // Generate some new reports using the existing reports' metadata, but varying the input shares - // such that the prepare state computed during aggregation initializaton won't match the first - // aggregation job. - let mutated_report_shares_generator = test_case - .report_share_generator - .with_vdaf(dummy_vdaf::Vdaf::new().with_input_share(dummy_vdaf::InputShare(1))); - let mutated_report_shares = test_case - .report_shares + // Generate some new reports using the existing reports' metadata, but varying the measurement + // values such that the prepare state computed during aggregation initializaton won't match the + // first aggregation job. + let mutated_report_inits = test_case + .report_inits .iter() .map(|s| { - mutated_report_shares_generator - .next_with_metadata(s.metadata().clone()) + test_case + .report_init_generator + .next_with_metadata(s.report_share().metadata().clone(), &1) .0 }) .collect(); @@ -249,7 +258,7 @@ async fn aggregation_job_mutation_report_aggregations() { let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( test_case.aggregation_param.get_encoded(), PartialBatchSelector::new_time_interval(), - mutated_report_shares, + mutated_report_inits, ); let response = put_aggregation_job( &test_case.task, diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 72a433c92..58f71c3b8 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -116,57 +116,124 @@ impl VdafOps { } }; - // Parse preparation message out of prepare step received from leader. - let prep_msg = match prep_step.result() { - PrepareStepResult::Continued(payload) => A::PrepareMessage::decode_with_param( - prep_state, - &mut Cursor::new(payload.as_ref()), - )?, - _ => { - return Err(datastore::Error::User( - Error::UnrecognizedMessage( - Some(*task.id()), - "leader sent non-Continued prepare step", - ) - .into(), - )); + // Parse preparation message out of prepare step received from Leader. + let prep_msg = A::PrepareMessage::decode_with_param( + prep_state, + &mut Cursor::new(match prep_step.result() { + PrepareStepResult::Continued { prep_msg, .. } => prep_msg, + PrepareStepResult::Finished { prep_msg } => prep_msg, + _ => { + return Err(datastore::Error::User( + Error::UnrecognizedMessage( + Some(*task.id()), + "leader sent non-Continued/Finished prepare step", + ) + .into(), + )); + } + }), + )?; + + // Compute the next transition; if we're finished, we terminate here. Otherwise, + // retrieve our updated state as well as the leader & helper prepare shares. + let (prep_state, leader_prep_share, helper_prep_share) = match vdaf + .prepare_step(prep_state.clone(), prep_msg) + { + Ok(PrepareTransition::Continue(prep_state, helper_prep_share)) => { + if let PrepareStepResult::Continued { prep_share, .. } = prep_step.result() { + let leader_prep_share = match A::PrepareShare::get_decoded_with_param( + &prep_state, + prep_share, + ) { + Ok(leader_prep_share) => leader_prep_share, + Err(err) => { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + ?err, + "Couldn't parse Leader's prepare share" + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } + }; + (prep_state, leader_prep_share, helper_prep_share) + } else { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + "Leader finished but Helper did not", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } } - }; - // Compute the next transition. - match vdaf.prepare_step(prep_state.clone(), prep_msg) { - Ok(PrepareTransition::Continue(prep_state, prep_share)) => { - *report_aggregation = report_aggregation - .clone() - .with_state(ReportAggregationState::Waiting(prep_state, None)) - .with_last_prep_step(Some(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Continued(prep_share.get_encoded()), - ))); - } + Ok(PrepareTransition::Finish(out_share)) => { + // If we finished but the Leader didn't, fail out. + if !matches!(prep_step.result(), PrepareStepResult::Finished { .. }) { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + "Helper finished but Leader did not", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } - Ok(PrepareTransition::Finish(output_share)) => { + // If both aggregators finished here, record our output share & respond with + // a finished message. accumulator.update( helper_aggregation_job.partial_batch_identifier(), prep_step.report_id(), report_aggregation.time(), - &output_share, + &out_share, )?; *report_aggregation = report_aggregation .clone() - .with_state(ReportAggregationState::Finished(output_share)) + .with_state(ReportAggregationState::Finished(out_share)) .with_last_prep_step(Some(PrepareStep::new( *prep_step.report_id(), - PrepareStepResult::Finished, + PrepareStepResult::Finished { + prep_msg: Vec::new(), + }, ))); + continue; } - Err(error) => { + Err(err) => { info!( task_id = %task.id(), job_id = %helper_aggregation_job.id(), report_id = %prep_step.report_id(), - ?error, "Prepare step failed", + ?err, "Prepare step failed", ); aggregate_step_failure_counter.add( &Context::current(), @@ -181,9 +248,82 @@ impl VdafOps { .with_last_prep_step(Some(PrepareStep::new( *prep_step.report_id(), PrepareStepResult::Failed(ReportShareError::VdafPrepError), - ))) + ))); + continue; + } + }; + + // Merge the leader & helper prepare shares into the next message. + let prep_msg = match vdaf.prepare_preprocess([leader_prep_share, helper_prep_share]) { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + ?err, + "Couldn't compute prepare message", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; } }; + + // Compute the next step based on the merged message. + let encoded_prep_msg = prep_msg.get_encoded(); + match vdaf.prepare_step(prep_state, prep_msg) { + Ok(PrepareTransition::Continue(prep_state, helper_prep_share)) => { + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Waiting(prep_state, None)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Continued { + prep_msg: encoded_prep_msg, + prep_share: helper_prep_share.get_encoded(), + }, + ))); + } + + Ok(PrepareTransition::Finish(out_share)) => { + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Finished(out_share)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Finished { + prep_msg: encoded_prep_msg, + }, + ))); + } + + Err(err) => { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + ?err, + "Prepare step failed", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + } + } } for report_agg in report_aggregations_iter { @@ -202,13 +342,13 @@ impl VdafOps { let saw_continue = report_aggregations.iter().any(|report_agg| { matches!( report_agg.last_prep_step().map(PrepareStep::result), - Some(PrepareStepResult::Continued(_)) + Some(PrepareStepResult::Continued { .. }) ) }); let saw_finish = report_aggregations.iter().any(|report_agg| { matches!( report_agg.last_prep_step().map(PrepareStep::result), - Some(PrepareStepResult::Finished) + Some(PrepareStepResult::Finished { .. }) ) }); let helper_aggregation_job = helper_aggregation_job @@ -235,21 +375,23 @@ impl VdafOps { try_join_all( report_aggregations .iter() - .map(|ra| tx.update_report_aggregation(ra)) + .map(|report_agg| tx.update_report_aggregation(report_agg)), ), - accumulator.flush_to_datastore(tx, &vdaf), + accumulator.flush_to_datastore(tx, &vdaf) )?; - Ok(Self::aggregation_job_resp_for(report_aggregations)) + Ok(Self::aggregation_job_resp_for::( + report_aggregations, + )) } - /// Constructs an AggregationJobResp from a given set of Helper report aggregations. - pub(super) fn aggregation_job_resp_for< - const SEED_SIZE: usize, - A: vdaf::Aggregator, - >( + /// Construct an AggregationJobResp from a given set of Helper report aggregations. + pub(super) fn aggregation_job_resp_for( report_aggregations: impl IntoIterator>, - ) -> AggregationJobResp { + ) -> AggregationJobResp + where + A: vdaf::Aggregator, + { AggregationJobResp::new( report_aggregations .into_iter() @@ -363,7 +505,7 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - aggregate_init_tests::ReportShareGenerator, + aggregate_init_tests::ReportInitGenerator, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, post_aggregation_job_expecting_status, @@ -382,23 +524,34 @@ mod tests { task::{test_util::TaskBuilder, QueryType, Task}, }; use janus_core::{ - task::VdafInstance, - test_util::{dummy_vdaf, install_test_trace_subscriber}, + task::{VdafInstance, VERIFY_KEY_LEN}, + test_util::install_test_trace_subscriber, time::{IntervalExt, MockClock}, }; use janus_messages::{ query_type::TimeInterval, AggregationJobContinueReq, AggregationJobId, AggregationJobResp, AggregationJobRound, Interval, PrepareStep, PrepareStepResult, Role, }; - use prio::codec::Encode; + use prio::{ + codec::Encode, + idpf::IdpfInput, + vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + Vdaf, + }, + }; use rand::random; use std::sync::Arc; use trillium::{Handler, Status}; - struct AggregationJobContinueTestCase { + struct AggregationJobContinueTestCase + where + V: Vdaf, + { task: Task, datastore: Arc>, - report_generator: ReportShareGenerator, + report_generator: ReportInitGenerator, aggregation_job_id: AggregationJobId, first_continue_request: AggregationJobContinueReq, first_continue_response: Option, @@ -408,60 +561,81 @@ mod tests { /// Set up a helper with an aggregation job in round 0 #[allow(clippy::unit_arg)] - async fn setup_aggregation_job_continue_test() -> AggregationJobContinueTestCase { + async fn setup_aggregation_job_continue_test( + ) -> AggregationJobContinueTestCase> { // Prepare datastore & request. install_test_trace_subscriber(); let aggregation_job_id = random(); - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let report_generator = ReportShareGenerator::new( + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let report_generator = ReportInitGenerator::new( clock.clone(), task.clone(), - dummy_vdaf::AggregationParam::default(), + Poplar1::new_sha3(1), + aggregation_param.clone(), ); - let report = report_generator.next(); + let (report_init, transcript) = report_generator.next(&IdpfInput::from_bools(&[true])); datastore .run_tx(|tx| { - let (task, report) = (task.clone(), report.clone()); + let (task, aggregation_param, report_init, transcript) = ( + task.clone(), + aggregation_param.clone(), + report_init.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await.unwrap(); - tx.put_report_share(task.id(), &report.0).await.unwrap(); + tx.put_report_share(task.id(), report_init.report_share()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LEN, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::from_time(report_init.report_share().metadata().time()).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) + .await + .unwrap(); - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + let (prep_state, _) = transcript.helper_prep_state(1); + tx.put_report_aggregation::>( + &ReportAggregation::new( *task.id(), aggregation_job_id, - dummy_vdaf::AggregationParam::default(), - (), - Interval::from_time(report.0.metadata().time()).unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), + *report_init.report_share().metadata().id(), + *report_init.report_share().metadata().time(), + 0, + None, + ReportAggregationState::Waiting(prep_state.clone(), None), ), ) .await .unwrap(); - let (prep_state, _) = report.1.helper_prep_state(0); - tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( - *task.id(), - aggregation_job_id, - *report.0.metadata().id(), - *report.0.metadata().time(), - 0, - None, - ReportAggregationState::Waiting(*prep_state, None), - )) - .await - .unwrap(); - Ok(()) }) }) @@ -471,8 +645,10 @@ mod tests { let first_continue_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([PrepareStep::new( - *report.0.metadata().id(), - PrepareStepResult::Continued(report.1.prepare_messages[0].get_encoded()), + *report_init.report_share().metadata().id(), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[1].get_encoded(), + }, )]), ); @@ -492,10 +668,10 @@ mod tests { } } - /// Set up a helper with an aggregation job in round 1 + /// Set up a helper with an aggregation job in round 1. #[allow(clippy::unit_arg)] - async fn setup_aggregation_job_continue_round_recovery_test() -> AggregationJobContinueTestCase - { + async fn setup_aggregation_job_continue_round_recovery_test( + ) -> AggregationJobContinueTestCase> { let mut test_case = setup_aggregation_job_continue_test().await; let first_continue_response = post_aggregation_job_and_decode( @@ -514,7 +690,12 @@ mod tests { .first_continue_request .prepare_steps() .iter() - .map(|step| PrepareStep::new(*step.report_id(), PrepareStepResult::Finished)) + .map(|step| PrepareStep::new( + *step.report_id(), + PrepareStepResult::Finished { + prep_msg: Vec::new() + } + )) .collect() ) ); @@ -570,23 +751,25 @@ mod tests { async fn aggregation_job_continue_round_recovery_mutate_continue_request() { let test_case = setup_aggregation_job_continue_round_recovery_test().await; - let unrelated_report = test_case.report_generator.next(); + let (unrelated_report_init, unrelated_transcript) = test_case + .report_generator + .next(&IdpfInput::from_bools(&[false])); let (before_aggregation_job, before_report_aggregations) = test_case .datastore .run_tx(|tx| { - let (task_id, unrelated_report, aggregation_job_id) = ( + let (task_id, unrelated_report_init, aggregation_job_id) = ( *test_case.task.id(), - unrelated_report.clone(), + unrelated_report_init.clone(), test_case.aggregation_job_id, ); Box::pin(async move { - tx.put_report_share(&task_id, &unrelated_report.0) + tx.put_report_share(&task_id, unrelated_report_init.report_share()) .await .unwrap(); let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -594,8 +777,8 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_sha3(1), &Role::Helper, &task_id, &aggregation_job_id, @@ -614,8 +797,10 @@ mod tests { let modified_request = AggregationJobContinueReq::new( test_case.first_continue_request.round(), Vec::from([PrepareStep::new( - *unrelated_report.0.metadata().id(), - PrepareStepResult::Continued(unrelated_report.1.prepare_messages[0].get_encoded()), + *unrelated_report_init.report_share().metadata().id(), + PrepareStepResult::Finished { + prep_msg: unrelated_transcript.prepare_messages[1].get_encoded(), + }, )]), ); @@ -636,7 +821,7 @@ mod tests { (*test_case.task.id(), test_case.aggregation_job_id); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -644,8 +829,8 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_sha3(1), &Role::Helper, &task_id, &aggregation_job_id, @@ -678,7 +863,7 @@ mod tests { // round mismatch error instead of tripping the check for a request to continue // to round 0. let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index 1c4913a59..b141d3f80 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -12,7 +12,7 @@ use janus_aggregator_core::{ task::{self, Task}, }; use janus_core::{ - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, time::{Clock, DurationExt as _, TimeExt as _}, }; use janus_messages::{ @@ -232,102 +232,102 @@ impl AggregationJobCreator { ) -> anyhow::Result { match (task.query_type(), task.vdaf()) { (task::QueryType::TimeInterval, VdafInstance::Prio3Count) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3CountVec { .. }) => { self.create_aggregation_jobs_for_time_interval_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, Prio3SumVecMultithreaded >(task).await } (task::QueryType::TimeInterval, VdafInstance::Prio3Sum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3SumVec { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::TimeInterval, VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::TimeInterval, VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::TimeInterval, VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) .await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3Count) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) .await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3CountVec { .. }) => { let max_batch_size = *max_batch_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, Prio3SumVecMultithreaded >(task, max_batch_size).await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3Sum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) .await } (task::QueryType::FixedSize { max_batch_size }, VdafInstance::Prio3SumVec { .. }) => { let max_batch_size = *max_batch_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, Prio3SumVec, >(task, max_batch_size).await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3Histogram { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) .await } @@ -775,7 +775,7 @@ mod tests { task::{test_util::TaskBuilder, QueryType as TaskQueryType}, }; use janus_core::{ - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, @@ -1469,10 +1469,9 @@ mod tests { >( tx: &Transaction<'_, C>, task_id: &TaskId, - ) -> HashMap, T)> - { + ) -> HashMap, T)> { let vdaf = Prio3::new_count(2).unwrap(); - read_aggregate_jobs_for_task_generic::( + read_aggregate_jobs_for_task_generic::( tx, task_id, &vdaf, ) .await diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index e0be308a5..39641a3ba 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -20,7 +20,8 @@ use janus_core::{time::Clock, vdaf_dispatch}; use janus_messages::{ query_type::{FixedSize, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, - PartialBatchSelector, PrepareStep, PrepareStepResult, ReportShare, ReportShareError, Role, + PartialBatchSelector, PrepareStep, PrepareStepResult, ReportPrepInit, ReportShare, + ReportShareError, Role, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter, Unit}, @@ -86,13 +87,13 @@ impl AggregationJobDriver { ) -> Result<()> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await }) } } @@ -314,7 +315,7 @@ impl AggregationJobDriver { // Compute report shares to send to helper, and decrypt our input shares & initialize // preparation state. let mut report_aggregations_to_write = Vec::new(); - let mut report_shares = Vec::new(); + let mut report_inits = Vec::new(); let mut stepped_aggregations = Vec::new(); for (report_aggregation, report) in reports { // Check for repeated extensions. @@ -360,10 +361,13 @@ impl AggregationJobDriver { } }; - report_shares.push(ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + report_inits.push(ReportPrepInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + prep_share.get_encoded(), )); stepped_aggregations.push(SteppedAggregation { report_aggregation, @@ -377,7 +381,7 @@ impl AggregationJobDriver { let req = AggregationJobInitializeReq::::new( aggregation_job.aggregation_parameter().get_encoded(), PartialBatchSelector::new(aggregation_job.partial_batch_identifier().clone()), - report_shares, + report_inits, ); let resp_bytes = send_request_to_helper( @@ -398,7 +402,7 @@ impl AggregationJobDriver { lease, task, aggregation_job, - &stepped_aggregations, + stepped_aggregations, report_aggregations_to_write, resp.prepare_steps(), ) @@ -476,9 +480,19 @@ impl AggregationJobDriver { } }; + let prepare_step_result = match &leader_transition { + PrepareTransition::Continue(_, prep_share) => PrepareStepResult::Continued { + prep_msg: prep_msg.get_encoded(), + prep_share: prep_share.get_encoded(), + }, + PrepareTransition::Finish(_) => PrepareStepResult::Finished { + prep_msg: prep_msg.get_encoded(), + }, + }; + prepare_steps.push(PrepareStep::new( *report_aggregation.report_id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + prepare_step_result, )); stepped_aggregations.push(SteppedAggregation { report_aggregation, @@ -510,7 +524,7 @@ impl AggregationJobDriver { lease, task, aggregation_job, - &stepped_aggregations, + stepped_aggregations, report_aggregations_to_write, resp.prepare_steps(), ) @@ -530,7 +544,7 @@ impl AggregationJobDriver { lease: Arc>, task: Arc, leader_aggregation_job: AggregationJob, - stepped_aggregations: &[SteppedAggregation], + stepped_aggregations: Vec>, mut report_aggregations_to_write: Vec>, helper_prep_steps: &[PrepareStep], ) -> Result<()> @@ -555,11 +569,11 @@ impl AggregationJobDriver { leader_aggregation_job.aggregation_parameter().clone(), ); for (stepped_aggregation, helper_prep_step) in - stepped_aggregations.iter().zip(helper_prep_steps) + stepped_aggregations.into_iter().zip(helper_prep_steps) { let (report_aggregation, leader_transition) = ( - &stepped_aggregation.report_aggregation, - &stepped_aggregation.leader_transition, + stepped_aggregation.report_aggregation, + stepped_aggregation.leader_transition, ); if helper_prep_step.report_id() != report_aggregation.report_id() { return Err(anyhow!( @@ -567,80 +581,213 @@ impl AggregationJobDriver { )); } - let new_state = match helper_prep_step.result() { - PrepareStepResult::Continued(payload) => { - // If the leader continued too, combine the leader's prepare share with the + let new_state = (|| match helper_prep_step.result() { + PrepareStepResult::Continued { + prep_msg, + prep_share, + } => { + // If the Leader continued too, combine the Leader's prepare share with the // helper's to compute next round's prepare message. Prepare to store the // leader's new state & the prepare message. If the leader didn't continue, // transition to INVALID. - if let PrepareTransition::Continue(leader_prep_state, leader_prep_share) = - leader_transition + let leader_prep_state = if let PrepareTransition::Continue( + leader_prep_state, + _, + ) = leader_transition { - let leader_prep_state = leader_prep_state.clone(); - let helper_prep_share = - A::PrepareShare::get_decoded_with_param(&leader_prep_state, payload) - .context("couldn't decode helper's prepare message"); - let prep_msg = helper_prep_share.and_then(|helper_prep_share| { - vdaf.prepare_preprocess([leader_prep_share.clone(), helper_prep_share]) - .context( - "couldn't preprocess leader & helper prepare shares into \ - prepare message", - ) - }); - match prep_msg { - Ok(prep_msg) => { - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) - } - Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't compute prepare message"); - self.aggregate_step_failure_counter.add( - &Context::current(), - 1, - &[KeyValue::new("type", "prepare_message_failure")], - ); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) - } - } + leader_prep_state } else { - warn!(report_id = %report_aggregation.report_id(), "Helper continued but leader did not"); + warn!(report_id = %report_aggregation.report_id(), "Helper continued but Leader did not"); self.aggregate_step_failure_counter.add( &Context::current(), 1, &[KeyValue::new("type", "continue_mismatch")], ); - ReportAggregationState::Invalid - } + return ReportAggregationState::Invalid; + }; + + let prep_msg = match A::PrepareMessage::get_decoded_with_param( + &leader_prep_state, + prep_msg, + ) { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't decode prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_message_decode_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::UnrecognizedMessage, + ); + } + }; + let helper_prep_share = match A::PrepareShare::get_decoded_with_param( + &leader_prep_state, + prep_share, + ) { + Ok(helper_prep_share) => helper_prep_share, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't decode Helper prepare share"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "helper_prep_share_decode_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::UnrecognizedMessage, + ); + } + }; + + let (leader_prep_state, leader_prep_share) = match vdaf + .prepare_step(leader_prep_state, prep_msg) + { + Ok(PrepareTransition::Continue(leader_prep_state, leader_prep_share)) => { + (leader_prep_state, leader_prep_share) + } + Ok(_) => { + warn!(report_id = %report_aggregation.report_id(), "Helper continued but Leader did not"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "continue_mismatch")], + ); + return ReportAggregationState::Invalid; + } + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Prepare step failed"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_step_failure")], + ); + return ReportAggregationState::Failed(ReportShareError::VdafPrepError); + } + }; + + let prep_msg = match vdaf + .prepare_preprocess([leader_prep_share, helper_prep_share]) + { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't compute prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_message_failure")], + ); + return ReportAggregationState::Failed(ReportShareError::VdafPrepError); + } + }; + + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) } - PrepareStepResult::Finished => { - // If the leader finished too, we are done; prepare to store the output share. - // If the leader didn't finish too, we transition to INVALID. - if let PrepareTransition::Finish(out_share) = leader_transition { - match accumulator.update( - leader_aggregation_job.partial_batch_identifier(), - report_aggregation.report_id(), - report_aggregation.time(), - out_share, - ) { - Ok(_) => ReportAggregationState::Finished(out_share.clone()), - Err(error) => { - warn!(report_id = %report_aggregation.report_id(), ?error, "Could not update batch aggregation"); + PrepareStepResult::Finished { prep_msg } => { + match leader_transition { + PrepareTransition::Continue(leader_prep_state, _) => { + // The Helper is finished. If the Leader is ready to continue & also + // finishes, we are done; prepare to store the output share. If the + // Leader doesn't finish too, we transition to INVALID. + + let prep_msg = match A::PrepareMessage::get_decoded_with_param( + &leader_prep_state, + prep_msg, + ) { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't decode prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_message_decode_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::UnrecognizedMessage, + ); + } + }; + + let out_share = match vdaf.prepare_step(leader_prep_state, prep_msg) { + Ok(PrepareTransition::Finish(out_share)) => out_share, + Ok(_) => { + warn!(report_id = %report_aggregation.report_id(), "Helper finished but Leader did not"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "finish_mismatch")], + ); + return ReportAggregationState::Invalid; + } + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Prepare step failed"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_step_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + ); + } + }; + + if let Err(err) = accumulator.update( + leader_aggregation_job.partial_batch_identifier(), + report_aggregation.report_id(), + report_aggregation.time(), + &out_share, + ) { + warn!(report_id = %report_aggregation.report_id(), ?err, "Could not update batch aggregation"); self.aggregate_step_failure_counter.add( &Context::current(), 1, &[KeyValue::new("type", "accumulate_failure")], ); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + return ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + ); } + + ReportAggregationState::Finished(out_share) + } + + PrepareTransition::Finish(out_share) => { + // If the Leader is already done, check that the Helper properly + // finished too by transmitting a Finished message with an empty + // prep_msg field. If so, we are done. Otherwise, we transition to + // INVALID. + if !prep_msg.is_empty() { + warn!(report_id = %report_aggregation.report_id(), "Leader finished but Helper did not"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "finish_mismatch")], + ); + return ReportAggregationState::Invalid; + } + + if let Err(err) = accumulator.update( + leader_aggregation_job.partial_batch_identifier(), + report_aggregation.report_id(), + report_aggregation.time(), + &out_share, + ) { + warn!(report_id = %report_aggregation.report_id(), ?err, "Could not update batch aggregation"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "accumulate_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + ); + } + + ReportAggregationState::Finished(out_share) } - } else { - warn!(report_id = %report_aggregation.report_id(), "Helper finished but leader did not"); - self.aggregate_step_failure_counter.add( - &Context::current(), - 1, - &[KeyValue::new("type", "finish_mismatch")], - ); - ReportAggregationState::Invalid } } @@ -655,7 +802,7 @@ impl AggregationJobDriver { ); ReportAggregationState::Failed(*err) } - }; + })(); report_aggregations_to_write.push(report_aggregation.clone().with_state(new_state)); } @@ -665,18 +812,12 @@ impl AggregationJobDriver { let aggregation_job_is_finished = report_aggregations_to_write .iter() .all(|ra| !matches!(ra.state(), ReportAggregationState::Waiting(_, _))); - let (next_round, next_state) = if aggregation_job_is_finished { - ( - leader_aggregation_job.round(), - AggregationJobState::Finished, - ) + let next_state = if aggregation_job_is_finished { + AggregationJobState::Finished } else { - ( - // Advance self to next VDAF preparation round - leader_aggregation_job.round().increment(), - AggregationJobState::InProgress, - ) + AggregationJobState::InProgress }; + let next_round = leader_aggregation_job.round().increment(); let aggregation_job_to_write = leader_aggregation_job .with_round(next_round) @@ -728,9 +869,9 @@ impl AggregationJobDriver { ) -> Result<()> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LEN) => { self.cancel_aggregation_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, TimeInterval, VdafType, @@ -739,9 +880,9 @@ impl AggregationJobDriver { }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LEN) => { self.cancel_aggregation_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, FixedSize, VdafType, @@ -886,7 +1027,7 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, report_id::ReportIdChecksumExt, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, test_util::{install_test_trace_subscriber, run_vdaf, runtime::TestRuntimeManager}, time::{Clock, IntervalExt, MockClock, TimeExt}, Runtime, @@ -896,19 +1037,22 @@ mod tests { AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, Duration, Extension, ExtensionType, HpkeConfig, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + ReportIdChecksum, ReportMetadata, ReportPrepInit, ReportShare, ReportShareError, Role, + TaskId, Time, }; use opentelemetry::global::meter; use prio::{ codec::Encode, + idpf::IdpfInput, vdaf::{ self, + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, prio3::{Prio3, Prio3Count}, Aggregator, }, }; use rand::random; - use reqwest::Url; use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration}; #[tokio::test] @@ -932,10 +1076,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -943,8 +1084,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -956,7 +1096,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -973,32 +1113,28 @@ mod tests { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) - .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( + tx.put_aggregation_job( + &AggregationJob::::new( *task.id(), aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), ), ) + .await?; + tx.put_report_aggregation(&ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) .await }) }) @@ -1006,30 +1142,19 @@ mod tests { .unwrap(); // Setup: prepare mocked HTTP responses. - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); - let helper_responses = Vec::from([ - ( - "PUT", - AggregationJobInitializeReq::::MEDIA_TYPE, - AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( - *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), - )])) - .get_encoded(), - ), - ( - "POST", - AggregationJobContinueReq::MEDIA_TYPE, - AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( - *report.metadata().id(), - PrepareStepResult::Finished, - )])) - .get_encoded(), - ), - ]); - let mocked_aggregates = join_all(helper_responses.into_iter().map( + let helper_responses = Vec::from([( + "PUT", + AggregationJobInitializeReq::::MEDIA_TYPE, + AggregationJobResp::MEDIA_TYPE, + AggregationJobResp::new(Vec::from([PrepareStep::new( + *report.metadata().id(), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[0].get_encoded(), + }, + )])) + .get_encoded(), + )]); + let mocked_aggregates = join_all(helper_responses.iter().map( |(req_method, req_content_type, resp_content_type, resp_body)| { server .mock( @@ -1042,7 +1167,7 @@ mod tests { "DAP-Auth-Token", str::from_utf8(agg_auth_token.as_ref()).unwrap(), ) - .match_header(CONTENT_TYPE.as_str(), req_content_type) + .match_header(CONTENT_TYPE.as_str(), *req_content_type) .with_status(200) .with_header(CONTENT_TYPE.as_str(), resp_content_type) .with_body(resp_body) @@ -1080,7 +1205,9 @@ mod tests { tracing::info!("awaiting stepper tasks"); // Wait for all of the aggregate job stepper tasks to complete. - runtime_manager.wait_for_completed_tasks("stepper", 2).await; + runtime_manager + .wait_for_completed_tasks("stepper", helper_responses.len()) + .await; // Stop the aggregate job driver task. task_handle.abort(); @@ -1089,18 +1216,16 @@ mod tests { mocked_aggregate.assert_async().await; } - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -1116,7 +1241,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1157,10 +1282,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -1168,8 +1290,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1181,7 +1302,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1189,7 +1310,7 @@ mod tests { Vec::new(), transcript.input_shares.clone(), ); - let repeated_extension_report = generate_report::( + let repeated_extension_report = generate_report::( *task.id(), ReportMetadata::new(random(), time), helper_hpke_keypair.config(), @@ -1217,7 +1338,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( @@ -1231,31 +1352,29 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *repeated_extension_report.metadata().id(), - *repeated_extension_report.metadata().time(), - 1, - None, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *repeated_extension_report.metadata().id(), + *repeated_extension_report.metadata().time(), + 1, + None, + ReportAggregationState::Start, + ), + ) .await?; Ok(tx @@ -1273,19 +1392,24 @@ mod tests { // (This is fragile in that it expects the leader request to be deterministically encoded. // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) + let (_, leader_prep_share) = transcript.leader_prep_state(0); let leader_request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + Vec::from([ReportPrepInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + leader_prep_share.get_encoded(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[0].get_encoded(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1345,30 +1469,26 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), - ); - let leader_prep_state = transcript.leader_prep_state(0).clone(); - let prep_msg = transcript.prepare_messages[0].clone(); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, None, - ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), + ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); let want_repeated_extension_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *repeated_extension_report.metadata().id(), @@ -1392,7 +1512,7 @@ mod tests { ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1451,10 +1571,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let report_metadata = ReportMetadata::new( @@ -1464,8 +1581,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1477,7 +1593,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1495,33 +1611,33 @@ mod tests { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) .await?; Ok(tx @@ -1539,19 +1655,24 @@ mod tests { // (This is fragile in that it expects the leader request to be deterministically encoded. // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) + let (_, leader_prep_share) = transcript.leader_prep_state(0); let leader_request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + Vec::from([ReportPrepInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + leader_prep_share.get_encoded(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[0].get_encoded(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1611,28 +1732,23 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, None, - ReportAggregationState::Waiting( - transcript.leader_prep_state(0).clone(), - Some(transcript.prepare_messages[0].clone()), - ), + ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); let (got_aggregation_job, got_report_aggregation) = ds @@ -1641,7 +1757,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1676,37 +1792,37 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock .now() .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1716,17 +1832,21 @@ mod tests { ); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); + let (leader_prep_state, _) = transcript.leader_prep_state(1); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate( + &aggregation_param, + [transcript.output_share(Role::Leader).clone()], + ) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; + let prep_msg = &transcript.prepare_messages[1]; let lease = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, aggregation_param, report, leader_prep_state, prep_msg) = ( vdaf.clone(), task.clone(), + aggregation_param.clone(), report.clone(), leader_prep_state.clone(), prep_msg.clone(), @@ -1736,13 +1856,13 @@ mod tests { tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -1751,8 +1871,8 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id, @@ -1783,12 +1903,16 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg.get_encoded(), + }, )]), ); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Finished, + PrepareStepResult::Finished { + prep_msg: Vec::new(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1846,38 +1970,39 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, - AggregationJobRound::from(1), + AggregationJobRound::from(2), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), - ); let batch_interval_start = report .metadata() .time() .to_batch_interval_start(task.time_precision()) .unwrap(); let want_batch_aggregations = Vec::from([BatchAggregation::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), Interval::new(batch_interval_start, *task.time_precision()).unwrap(), - (), + aggregation_param.clone(), 0, leader_aggregate_share, 1, @@ -1887,11 +2012,15 @@ mod tests { let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds .run_tx(|tx| { - let (vdaf, task, report_metadata) = - (Arc::clone(&vdaf), task.clone(), report.metadata().clone()); + let (vdaf, task, report_metadata, aggregation_param) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -1909,8 +2038,8 @@ mod tests { .unwrap(); let batch_aggregations = TimeInterval::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, >( tx, @@ -1924,7 +2053,7 @@ mod tests { *task.time_precision(), ) .unwrap(), - &(), + &aggregation_param, ) .await .unwrap(); @@ -1941,7 +2070,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -1965,17 +2094,14 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::FixedSize { max_batch_size: 10 }, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let report_metadata = ReportMetadata::new( random(), @@ -1984,20 +2110,23 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2007,18 +2136,22 @@ mod tests { ); let batch_id = random(); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); + let (leader_prep_state, _) = transcript.leader_prep_state(1); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate( + &aggregation_param, + [transcript.output_share(Role::Leader).clone()], + ) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; + let prep_msg = &transcript.prepare_messages[1]; let lease = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, report, aggregation_param, leader_prep_state, prep_msg) = ( vdaf.clone(), task.clone(), report.clone(), + aggregation_param.clone(), leader_prep_state.clone(), prep_msg.clone(), ); @@ -2027,13 +2160,13 @@ mod tests { tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2042,8 +2175,8 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id, @@ -2074,12 +2207,16 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg.get_encoded(), + }, )]), ); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Finished, + PrepareStepResult::Finished { + prep_msg: Vec::new(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -2137,34 +2274,35 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, - AggregationJobRound::from(1), + AggregationJobRound::from(2), ); let leader_output_share = transcript.output_share(Role::Leader); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Finished(leader_output_share.clone()), - ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished(leader_output_share.clone()), + ); let want_batch_aggregations = Vec::from([BatchAggregation::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, - Prio3Count, + Poplar1, >::new( *task.id(), batch_id, - (), + aggregation_param.clone(), 0, leader_aggregate_share, 1, @@ -2174,11 +2312,15 @@ mod tests { let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds .run_tx(|tx| { - let (vdaf, task, report_metadata) = - (Arc::clone(&vdaf), task.clone(), report.metadata().clone()); + let (vdaf, task, report_metadata, aggregation_param) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2196,10 +2338,10 @@ mod tests { .unwrap(); let batch_aggregations = FixedSize::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, - >(tx, &task, &vdaf, &batch_id, &()) + >(tx, &task, &vdaf, &batch_id, &aggregation_param) .await?; Ok((aggregation_job, report_aggregation, batch_aggregations)) }) @@ -2214,7 +2356,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -2249,8 +2391,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -2261,7 +2402,7 @@ mod tests { ); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2271,18 +2412,16 @@ mod tests { ); let aggregation_job_id = random(); - let aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - ); - let report_aggregation = ReportAggregation::::new( + let aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ); + let report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -2339,7 +2478,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2428,15 +2567,11 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); @@ -2447,7 +2582,7 @@ mod tests { .unwrap(); let report_metadata = ReportMetadata::new(random(), time); let transcript = run_vdaf(&vdaf, verify_key.as_bytes(), &(), report_metadata.id(), &0); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2468,35 +2603,31 @@ mod tests { // run through initial VDAF preparation before sending a request to the helper. tx.put_client_report(&vdaf, &report).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) - .await?; - - tx.put_report_aggregation( - &ReportAggregation::::new( + tx.put_aggregation_job( + &AggregationJob::::new( *task.id(), aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - None, - ReportAggregationState::Start, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), ), ) .await?; + tx.put_report_aggregation(&ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) + .await?; + Ok(()) }) }) @@ -2595,7 +2726,7 @@ mod tests { .run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2607,7 +2738,7 @@ mod tests { .unwrap(); assert_eq!( aggregation_job_after, - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index 21ac91332..d1c692fd8 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -71,9 +71,9 @@ impl CollectionJobDriver { ) -> Result<(), Error> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { self.step_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, TimeInterval, VdafType @@ -82,9 +82,9 @@ impl CollectionJobDriver { }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { self.step_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, FixedSize, VdafType @@ -261,8 +261,8 @@ impl CollectionJobDriver { ) -> Result<(), Error> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.abandon_collection_job_generic::( + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.abandon_collection_job_generic::( datastore, Arc::new(vdaf), lease, @@ -271,8 +271,8 @@ impl CollectionJobDriver { }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.abandon_collection_job_generic::( + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.abandon_collection_job_generic::( datastore, Arc::new(vdaf), lease, @@ -499,7 +499,6 @@ mod tests { use prio::codec::{Decode, Encode}; use rand::random; use std::{str, sync::Arc, time::Duration as StdDuration}; - use url::Url; async fn setup_collection_job_test_case( server: &mut mockito::Server, @@ -513,10 +512,7 @@ mod tests { ) { let time_precision = Duration::from_seconds(500); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .with_time_precision(time_precision) .with_min_batch_size(10) .build(); @@ -571,7 +567,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Finished(OutputShare()), + ReportAggregationState::Finished(OutputShare(0)), )) .await?; @@ -641,10 +637,7 @@ mod tests { let time_precision = Duration::from_seconds(500); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .with_time_precision(time_precision) .with_min_batch_size(10) .build(); @@ -700,7 +693,7 @@ mod tests { *report.metadata().time(), 0, None, - ReportAggregationState::Finished(OutputShare()), + ReportAggregationState::Finished(OutputShare(0)), )) .await?; diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 242a8bcc2..ee53a2970 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -197,7 +197,7 @@ async fn setup_fixed_size_current_batch_collection_job_test_case( time, ord, None, - ReportAggregationState::Finished(dummy_vdaf::OutputShare()), + ReportAggregationState::Finished(dummy_vdaf::OutputShare(0)), )) .await .unwrap(); @@ -347,13 +347,12 @@ async fn collection_job_success_fixed_size() { .align_to_time_precision(test_case.task.time_precision()) .unwrap(), ); - assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[0], + collect_resp.leader_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_fixed_size(batch_id), @@ -371,7 +370,7 @@ async fn collection_job_success_fixed_size() { test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[1], + collect_resp.helper_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_fixed_size(batch_id), diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 89028f4e0..332d54082 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -539,8 +539,7 @@ mod tests { http_handlers::aggregator_handler, tests::{ create_report, create_report_with_id, default_aggregator_config, - generate_helper_report_share, generate_helper_report_share_for_plaintext, - DUMMY_VERIFY_KEY_LENGTH, + generate_helper_report_init, generate_helper_report_init_for_plaintext, }, }; use assert_matches::assert_matches; @@ -561,7 +560,7 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, report_id::ReportIdChecksumExt, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LEN}, test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, }; @@ -572,14 +571,16 @@ mod tests { Collection, CollectionJobId, CollectionReq, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PrepareStep, PrepareStepResult, Query, Report, ReportId, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + ReportIdChecksum, ReportMetadata, ReportPrepInit, ReportShare, ReportShareError, Role, + TaskId, Time, }; use prio::{ codec::{Decode, Encode}, - field::Field64, + idpf::IdpfInput, vdaf::{ - prio3::{Prio3, Prio3Count}, - AggregateShare, Aggregator, OutputShare, + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + Aggregator, }, }; use rand::random; @@ -851,7 +852,8 @@ mod tests { .unwrap(), ), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) @@ -867,26 +869,6 @@ mod tests { ) .await; - // should reject a report with only one share with the unrecognizedMessage type. - let bad_report = Report::new( - report.metadata().clone(), - report.public_share().to_vec(), - Vec::from([report.encrypted_input_shares()[0].clone()]), - ); - let mut test_conn = put(task.report_upload_uri().unwrap().path()) - .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) - .with_request_body(bad_report.get_encoded()) - .run_async(&handler) - .await; - check_response( - &mut test_conn, - Status::BadRequest, - "unrecognizedMessage", - "The message type for a response was incorrect or the payload was malformed.", - task.id(), - ) - .await; - // should reject a report using the wrong HPKE config for the leader, and reply with // the error type outdatedConfig. let unused_hpke_config_id = (0..) @@ -896,16 +878,15 @@ mod tests { let bad_report = Report::new( report.metadata().clone(), report.public_share().to_vec(), - Vec::from([ - HpkeCiphertext::new( - unused_hpke_config_id, - report.encrypted_input_shares()[0] - .encapsulated_key() - .to_vec(), - report.encrypted_input_shares()[0].payload().to_vec(), - ), - report.encrypted_input_shares()[1].clone(), - ]), + HpkeCiphertext::new( + unused_hpke_config_id, + report + .leader_encrypted_input_share() + .encapsulated_key() + .to_vec(), + report.leader_encrypted_input_share().payload().to_vec(), + ), + report.helper_encrypted_input_share().clone(), ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) @@ -931,7 +912,8 @@ mod tests { let bad_report = Report::new( ReportMetadata::new(*report.metadata().id(), bad_report_time), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) @@ -1008,7 +990,8 @@ mod tests { .unwrap(), ), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ) .get_encoded(), ) @@ -1262,10 +1245,8 @@ mod tests { assert_eq!(want_status, test_conn.status().unwrap() as u16); } - #[tokio::test] - // Silence the unit_arg lint so that we can work with dummy_vdaf::Vdaf::InputShare values (whose - // type is ()). #[allow(clippy::unit_arg)] + #[tokio::test] async fn aggregate_init() { // Prepare datastore & request. install_test_trace_subscriber(); @@ -1280,7 +1261,7 @@ mod tests { let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); - // report_share_0 is a "happy path" report. + // report_init_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), clock @@ -1293,18 +1274,17 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_0.id(), - &(), + &0, ); - let report_share_0 = generate_helper_report_share::( + let report_init_0 = generate_helper_report_init( *task.id(), report_metadata_0, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_1 fails decryption. + // report_init_1 fails decryption. let report_metadata_1 = ReportMetadata::new( random(), clock @@ -1317,17 +1297,16 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_1.id(), - &(), + &0, ); - let report_share_1 = generate_helper_report_share::( + let report_init_1 = generate_helper_report_init( *task.id(), report_metadata_1.clone(), hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - let encrypted_input_share = report_share_1.encrypted_input_share(); + let encrypted_input_share = report_init_1.report_share().encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); corrupted_payload[0] ^= 0xFF; let corrupted_input_share = HpkeCiphertext::new( @@ -1335,14 +1314,15 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); - let report_share_1 = ReportShare::new( + let report_share_1: ReportShare = ReportShare::new( report_metadata_1, - encoded_public_share.clone(), + transcript.public_share.get_encoded(), corrupted_input_share, ); + let report_init_1 = + ReportPrepInit::new(report_share_1, report_init_1.leader_prep_share().to_vec()); - // report_share_2 fails decoding due to an issue with the input share. + // report_init_2 fails decoding due to an issue with the input share. let report_metadata_2 = ReportMetadata::new( random(), clock @@ -1355,19 +1335,25 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_2.id(), - &(), + &0, ); let mut input_share_bytes = transcript.input_shares[1].get_encoded(); input_share_bytes.push(0); // can no longer be decoded. - let report_share_2 = generate_helper_report_share_for_plaintext( + let report_init_2 = generate_helper_report_init_for_plaintext( report_metadata_2.clone(), hpke_key.config(), - encoded_public_share.clone(), + transcript.public_share.get_encoded(), &input_share_bytes, - &InputShareAad::new(*task.id(), report_metadata_2, encoded_public_share).get_encoded(), + &InputShareAad::new( + *task.id(), + report_metadata_2, + transcript.public_share.get_encoded(), + ) + .get_encoded(), + transcript.leader_prep_state(0).1.get_encoded(), ); - // report_share_3 has an unknown HPKE config ID. + // report_init_3 has an unknown HPKE config ID. let report_metadata_3 = ReportMetadata::new( random(), clock @@ -1380,7 +1366,7 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_3.id(), - &(), + &0, ); let wrong_hpke_config = loop { let hpke_config = generate_test_hpke_config_and_private_key().config().clone(); @@ -1389,16 +1375,15 @@ mod tests { } break hpke_config; }; - let report_share_3 = generate_helper_report_share::( + let report_init_3 = generate_helper_report_init( *task.id(), report_metadata_3, &wrong_hpke_config, - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_4 has already been aggregated in another aggregation job, with the same + // report_init_4 has already been aggregated in another aggregation job, with the same // aggregation parameter. let report_metadata_4 = ReportMetadata::new( random(), @@ -1412,18 +1397,17 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_4.id(), - &(), + &0, ); - let report_share_4 = generate_helper_report_share::( + let report_init_4 = generate_helper_report_init( *task.id(), report_metadata_4, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_5 falls into a batch that has already been collected. + // report_init_5 falls into a batch that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( task.time_precision().as_seconds() / 2, )); @@ -1439,18 +1423,17 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_5.id(), - &(), + &0, ); - let report_share_5 = generate_helper_report_share::( + let report_init_5 = generate_helper_report_init( *task.id(), report_metadata_5, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_6 fails decoding due to an issue with the public share. + // report_init_6 fails decoding due to an issue with the public share. let public_share_6 = Vec::from([0]); let report_metadata_6 = ReportMetadata::new( random(), @@ -1464,17 +1447,18 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_6.id(), - &(), + &0, ); - let report_share_6 = generate_helper_report_share_for_plaintext( + let report_init_6 = generate_helper_report_init_for_plaintext( report_metadata_6.clone(), hpke_key.config(), public_share_6.clone(), &transcript.input_shares[1].get_encoded(), &InputShareAad::new(*task.id(), report_metadata_6, public_share_6).get_encoded(), + transcript.leader_prep_state(0).1.get_encoded(), ); - // report_share_7 fails due to having repeated extensions. + // report_init_7 fails due to having repeated extensions. let report_metadata_7 = ReportMetadata::new( random(), clock @@ -1487,21 +1471,20 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_7.id(), - &(), + &0, ); - let report_share_7 = generate_helper_report_share::( + let report_init_7 = generate_helper_report_init( *task.id(), report_metadata_7, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::from([ Extension::new(ExtensionType::Tbd, Vec::new()), Extension::new(ExtensionType::Tbd, Vec::new()), ]), - &transcript.input_shares[0], ); - // report_share_8 has already been aggregated in another aggregation job, with a different + // report_init_8 has already been aggregated in another aggregation job, with a different // aggregation parameter. let report_metadata_8 = ReportMetadata::new( random(), @@ -1515,32 +1498,62 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(1), report_metadata_8.id(), - &(), + &0, ); - let report_share_8 = generate_helper_report_share::( + let report_init_8 = generate_helper_report_init( *task.id(), report_metadata_8, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], + ); + + // report_init_9 fails decoding due to an issue with the leader preparation share. + let report_metadata_9 = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &vdaf, + verify_key.as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata_9.id(), + &0, + ); + let report_init_9 = generate_helper_report_init_for_plaintext( + report_metadata_9.clone(), + hpke_key.config(), + transcript.public_share.get_encoded(), + &transcript.input_shares[1].get_encoded(), + &InputShareAad::new( + *task.id(), + report_metadata_9, + transcript.public_share.get_encoded(), + ) + .get_encoded(), + Vec::from([0]), ); let (conflicting_aggregation_job, non_conflicting_aggregation_job) = datastore .run_tx(|tx| { - let (task, report_share_4, report_share_8) = - (task.clone(), report_share_4.clone(), report_share_8.clone()); + let (task, report_init_4, report_init_8) = + (task.clone(), report_init_4.clone(), report_init_8.clone()); Box::pin(async move { tx.put_task(&task).await?; - // report_share_4 and report_share_8 are already in the datastore as they were + // report_init_4 and report_init_8 are already in the datastore as they were // referenced by existing aggregation jobs. - tx.put_report_share(task.id(), &report_share_4).await?; - tx.put_report_share(task.id(), &report_share_8).await?; + tx.put_report_share(task.id(), report_init_4.report_share()) + .await?; + tx.put_report_share(task.id(), report_init_8.report_share()) + .await?; - // Put in an aggregation job and report aggregation for report_share_4. It uses + // Put in an aggregation job and report aggregation for report_init_4. It uses // the same aggregation parameter as the aggregation job this test will later - // add and so should cause report_share_4 to fail to prepare. + // add and so should cause report_init_4 to fail to prepare. let conflicting_aggregation_job = AggregationJob::new( *task.id(), random(), @@ -1559,8 +1572,8 @@ mod tests { tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( *task.id(), *conflicting_aggregation_job.id(), - *report_share_4.metadata().id(), - *report_share_4.metadata().time(), + *report_init_4.report_share().metadata().id(), + *report_init_4.report_share().metadata().time(), 0, None, ReportAggregationState::Start, @@ -1568,9 +1581,9 @@ mod tests { .await .unwrap(); - // Put in an aggregation job and report aggregation for report_share_8, using a + // Put in an aggregation job and report aggregation for report_init_8, using a // a different aggregation parameter. As the aggregation parameter differs, - // report_share_8 should prepare successfully in the aggregation job we'll PUT + // report_init_8 should prepare successfully in the aggregation job we'll PUT // later. let non_conflicting_aggregation_job = AggregationJob::new( *task.id(), @@ -1590,8 +1603,8 @@ mod tests { tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( *task.id(), *non_conflicting_aggregation_job.id(), - *report_share_8.metadata().id(), - *report_share_8.metadata().time(), + *report_init_8.report_share().metadata().id(), + *report_init_8.report_share().metadata().time(), 0, None, ReportAggregationState::Start, @@ -1627,15 +1640,16 @@ mod tests { dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), Vec::from([ - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), - report_share_6.clone(), - report_share_7.clone(), - report_share_8.clone(), + report_init_0.clone(), + report_init_1.clone(), + report_init_2.clone(), + report_init_3.clone(), + report_init_4.clone(), + report_init_5.clone(), + report_init_6.clone(), + report_init_7.clone(), + report_init_8.clone(), + report_init_9.clone(), ]), ); @@ -1662,64 +1676,101 @@ mod tests { let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap(); // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 9); + assert_eq!(aggregate_resp.prepare_steps().len(), 10); let prepare_step_0 = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step_0.report_id(), report_share_0.metadata().id()); - assert_matches!(prepare_step_0.result(), &PrepareStepResult::Continued(..)); + assert_eq!( + prepare_step_0.report_id(), + report_init_0.report_share().metadata().id() + ); + assert_matches!(prepare_step_0.result(), &PrepareStepResult::Finished { .. }); let prepare_step_1 = aggregate_resp.prepare_steps().get(1).unwrap(); - assert_eq!(prepare_step_1.report_id(), report_share_1.metadata().id()); + assert_eq!( + prepare_step_1.report_id(), + report_init_1.report_share().metadata().id() + ); assert_matches!( prepare_step_1.result(), &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) ); let prepare_step_2 = aggregate_resp.prepare_steps().get(2).unwrap(); - assert_eq!(prepare_step_2.report_id(), report_share_2.metadata().id()); + assert_eq!( + prepare_step_2.report_id(), + report_init_2.report_share().metadata().id() + ); assert_matches!( prepare_step_2.result(), &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage) ); let prepare_step_3 = aggregate_resp.prepare_steps().get(3).unwrap(); - assert_eq!(prepare_step_3.report_id(), report_share_3.metadata().id()); + assert_eq!( + prepare_step_3.report_id(), + report_init_3.report_share().metadata().id() + ); assert_matches!( prepare_step_3.result(), &PrepareStepResult::Failed(ReportShareError::HpkeUnknownConfigId) ); let prepare_step_4 = aggregate_resp.prepare_steps().get(4).unwrap(); - assert_eq!(prepare_step_4.report_id(), report_share_4.metadata().id()); + assert_eq!( + prepare_step_4.report_id(), + report_init_4.report_share().metadata().id() + ); assert_eq!( prepare_step_4.result(), &PrepareStepResult::Failed(ReportShareError::ReportReplayed) ); let prepare_step_5 = aggregate_resp.prepare_steps().get(5).unwrap(); - assert_eq!(prepare_step_5.report_id(), report_share_5.metadata().id()); + assert_eq!( + prepare_step_5.report_id(), + report_init_5.report_share().metadata().id() + ); assert_eq!( prepare_step_5.result(), &PrepareStepResult::Failed(ReportShareError::BatchCollected) ); let prepare_step_6 = aggregate_resp.prepare_steps().get(6).unwrap(); - assert_eq!(prepare_step_6.report_id(), report_share_6.metadata().id()); + assert_eq!( + prepare_step_6.report_id(), + report_init_6.report_share().metadata().id() + ); assert_eq!( prepare_step_6.result(), &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), ); let prepare_step_7 = aggregate_resp.prepare_steps().get(7).unwrap(); - assert_eq!(prepare_step_7.report_id(), report_share_7.metadata().id()); + assert_eq!( + prepare_step_7.report_id(), + report_init_7.report_share().metadata().id() + ); assert_eq!( prepare_step_7.result(), &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), ); let prepare_step_8 = aggregate_resp.prepare_steps().get(8).unwrap(); - assert_eq!(prepare_step_8.report_id(), report_share_8.metadata().id()); - assert_matches!(prepare_step_8.result(), &PrepareStepResult::Continued(..)); + assert_eq!( + prepare_step_8.report_id(), + report_init_8.report_share().metadata().id() + ); + assert_matches!(prepare_step_8.result(), &PrepareStepResult::Finished { .. }); + + let prepare_step_9 = aggregate_resp.prepare_steps().get(9).unwrap(); + assert_eq!( + prepare_step_9.report_id(), + report_init_9.report_share().metadata().id() + ); + assert_matches!( + prepare_step_9.result(), + &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage) + ); // Check aggregation job in datastore. let aggregation_jobs = datastore @@ -1748,7 +1799,7 @@ mod tests { } else if aggregation_job.task_id().eq(task.id()) && aggregation_job.id().eq(&aggregation_job_id) && aggregation_job.partial_batch_identifier().eq(&()) - && aggregation_job.state().eq(&AggregationJobState::InProgress) + && aggregation_job.state().eq(&AggregationJobState::Finished) { saw_new_aggregation_job = true; } @@ -1770,16 +1821,16 @@ mod tests { // This report has the same ID as the previous one, but a different timestamp. let mutated_timestamp_report_metadata = ReportMetadata::new( - *test_case.report_shares[0].metadata().id(), + *test_case.report_inits[0].report_share().metadata().id(), test_case .clock .now() .add(test_case.task.time_precision()) .unwrap(), ); - let mutated_timestamp_report_share = test_case - .report_share_generator - .next_with_metadata(mutated_timestamp_report_metadata) + let mutated_timestamp_report_init = test_case + .report_init_generator + .next_with_metadata(mutated_timestamp_report_metadata, &0) .0; // Send another aggregate job re-using the same report ID but with a different timestamp. It @@ -1787,7 +1838,7 @@ mod tests { let request = AggregationJobInitializeReq::new( other_aggregation_parameter.get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([mutated_timestamp_report_share.clone()]), + Vec::from([mutated_timestamp_report_init.clone()]), ); let mut test_conn = @@ -1806,7 +1857,7 @@ mod tests { let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); assert_eq!( prepare_step.report_id(), - mutated_timestamp_report_share.metadata().id() + mutated_timestamp_report_init.report_share().metadata().id() ); assert_matches!( prepare_step.result(), @@ -1828,8 +1879,14 @@ mod tests { .await .unwrap(); assert_eq!(client_reports.len(), 2); - assert_eq!(&client_reports[0], test_case.report_shares[0].metadata()); - assert_eq!(&client_reports[1], test_case.report_shares[1].metadata()); + assert_eq!( + &client_reports[0], + test_case.report_inits[0].report_share().metadata() + ); + assert_eq!( + &client_reports[1], + test_case.report_inits[1].report_share().metadata() + ); } #[tokio::test] @@ -1846,28 +1903,35 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = ephemeral_datastore.datastore(clock.clone()); - let hpke_key = task.current_hpke_key(); datastore.put_task(&task).await.unwrap(); - let report_share = generate_helper_report_share::( + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata.id(), + &0, + ); + let report_init = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), + report_metadata, + task.current_hpke_key().config(), + &transcript, Vec::new(), - &dummy_vdaf::InputShare::default(), ); + let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([report_init.clone()]), ); // Create aggregator handler, send request, and parse response. @@ -1894,7 +1958,10 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 1); let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + assert_eq!( + prepare_step.report_id(), + report_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) @@ -1915,28 +1982,35 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = ephemeral_datastore.datastore(clock.clone()); - let hpke_key = task.current_hpke_key(); datastore.put_task(&task).await.unwrap(); - let report_share = generate_helper_report_share::( + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata.id(), + &0, + ); + let report_init = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), + report_metadata, + task.current_hpke_key().config(), + &transcript, Vec::new(), - &dummy_vdaf::InputShare::default(), ); + let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([report_init.clone()]), ); // Create aggregator filter, send request, and parse response. @@ -1963,7 +2037,10 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 1); let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + assert_eq!( + prepare_step.report_id(), + report_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) @@ -1986,24 +2063,32 @@ mod tests { datastore.put_task(&task).await.unwrap(); - let report_share = ReportShare::new( - ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - Vec::from("PUBLIC"), - HpkeCiphertext::new( - // bogus, but we never get far enough to notice - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata.id(), + &0, + ); + let report_init = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( + *task.id(), + report_metadata, + task.current_hpke_key().config(), + &transcript, + Vec::new(), ); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone(), report_share]), + Vec::from([report_init.clone(), report_init]), ); let handler = @@ -2043,7 +2128,7 @@ mod tests { let aggregation_job_id = random(); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); @@ -2051,12 +2136,16 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); - let hpke_key = task.current_hpke_key(); + let vdaf = Arc::new(Poplar1::new_sha3(1)); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let measurement = IdpfInput::from_bools(&[false]); - // report_share_0 is a "happy path" report. + // report_init_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), clock @@ -2067,22 +2156,21 @@ mod tests { let transcript_0 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, _) = transcript_0.helper_prep_state(0); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let (prep_state_0, _) = transcript_0.helper_prep_state(1); + let prep_msg_0 = transcript_0.prepare_messages[1].clone(); + let report_init_0 = generate_helper_report_init::>( *task.id(), report_metadata_0.clone(), - hpke_key.config(), - &transcript_0.public_share, + task.current_hpke_key().config(), + &transcript_0, Vec::new(), - &transcript_0.input_shares[1], ); - // report_share_1 is omitted by the leader's request. + // report_init_1 is omitted by the leader's request. let report_metadata_1 = ReportMetadata::new( random(), clock @@ -2093,22 +2181,20 @@ mod tests { let transcript_1 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - - let (prep_state_1, _) = transcript_1.helper_prep_state(0); - let report_share_1 = generate_helper_report_share::( + let (prep_state_1, _) = transcript_1.helper_prep_state(1); + let report_init_1 = generate_helper_report_init::>( *task.id(), report_metadata_1.clone(), - hpke_key.config(), - &transcript_1.public_share, + task.current_hpke_key().config(), + &transcript_1, Vec::new(), - &transcript_1.input_shares[1], ); - // report_share_2 falls into a batch that has already been collected. + // report_init_2 falls into a batch that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( task.time_precision().as_seconds() / 2, )); @@ -2122,39 +2208,44 @@ mod tests { let transcript_2 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, _) = transcript_2.helper_prep_state(0); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let (prep_state_2, _) = transcript_2.helper_prep_state(1); + let prep_msg_2 = transcript_2.prepare_messages[1].clone(); + let report_init_2 = generate_helper_report_init::>( *task.id(), report_metadata_2.clone(), - hpke_key.config(), - &transcript_2.public_share, + task.current_hpke_key().config(), + &transcript_2, Vec::new(), - &transcript_2.input_shares[1], ); + let aggregate_share = vdaf + .aggregate( + &aggregation_param, + [ + transcript_0.output_share(Role::Helper).clone(), + transcript_1.output_share(Role::Helper).clone(), + transcript_2.output_share(Role::Helper).clone(), + ], + ) + .unwrap(); + datastore .run_tx(|tx| { - let task = task.clone(); + let (task, aggregation_param, aggregate_share) = (task.clone(), aggregation_param.clone(), aggregate_share.clone()); let (report_share_0, report_share_1, report_share_2) = ( - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), + report_init_0.report_share().clone(), + report_init_1.report_share().clone(), + report_init_2.report_share().clone(), ); let (prep_state_0, prep_state_1, prep_state_2) = ( prep_state_0.clone(), prep_state_1.clone(), prep_state_2.clone(), ); - let (report_metadata_0, report_metadata_1, report_metadata_2) = ( - report_metadata_0.clone(), - report_metadata_1.clone(), - report_metadata_2.clone(), - ); Box::pin(async move { tx.put_task(&task).await?; @@ -2164,13 +2255,13 @@ mod tests { tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2179,36 +2270,36 @@ mod tests { )) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, - *report_metadata_0.id(), - *report_metadata_0.time(), + *report_share_0.metadata().id(), + *report_share_0.metadata().time(), 0, None, ReportAggregationState::Waiting(prep_state_0, None), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, - *report_metadata_1.id(), - *report_metadata_1.time(), + *report_share_1.metadata().id(), + *report_share_1.metadata().time(), 1, None, ReportAggregationState::Waiting(prep_state_1, None), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, - *report_metadata_2.id(), - *report_metadata_2.time(), + *report_share_2.metadata().id(), + *report_share_2.metadata().time(), 2, None, ReportAggregationState::Waiting(prep_state_2, None), @@ -2216,7 +2307,7 @@ mod tests { ) .await?; - tx.put_aggregate_share_job::( + tx.put_aggregate_share_job::>( &AggregateShareJob::new( *task.id(), Interval::new( @@ -2224,8 +2315,8 @@ mod tests { *task.time_precision(), ) .unwrap(), - (), - AggregateShare::from(OutputShare::from(Vec::from([Field64::from(7)]))), + aggregation_param, + aggregate_share, 0, ReportIdChecksum::default(), ), @@ -2241,16 +2332,20 @@ mod tests { Vec::from([ PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_0.get_encoded(), + }, ), PrepareStep::new( *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_2.get_encoded(), + }, ), ]), ); - // Create aggregator handler, send request, and parse response. + // Create aggregator filter, send request, and parse response. let handler = aggregator_handler(Arc::clone(&datastore), clock, default_aggregator_config()).unwrap(); @@ -2261,7 +2356,12 @@ mod tests { assert_eq!( aggregate_resp, AggregationJobResp::new(Vec::from([ - PrepareStep::new(*report_metadata_0.id(), PrepareStepResult::Finished), + PrepareStep::new( + *report_metadata_0.id(), + PrepareStepResult::Finished { + prep_msg: Vec::new() + } + ), PrepareStep::new( *report_metadata_2.id(), PrepareStepResult::Failed(ReportShareError::BatchCollected), @@ -2270,38 +2370,39 @@ mod tests { ); // Validate datastore. - let (aggregation_job, report_aggregations) = - datastore - .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); - Box::pin(async move { - let aggregation_job = tx - .get_aggregation_job::( + let (aggregation_job, report_aggregations) = datastore + .run_tx(|tx| { + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) - .await.unwrap().unwrap(); - let report_aggregations = tx - .get_report_aggregations_for_aggregation_job( - vdaf.as_ref(), - &Role::Helper, - task.id(), - &aggregation_job_id, - ) - .await - .unwrap(); - Ok((aggregation_job, report_aggregations)) - }) + .await + .unwrap() + .unwrap(); + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Helper, + task.id(), + &aggregation_job_id, + ) + .await + .unwrap(); + Ok((aggregation_job, report_aggregations)) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert_eq!( aggregation_job, AggregationJob::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2321,7 +2422,9 @@ mod tests { 0, Some(PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Finished + PrepareStepResult::Finished { + prep_msg: Vec::new() + } )), ReportAggregationState::Finished( transcript_0.output_share(Role::Helper).clone() @@ -2358,7 +2461,7 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); @@ -2374,12 +2477,16 @@ mod tests { .unwrap(), ); - let vdaf = Prio3::new_count(2).unwrap(); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); - let hpke_key = task.current_hpke_key(); + let vdaf = Poplar1::new_sha3(1); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let measurement = IdpfInput::from_bools(&[false]); - // report_share_0 is a "happy path" report. + // report_init_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), first_batch_interval_clock @@ -2390,23 +2497,22 @@ mod tests { let transcript_0 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, _) = transcript_0.helper_prep_state(0); + let (prep_state_0, _) = transcript_0.helper_prep_state(1); let out_share_0 = transcript_0.output_share(Role::Helper); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let prep_msg_0 = transcript_0.prepare_messages[1].clone(); + let report_init_0 = generate_helper_report_init::>( *task.id(), report_metadata_0.clone(), - hpke_key.config(), - &transcript_0.public_share, + task.current_hpke_key().config(), + &transcript_0, Vec::new(), - &transcript_0.input_shares[1], ); - // report_share_1 is another "happy path" report to exercise in-memory accumulation of + // report_init_1 is another "happy path" report to exercise in-memory accumulation of // output shares let report_metadata_1 = ReportMetadata::new( random(), @@ -2418,23 +2524,22 @@ mod tests { let transcript_1 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - let (prep_state_1, _) = transcript_1.helper_prep_state(0); + let (prep_state_1, _) = transcript_1.helper_prep_state(1); let out_share_1 = transcript_1.output_share(Role::Helper); - let prep_msg_1 = transcript_1.prepare_messages[0].clone(); - let report_share_1 = generate_helper_report_share::( + let prep_msg_1 = transcript_1.prepare_messages[1].clone(); + let report_init_1 = generate_helper_report_init::>( *task.id(), report_metadata_1.clone(), - hpke_key.config(), - &transcript_1.public_share, + task.current_hpke_key().config(), + &transcript_1, Vec::new(), - &transcript_1.input_shares[1], ); - // report share 2 aggregates successfully, but into a distinct batch aggregation. + // report_init_2 aggregates successfully, but into a distinct batch aggregation. let report_metadata_2 = ReportMetadata::new( random(), second_batch_interval_clock @@ -2445,40 +2550,34 @@ mod tests { let transcript_2 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, _) = transcript_2.helper_prep_state(0); + let (prep_state_2, _) = transcript_2.helper_prep_state(1); let out_share_2 = transcript_2.output_share(Role::Helper); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let prep_msg_2 = transcript_2.prepare_messages[1].clone(); + let report_init_2 = generate_helper_report_init::>( *task.id(), report_metadata_2.clone(), - hpke_key.config(), - &transcript_2.public_share, + task.current_hpke_key().config(), + &transcript_2, Vec::new(), - &transcript_2.input_shares[1], ); datastore .run_tx(|tx| { - let task = task.clone(); + let (task, aggregation_param) = (task.clone(), aggregation_param.clone()); let (report_share_0, report_share_1, report_share_2) = ( - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), + report_init_0.report_share().clone(), + report_init_1.report_share().clone(), + report_init_2.report_share().clone(), ); let (prep_state_0, prep_state_1, prep_state_2) = ( prep_state_0.clone(), prep_state_1.clone(), prep_state_2.clone(), ); - let (report_metadata_0, report_metadata_1, report_metadata_2) = ( - report_metadata_0.clone(), - report_metadata_1.clone(), - report_metadata_2.clone(), - ); Box::pin(async move { tx.put_task(&task).await?; @@ -2488,13 +2587,13 @@ mod tests { tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2504,39 +2603,39 @@ mod tests { .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - *report_metadata_0.id(), - *report_metadata_0.time(), + *report_share_0.metadata().id(), + *report_share_0.metadata().time(), 0, None, ReportAggregationState::Waiting(prep_state_0, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - *report_metadata_1.id(), - *report_metadata_1.time(), + *report_share_1.metadata().id(), + *report_share_1.metadata().time(), 1, None, ReportAggregationState::Waiting(prep_state_1, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - *report_metadata_2.id(), - *report_metadata_2.time(), + *report_share_2.metadata().id(), + *report_share_2.metadata().time(), 2, None, ReportAggregationState::Waiting(prep_state_2, None), @@ -2554,22 +2653,28 @@ mod tests { Vec::from([ PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_0.get_encoded(), + }, ), PrepareStep::new( *report_metadata_1.id(), - PrepareStepResult::Continued(prep_msg_1.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_1.get_encoded(), + }, ), PrepareStep::new( *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_2.get_encoded(), + }, ), ]), ); - // Create aggregator handler, send request, and parse response. + // Create aggregator filter, send request, and parse response. let handler = aggregator_handler( - Arc::clone(&datastore), + datastore.clone(), first_batch_interval_clock.clone(), default_aggregator_config(), ) @@ -2581,12 +2686,16 @@ mod tests { // Map the batch aggregation ordinal value to 0, as it may vary due to sharding. let batch_aggregations: Vec<_> = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, >( tx, @@ -2601,7 +2710,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds() * 2), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -2610,10 +2719,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -2623,12 +2732,6 @@ mod tests { }) .collect(); - let aggregate_share = vdaf - .aggregate(&(), [out_share_0.clone(), out_share_1.clone()]) - .unwrap(); - let checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) - .updated_with(report_metadata_1.id()); - assert_eq!( batch_aggregations, Vec::from([ @@ -2642,12 +2745,17 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, - aggregate_share, + vdaf.aggregate( + &aggregation_param.clone(), + [out_share_0.clone(), out_share_1.clone()], + ) + .unwrap(), 2, Interval::from_time(report_metadata_0.time()).unwrap(), - checksum, + ReportIdChecksum::for_report_id(report_metadata_0.id()) + .updated_with(report_metadata_1.id()), ), BatchAggregation::new( *task.id(), @@ -2659,9 +2767,10 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, - AggregateShare::from(out_share_2.clone()), + vdaf.aggregate(&aggregation_param, [out_share_2.clone()]) + .unwrap(), 1, Interval::from_time(report_metadata_2.time()).unwrap(), ReportIdChecksum::for_report_id(report_metadata_2.id()), @@ -2671,7 +2780,7 @@ mod tests { // Aggregate some more reports, which should get accumulated into the // batch_aggregations rows created earlier. - // report_share_3 gets aggreated into the first batch interval. + // report_init_3 gets aggregated into the first batch interval. let report_metadata_3 = ReportMetadata::new( random(), first_batch_interval_clock @@ -2682,23 +2791,22 @@ mod tests { let transcript_3 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_3.id(), - &0, + &measurement, ); - let (prep_state_3, _) = transcript_3.helper_prep_state(0); + let (prep_state_3, _) = transcript_3.helper_prep_state(1); let out_share_3 = transcript_3.output_share(Role::Helper); - let prep_msg_3 = transcript_3.prepare_messages[0].clone(); - let report_share_3 = generate_helper_report_share::( + let prep_msg_3 = transcript_3.prepare_messages[1].clone(); + let report_init_3 = generate_helper_report_init::>( *task.id(), report_metadata_3.clone(), - hpke_key.config(), - &transcript_3.public_share, + task.current_hpke_key().config(), + &transcript_3, Vec::new(), - &transcript_3.input_shares[1], ); - // report_share_4 gets aggregated into the second batch interval + // report_init_4 gets aggregated into the second batch interval. let report_metadata_4 = ReportMetadata::new( random(), second_batch_interval_clock @@ -2709,23 +2817,22 @@ mod tests { let transcript_4 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_4.id(), - &0, + &measurement, ); - let (prep_state_4, _) = transcript_4.helper_prep_state(0); + let (prep_state_4, _) = transcript_4.helper_prep_state(1); let out_share_4 = transcript_4.output_share(Role::Helper); - let prep_msg_4 = transcript_4.prepare_messages[0].clone(); - let report_share_4 = generate_helper_report_share::( + let prep_msg_4 = transcript_4.prepare_messages[1].clone(); + let report_init_4 = generate_helper_report_init::>( *task.id(), report_metadata_4.clone(), - hpke_key.config(), - &transcript_4.public_share, + task.current_hpke_key().config(), + &transcript_4, Vec::new(), - &transcript_4.input_shares[1], ); - // report share 5 also gets aggregated into the second batch interval + // report_init_5 also gets aggregated into the second batch interval. let report_metadata_5 = ReportMetadata::new( random(), second_batch_interval_clock @@ -2736,40 +2843,34 @@ mod tests { let transcript_5 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_5.id(), - &0, + &measurement, ); - let (prep_state_5, _) = transcript_5.helper_prep_state(0); + let (prep_state_5, _) = transcript_5.helper_prep_state(1); let out_share_5 = transcript_5.output_share(Role::Helper); - let prep_msg_5 = transcript_5.prepare_messages[0].clone(); - let report_share_5 = generate_helper_report_share::( + let prep_msg_5 = transcript_5.prepare_messages[1].clone(); + let report_init_5 = generate_helper_report_init::>( *task.id(), report_metadata_5.clone(), - hpke_key.config(), - &transcript_5.public_share, + task.current_hpke_key().config(), + &transcript_5, Vec::new(), - &transcript_5.input_shares[1], ); datastore .run_tx(|tx| { - let task = task.clone(); + let (task, aggregation_param) = (task.clone(), aggregation_param.clone()); let (report_share_3, report_share_4, report_share_5) = ( - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), + report_init_3.report_share().clone(), + report_init_4.report_share().clone(), + report_init_5.report_share().clone(), ); let (prep_state_3, prep_state_4, prep_state_5) = ( prep_state_3.clone(), prep_state_4.clone(), prep_state_5.clone(), ); - let (report_metadata_3, report_metadata_4, report_metadata_5) = ( - report_metadata_3.clone(), - report_metadata_4.clone(), - report_metadata_5.clone(), - ); Box::pin(async move { tx.put_report_share(task.id(), &report_share_3).await?; @@ -2777,13 +2878,13 @@ mod tests { tx.put_report_share(task.id(), &report_share_5).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2793,39 +2894,39 @@ mod tests { .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - *report_metadata_3.id(), - *report_metadata_3.time(), + *report_share_3.metadata().id(), + *report_share_3.metadata().time(), 3, None, ReportAggregationState::Waiting(prep_state_3, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - *report_metadata_4.id(), - *report_metadata_4.time(), + *report_share_4.metadata().id(), + *report_share_4.metadata().time(), 4, None, ReportAggregationState::Waiting(prep_state_4, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - *report_metadata_5.id(), - *report_metadata_5.time(), + *report_share_5.metadata().id(), + *report_share_5.metadata().time(), 5, None, ReportAggregationState::Waiting(prep_state_5, None), @@ -2843,22 +2944,28 @@ mod tests { Vec::from([ PrepareStep::new( *report_metadata_3.id(), - PrepareStepResult::Continued(prep_msg_3.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_3.get_encoded(), + }, ), PrepareStep::new( *report_metadata_4.id(), - PrepareStepResult::Continued(prep_msg_4.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_4.get_encoded(), + }, ), PrepareStep::new( *report_metadata_5.id(), - PrepareStepResult::Continued(prep_msg_5.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_5.get_encoded(), + }, ), ]), ); - // Create aggregator handler, send request, and parse response. + // Create aggregator filter, send request, and parse response. let handler = aggregator_handler( - Arc::clone(&datastore), + datastore.clone(), first_batch_interval_clock, default_aggregator_config(), ) @@ -2872,12 +2979,16 @@ mod tests { // be the same) let mut batch_aggregations: Vec<_> = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, >( tx, @@ -2892,7 +3003,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds() * 2), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -2901,10 +3012,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -2920,7 +3031,7 @@ mod tests { let first_aggregate_share = vdaf .aggregate( - &(), + &aggregation_param, [out_share_0, out_share_1, out_share_3].into_iter().cloned(), ) .unwrap(); @@ -2930,7 +3041,7 @@ mod tests { let second_aggregate_share = vdaf .aggregate( - &(), + &aggregation_param, [out_share_2, out_share_4, out_share_5].into_iter().cloned(), ) .unwrap(); @@ -2951,7 +3062,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, first_aggregate_share, 3, @@ -2968,7 +3079,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param, 0, second_aggregate_share, 3, @@ -2980,7 +3091,7 @@ mod tests { } #[tokio::test] - async fn aggregate_continue_leader_sends_non_continue_transition() { + async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { // Prepare datastore & request. install_test_trace_subscriber(); @@ -3016,25 +3127,23 @@ mod tests { ) .await?; - tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + aggregation_job_id, + dummy_vdaf::AggregationParam(0), + (), + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - dummy_vdaf::Vdaf, - >::new( + tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), aggregation_job_id, *report_metadata.id(), @@ -3054,7 +3163,7 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report_metadata.id(), - PrepareStepResult::Finished, + PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), )]), ); @@ -3114,25 +3223,23 @@ mod tests { ), ) .await?; - tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + aggregation_job_id, + dummy_vdaf::AggregationParam(0), + (), + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - dummy_vdaf::Vdaf, - >::new( + tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), aggregation_job_id, *report_metadata.id(), @@ -3152,7 +3259,10 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report_metadata.id(), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -3175,11 +3285,13 @@ mod tests { let (task, report_metadata) = (task.clone(), report_metadata.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( task.id(), &aggregation_job_id, ) - .await.unwrap().unwrap(); + .await + .unwrap() + .unwrap(); let report_aggregation = tx .get_report_aggregation( &dummy_vdaf::Vdaf::default(), @@ -3188,7 +3300,9 @@ mod tests { &aggregation_job_id, report_metadata.id(), ) - .await.unwrap().unwrap(); + .await + .unwrap() + .unwrap(); Ok((aggregation_job, report_aggregation)) }) }) @@ -3263,25 +3377,23 @@ mod tests { ), ) .await?; - tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + aggregation_job_id, + dummy_vdaf::AggregationParam(0), + (), + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - dummy_vdaf::Vdaf, - >::new( + tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), aggregation_job_id, *report_metadata.id(), @@ -3303,7 +3415,10 @@ mod tests { ReportId::from( [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], // not the same as above ), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -3383,26 +3498,24 @@ mod tests { ) .await?; - tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + aggregation_job_id, + dummy_vdaf::AggregationParam(0), + (), + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - dummy_vdaf::Vdaf, - >::new( + tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), aggregation_job_id, *report_metadata_0.id(), @@ -3412,10 +3525,7 @@ mod tests { ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - dummy_vdaf::Vdaf, - >::new( + tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), aggregation_job_id, *report_metadata_1.id(), @@ -3437,11 +3547,17 @@ mod tests { // Report IDs are in opposite order to what was stored in the datastore. PrepareStep::new( *report_metadata_1.id(), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), ]), ); @@ -3498,25 +3614,23 @@ mod tests { ), ) .await?; - tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - aggregation_job_id, - dummy_vdaf::AggregationParam(0), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + aggregation_job_id, + dummy_vdaf::AggregationParam(0), + (), + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - dummy_vdaf::Vdaf, - >::new( + tx.put_report_aggregation(&ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), aggregation_job_id, *report_metadata.id(), @@ -3536,7 +3650,10 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -4094,13 +4211,12 @@ mod tests { ) .unwrap() ); - assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[0], + collect_resp.leader_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_time_interval(batch_interval), @@ -4118,7 +4234,7 @@ mod tests { test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[1], + collect_resp.helper_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_time_interval(batch_interval), @@ -4170,21 +4286,22 @@ mod tests { .run_tx(|tx| { let task = test_case.task.clone(); Box::pin(async move { - tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()) + tx.put_batch_aggregation( + &BatchAggregation::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + Interval::new( + Time::from_seconds_since_epoch(0), + *task.time_precision(), + ) .unwrap(), - dummy_vdaf::AggregationParam(0), - 0, - dummy_vdaf::AggregateShare(0), - 10, - interval, - ReportIdChecksum::get_decoded(&[2; 32]).unwrap(), - )) + dummy_vdaf::AggregationParam(0), + 0, + dummy_vdaf::AggregateShare(0), + 10, + interval, + ReportIdChecksum::get_decoded(&[2; 32]).unwrap(), + ), + ) .await }) }) @@ -4245,20 +4362,18 @@ mod tests { .run_tx(|tx| { let task = test_case.task.clone(); Box::pin(async move { - tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, - TimeInterval, - dummy_vdaf::Vdaf, - >::new( - *task.id(), - interval, - dummy_vdaf::AggregationParam(0), - 0, - dummy_vdaf::AggregateShare(0), - 10, - interval, - ReportIdChecksum::get_decoded(&[2; 32]).unwrap(), - )) + tx.put_batch_aggregation( + &BatchAggregation::<0, TimeInterval, dummy_vdaf::Vdaf>::new( + *task.id(), + interval, + dummy_vdaf::AggregationParam(0), + 0, + dummy_vdaf::AggregateShare(0), + 10, + interval, + ReportIdChecksum::get_decoded(&[2; 32]).unwrap(), + ), + ) .await }) }) @@ -4618,7 +4733,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + 0, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -4639,7 +4754,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + 0, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -4660,7 +4775,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + 0, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -4681,7 +4796,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + 0, TimeInterval, dummy_vdaf::Vdaf, >::new( diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index 0a8a3a162..f1368f899 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -779,9 +779,8 @@ mod tests { // YAML contains no task ID, VDAF verify keys, aggregator auth tokens, collector auth tokens // or HPKE keys. let serialized_task_yaml = r#" -- aggregator_endpoints: - - https://leader - - https://helper +- leader_aggregator_endpoint: https://leader + helper_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 @@ -802,9 +801,8 @@ mod tests { aggregator_auth_tokens: [] collector_auth_tokens: [] hpke_keys: [] -- aggregator_endpoints: - - https://leader - - https://helper +- leader_aggregator_endpoint: https://leader + helper_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 diff --git a/aggregator_api/src/lib.rs b/aggregator_api/src/lib.rs index 737922810..8a5f7ca78 100644 --- a/aggregator_api/src/lib.rs +++ b/aggregator_api/src/lib.rs @@ -127,7 +127,8 @@ async fn post_task( let task = Arc::new( Task::new( /* task_id */ random(), - /* aggregator_endpoints */ req.aggregator_endpoints, + /* leader_aggregator_endpoint */ req.leader_aggregator_endpoint, + /* helper_aggregator_endpoint */ req.helper_aggregator_endpoint, /* query_type */ req.query_type, /* vdaf */ req.vdaf, /* role */ req.role, @@ -243,7 +244,8 @@ mod models { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct PostTaskReq { - pub(crate) aggregator_endpoints: Vec, + pub(crate) leader_aggregator_endpoint: Url, + pub(crate) helper_aggregator_endpoint: Url, pub(crate) query_type: QueryType, pub(crate) vdaf: VdafInstance, pub(crate) role: Role, @@ -257,7 +259,8 @@ mod models { #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct TaskResp { pub(crate) task_id: TaskId, - pub(crate) aggregator_endpoints: Vec, + pub(crate) leader_aggregator_endpoint: Url, + pub(crate) helper_aggregator_endpoint: Url, pub(crate) query_type: QueryType, pub(crate) vdaf: VdafInstance, pub(crate) role: Role, @@ -299,7 +302,8 @@ mod models { Self { task_id: *task.id(), - aggregator_endpoints: task.aggregator_endpoints().to_vec(), + leader_aggregator_endpoint: task.leader_aggregator_endpoint().clone(), + helper_aggregator_endpoint: task.helper_aggregator_endpoint().clone(), query_type: *task.query_type(), vdaf: task.vdaf().clone(), role: *task.role(), @@ -501,10 +505,8 @@ mod tests { // Verify: posting a task creates a new task which matches the request. let req = PostTaskReq { - aggregator_endpoints: Vec::from([ - "http://leader.endpoint".try_into().unwrap(), - "http://helper.endpoint".try_into().unwrap(), - ]), + leader_aggregator_endpoint: "http://leader.endpoint".try_into().unwrap(), + helper_aggregator_endpoint: "http://helper.endpoint".try_into().unwrap(), query_type: QueryType::TimeInterval, vdaf: VdafInstance::Prio3Count, role: Role::Leader, @@ -547,7 +549,14 @@ mod tests { .expect("task was not created"); // Verify that the task written to the datastore matches the request... - assert_eq!(&req.aggregator_endpoints, got_task.aggregator_endpoints()); + assert_eq!( + &req.leader_aggregator_endpoint, + got_task.leader_aggregator_endpoint() + ); + assert_eq!( + &req.helper_aggregator_endpoint, + got_task.helper_aggregator_endpoint() + ); assert_eq!(&req.query_type, got_task.query_type()); assert_eq!(&req.vdaf, got_task.vdaf()); assert_eq!(&req.role, got_task.role()); @@ -856,10 +865,8 @@ mod tests { fn post_task_req_serialization() { assert_tokens( &PostTaskReq { - aggregator_endpoints: Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + leader_aggregator_endpoint: "https://example.com/".parse().unwrap(), + helper_aggregator_endpoint: "https://example.net/".parse().unwrap(), query_type: QueryType::FixedSize { max_batch_size: 999, }, @@ -880,13 +887,12 @@ mod tests { &[ Token::Struct { name: "PostTaskReq", - len: 9, + len: 10, }, - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::StructVariant { name: "QueryType", @@ -955,10 +961,8 @@ mod tests { fn task_resp_serialization() { let task = Task::new( TaskId::from([0u8; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::FixedSize { max_batch_size: 999, }, @@ -1001,15 +1005,14 @@ mod tests { &[ Token::Struct { name: "TaskResp", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::StructVariant { name: "QueryType", diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 03c237d98..1e3e7b6fd 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -53,7 +53,6 @@ use std::{ use tokio::try_join; use tokio_postgres::{error::SqlState, row::RowIndex, IsolationLevel, Row, Statement, ToStatement}; use tracing::error; -use url::Url; #[cfg(feature = "test-util")] #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] @@ -315,20 +314,15 @@ impl Transaction<'_, C> { /// Writes a task into the datastore. #[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)] pub async fn put_task(&self, task: &Task) -> Result<(), Error> { - let endpoints: Vec<_> = task - .aggregator_endpoints() - .iter() - .map(Url::as_str) - .collect(); - // Main task insert. let stmt = self .prepare_cached( "INSERT INTO tasks ( - task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)", + task_id, aggregator_role, leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, + task_expiration, report_expiry_age, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)", ) .await?; self.execute( @@ -336,7 +330,8 @@ impl Transaction<'_, C> { &[ /* task_id */ &task.id().as_ref(), /* aggregator_role */ &AggregatorRole::from_role(*task.role())?, - /* aggregator_endpoints */ &endpoints, + /* leader_aggregator_endpoint */ &task.leader_aggregator_endpoint().as_str(), + /* helper_aggregator_endpoint */ &task.helper_aggregator_endpoint().as_str(), /* query_type */ &Json(task.query_type()), /* vdaf */ &Json(task.vdaf()), /* max_batch_query_count */ @@ -519,9 +514,9 @@ impl Transaction<'_, C> { let params: &[&(dyn ToSql + Sync)] = &[&task_id.as_ref()]; let stmt = self .prepare_cached( - "SELECT aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config + "SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, + query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, + min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config FROM tasks WHERE task_id = $1", ) .await?; @@ -591,9 +586,10 @@ impl Transaction<'_, C> { pub async fn get_tasks(&self) -> Result, Error> { let stmt = self .prepare_cached( - "SELECT task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config + "SELECT task_id, aggregator_role, leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, + task_expiration, report_expiry_age, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config FROM tasks", ) .await?; @@ -728,11 +724,10 @@ impl Transaction<'_, C> { ) -> Result { // Scalar task parameters. let aggregator_role: AggregatorRole = row.get("aggregator_role"); - let endpoints = row - .get::<_, Vec>("aggregator_endpoints") - .into_iter() - .map(|endpoint| Ok(Url::parse(&endpoint)?)) - .collect::>()?; + let leader_aggregator_endpoint = + row.get::<_, String>("leader_aggregator_endpoint").parse()?; + let helper_aggregator_endpoint = + row.get::<_, String>("helper_aggregator_endpoint").parse()?; let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?; @@ -817,7 +812,8 @@ impl Transaction<'_, C> { Ok(Task::new( *task_id, - endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, aggregator_role.as_role(), @@ -1930,13 +1926,12 @@ impl Transaction<'_, C> { let prep_msg = prep_msg_bytes .map(|bytes| A::PrepareMessage::get_decoded_with_param(&prep_state, &bytes)) .transpose()?; - ReportAggregationState::Waiting(prep_state, prep_msg) } ReportAggregationStateCode::Finished => { let aggregation_param = A::AggregationParam::get_decoded(aggregation_param_bytes)?; - ReportAggregationState::Finished(A::OutputShare::get_decoded_with_param( + let output_share = A::OutputShare::get_decoded_with_param( &(vdaf, &aggregation_param), &out_share_bytes.ok_or_else(|| { Error::DbState( @@ -1944,7 +1939,8 @@ impl Transaction<'_, C> { .to_string(), ) })?, - )?) + )?; + ReportAggregationState::Finished(output_share) } ReportAggregationStateCode::Failed => { @@ -4546,9 +4542,9 @@ pub mod models { pub fn state_code(&self) -> ReportAggregationStateCode { match self { ReportAggregationState::Start => ReportAggregationStateCode::Start, - ReportAggregationState::Waiting(_, _) => ReportAggregationStateCode::Waiting, - ReportAggregationState::Finished(_) => ReportAggregationStateCode::Finished, - ReportAggregationState::Failed(_) => ReportAggregationStateCode::Failed, + ReportAggregationState::Waiting { .. } => ReportAggregationStateCode::Waiting, + ReportAggregationState::Finished { .. } => ReportAggregationStateCode::Finished, + ReportAggregationState::Failed { .. } => ReportAggregationStateCode::Failed, ReportAggregationState::Invalid => ReportAggregationStateCode::Invalid, } } @@ -5357,7 +5353,7 @@ mod tests { use futures::future::try_join_all; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, test_util::{ dummy_vdaf::{self, AggregateShare, AggregationParam}, install_test_trace_subscriber, run_vdaf, @@ -6769,7 +6765,7 @@ mod tests { tx.put_task(&task).await?; for aggregation_job_id in aggregation_job_ids { tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( @@ -6786,20 +6782,18 @@ mod tests { } // Write an aggregation job that is finished. We don't want to retrieve this one. - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(1), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ), + ) .await?; // Write an aggregation job for a task that we are taking on the helper role for. @@ -6811,20 +6805,18 @@ mod tests { ) .build(); tx.put_task(&helper_task).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *helper_task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *helper_task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await }) }) @@ -7037,7 +7029,7 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( &random(), &random(), ) @@ -7051,7 +7043,7 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.update_aggregation_job::( + tx.update_aggregation_job::( &AggregationJob::new( random(), random(), @@ -7182,12 +7174,12 @@ mod tests { let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); + let verify_key: [u8; VERIFY_KEY_LEN] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); - let leader_prep_state = vdaf_transcript.leader_prep_state(0); + let (leader_prep_state, _) = vdaf_transcript.leader_prep_state(0); for (ord, state) in [ - ReportAggregationState::::Start, + ReportAggregationState::::Start, ReportAggregationState::Waiting( leader_prep_state.clone(), Some(vdaf_transcript.prepare_messages[0].clone()), @@ -7200,27 +7192,30 @@ mod tests { .into_iter() .enumerate() { - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); + let task_id = random(); let aggregation_job_id = random(); let time = Time::from_seconds_since_epoch(12345); - let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); let want_report_aggregation = ds .run_tx(|tx| { - let (task, state) = (task.clone(), state.clone()); + let state = state.clone(); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_task( + &TaskBuilder::new( + task::QueryType::TimeInterval, + VdafInstance::Prio3Count, + Role::Leader, + ) + .with_id(task_id) + .build(), + ) + .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( - *task.id(), + task_id, aggregation_job_id, (), (), @@ -7234,7 +7229,7 @@ mod tests { )) .await?; tx.put_report_share( - task.id(), + &task_id, &ReportShare::new( ReportMetadata::new(report_id, time), Vec::from("public_share"), @@ -7248,14 +7243,17 @@ mod tests { .await?; let report_aggregation = ReportAggregation::new( - *task.id(), + task_id, aggregation_job_id, report_id, time, ord.try_into().unwrap(), Some(PrepareStep::new( report_id, - PrepareStepResult::Continued(format!("prep_msg_{ord}").into()), + PrepareStepResult::Continued { + prep_msg: format!("prep_msg_{ord}").into(), + prep_share: format!("prep_share_{ord}").into(), + }, )), state, ); @@ -7268,12 +7266,12 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), &Role::Leader, - task.id(), + &task_id, &aggregation_job_id, &report_id, ) @@ -7307,12 +7305,12 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), &Role::Leader, - task.id(), + &task_id, &aggregation_job_id, &report_id, ) @@ -7503,7 +7501,7 @@ mod tests { let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); + let verify_key: [u8; VERIFY_KEY_LEN] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); let task = TaskBuilder::new( @@ -7520,13 +7518,13 @@ mod tests { let (task, prep_msg, prep_state, output_share) = ( task.clone(), vdaf_transcript.prepare_messages[0].clone(), - vdaf_transcript.leader_prep_state(0).clone(), + vdaf_transcript.leader_prep_state(0).0.clone(), vdaf_transcript.output_share(Role::Leader).clone(), ); Box::pin(async move { tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( @@ -7543,8 +7541,10 @@ mod tests { let mut want_report_aggregations = Vec::new(); for (ord, state) in [ - ReportAggregationState::::Start, + ReportAggregationState::::Start, ReportAggregationState::Waiting(prep_state.clone(), Some(prep_msg)), + ReportAggregationState::Waiting(prep_state.clone(), None), + ReportAggregationState::Finished(output_share.clone()), ReportAggregationState::Finished(output_share), ReportAggregationState::Failed(ReportShareError::VdafPrepError), ReportAggregationState::Invalid, @@ -7573,7 +7573,12 @@ mod tests { report_id, time, ord.try_into().unwrap(), - Some(PrepareStep::new(report_id, PrepareStepResult::Finished)), + Some(PrepareStep::new( + report_id, + PrepareStepResult::Finished { + prep_msg: format!("prep_msg_{ord}").into(), + }, + )), state.clone(), ); tx.put_report_aggregation(&report_aggregation).await?; @@ -9444,7 +9449,7 @@ mod tests { clock.now(), 0, None, - ReportAggregationState::Finished(dummy_vdaf::OutputShare()), // Counted among min_size and max_size. + ReportAggregationState::Finished(dummy_vdaf::OutputShare(0)), // Counted among min_size and max_size. ); let report_aggregation_1_1 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), @@ -9453,7 +9458,7 @@ mod tests { clock.now(), 1, None, - ReportAggregationState::Finished(dummy_vdaf::OutputShare()), // Counted among min_size and max_size. + ReportAggregationState::Finished(dummy_vdaf::OutputShare(0)), // Counted among min_size and max_size. ); let report_aggregation_1_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index e9b129492..ff9bb04c2 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -14,11 +14,7 @@ use janus_messages::{ }; use rand::{distributions::Standard, random, thread_rng, Rng}; use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; -use std::{ - array::TryFromSliceError, - collections::HashMap, - fmt::{self, Formatter}, -}; +use std::{array::TryFromSliceError, collections::HashMap}; use url::Url; /// Errors that methods and functions in this module may return. @@ -77,10 +73,12 @@ impl TryFrom<&SecretBytes> for VerifyKey { pub struct Task { /// Unique identifier for the task. task_id: TaskId, - /// URLs relative to which aggregator API endpoints are found. The first - /// entry is the leader's. - #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - aggregator_endpoints: Vec, + /// URL relative to which the Leader's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + leader_aggregator_endpoint: Url, + /// URL relative to which the Helper's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + helper_aggregator_endpoint: Url, /// The query type this task uses to generate batches. query_type: QueryType, /// The VDAF this task executes. @@ -122,7 +120,8 @@ impl Task { #[allow(clippy::too_many_arguments)] pub fn new>( task_id: TaskId, - mut aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -138,13 +137,6 @@ impl Task { collector_auth_tokens: Vec, hpke_keys: I, ) -> Result { - // Ensure provided aggregator endpoints end with a slash, as we will be joining additional - // path segments into these endpoints & the Url::join implementation is persnickety about - // the slash at the end of the path. - for url in &mut aggregator_endpoints { - url_ensure_trailing_slash(url); - } - // Compute hpke_configs mapping cfg.id -> (cfg, key). let hpke_keys: HashMap = hpke_keys .into_iter() @@ -153,7 +145,8 @@ impl Task { let task = Self { task_id, - aggregator_endpoints, + leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint), + helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint), query_type, vdaf, role, @@ -174,10 +167,6 @@ impl Task { } fn validate(&self) -> Result<(), Error> { - // DAP currently only supports configurations of exactly two aggregators. - if self.aggregator_endpoints.len() != 2 { - return Err(Error::InvalidParameter("aggregator_endpoints")); - } if !self.role.is_aggregator() { return Err(Error::InvalidParameter("role")); } @@ -203,9 +192,14 @@ impl Task { &self.task_id } - /// Retrieves the aggregator endpoints associated with this task in natural order. - pub fn aggregator_endpoints(&self) -> &[Url] { - &self.aggregator_endpoints + /// Retrieves the Leader's aggregator endpoint associated with this task. + pub fn leader_aggregator_endpoint(&self) -> &Url { + &self.leader_aggregator_endpoint + } + + /// Retrieves the Helper's aggregator endpoint associated with this task. + pub fn helper_aggregator_endpoint(&self) -> &Url { + &self.helper_aggregator_endpoint } /// Retrieves the query type associated with this task. @@ -303,12 +297,6 @@ impl Task { } } - /// Returns the [`Url`] relative to which the server performing `role` serves its API. - pub fn aggregator_url(&self, role: &Role) -> Result<&Url, Error> { - let index = role.index().ok_or(Error::InvalidParameter(role.as_str()))?; - Ok(&self.aggregator_endpoints[index]) - } - /// Returns the [`AuthenticationToken`] currently used by this aggregator to authenticate itself /// to other aggregators. pub fn primary_aggregator_auth_token(&self) -> &AuthenticationToken { @@ -363,14 +351,14 @@ impl Task { /// Returns the URI at which reports may be uploaded for this task. pub fn report_upload_uri(&self) -> Result { Ok(self - .aggregator_url(&Role::Leader)? + .leader_aggregator_endpoint() .join(&format!("{}/reports", self.tasks_path()))?) } /// Returns the URI at which the helper resource for the specified aggregation job ID can be /// accessed. pub fn aggregation_job_uri(&self, aggregation_job_id: &AggregationJobId) -> Result { - Ok(self.aggregator_url(&Role::Helper)?.join(&format!( + Ok(self.helper_aggregator_endpoint().join(&format!( "{}/aggregation_jobs/{aggregation_job_id}", self.tasks_path() ))?) @@ -379,34 +367,27 @@ impl Task { /// Returns the URI at which the helper aggregate shares resource can be accessed. pub fn aggregate_shares_uri(&self) -> Result { Ok(self - .aggregator_url(&Role::Helper)? + .helper_aggregator_endpoint() .join(&format!("{}/aggregate_shares", self.tasks_path()))?) } /// Returns the URI at which the leader resource for the specified collection job ID can be /// accessed. pub fn collection_job_uri(&self, collection_job_id: &CollectionJobId) -> Result { - Ok(self.aggregator_url(&Role::Leader)?.join(&format!( + Ok(self.leader_aggregator_endpoint().join(&format!( "{}/collection_jobs/{collection_job_id}", self.tasks_path() ))?) } } -fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for url in urls { - list.entry(&format!("{url}")); - } - list.finish() -} - /// SerializedTask is an intermediate representation for tasks being serialized via the Serialize & /// Deserialize traits. #[derive(Clone, Serialize, Deserialize)] pub struct SerializedTask { task_id: Option, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -499,7 +480,8 @@ impl Serialize for Task { SerializedTask { task_id: Some(self.task_id), - aggregator_endpoints: self.aggregator_endpoints.clone(), + leader_aggregator_endpoint: self.leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint: self.helper_aggregator_endpoint.clone(), query_type: self.query_type, vdaf: self.vdaf.clone(), role: self.role, @@ -567,7 +549,8 @@ impl TryFrom for Task { Task::new( task_id, - serialized_task.aggregator_endpoints, + serialized_task.leader_aggregator_endpoint, + serialized_task.helper_aggregator_endpoint, serialized_task.query_type, serialized_task.vdaf, serialized_task.role, @@ -594,7 +577,6 @@ impl<'de> Deserialize<'de> for Task { } } -// This is public to allow use in integration tests. #[cfg(feature = "test-util")] #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] pub mod test_util { @@ -604,7 +586,7 @@ pub mod test_util { }; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair}, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LEN}, time::DurationExt, }; use janus_messages::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId, Time}; @@ -620,7 +602,7 @@ pub mod test_util { // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, + _ => VERIFY_KEY_LEN, } } @@ -664,10 +646,8 @@ pub mod test_util { Self( Task::new( task_id, - Vec::from([ - "https://leader.endpoint".parse().unwrap(), - "https://helper.endpoint".parse().unwrap(), - ]), + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), query_type, vdaf, role, @@ -692,17 +672,20 @@ pub mod test_util { Self(Task { task_id, ..self.0 }) } - /// Associates the eventual task with the given aggregator endpoints. - pub fn with_aggregator_endpoints(self, aggregator_endpoints: Vec) -> Self { + /// Associates the eventual task with the given aggregator endpoint for the Leader. + pub fn with_leader_aggregator_endpoint(self, leader_aggregator_endpoint: Url) -> Self { Self(Task { - aggregator_endpoints, + leader_aggregator_endpoint, ..self.0 }) } - /// Retrieves the aggregator endpoints associated with this task builder. - pub fn aggregator_endpoints(&self) -> &[Url] { - self.0.aggregator_endpoints() + /// Associates the eventual task with the given aggregator endpoint for the Helper. + pub fn with_helper_aggregator_endpoint(self, helper_aggregator_endpoint: Url) -> Self { + Self(Task { + helper_aggregator_endpoint, + ..self.0 + }) } /// Associates the eventual task with the given aggregator role. @@ -800,6 +783,12 @@ pub mod test_util { }) } + /// Returns a view of the task that would currently be built. + pub fn task(&self) -> &Task { + self.0.validate().unwrap(); + &self.0 + } + /// Consumes this task builder & produces a [`Task`] with the given specifications. pub fn build(self) -> Task { self.0.validate().unwrap(); @@ -818,7 +807,7 @@ mod tests { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VERIFY_KEY_LEN}, test_util::roundtrip_encoding, time::DurationExt, }; @@ -828,7 +817,6 @@ mod tests { }; use rand::random; use serde_test::{assert_tokens, Token}; - use url::Url; #[test] fn task_serialization() { @@ -852,14 +840,12 @@ mod tests { // As leader, we receive an error if no collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -876,14 +862,12 @@ mod tests { // As leader, we receive no error if a collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -900,14 +884,12 @@ mod tests { // As helper, we receive no error if no collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -924,14 +906,12 @@ mod tests { // As helper, we receive an error if a collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -950,14 +930,12 @@ mod tests { fn aggregator_endpoints_end_in_slash() { let task = Task::new( random(), - Vec::from([ - "http://leader_endpoint/foo/bar".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint/foo/bar".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -972,11 +950,12 @@ mod tests { .unwrap(); assert_eq!( - task.aggregator_endpoints, - Vec::from([ - "http://leader_endpoint/foo/bar/".parse().unwrap(), - "http://helper_endpoint/".parse().unwrap() - ]) + task.leader_aggregator_endpoint, + "http://leader_endpoint/foo/bar/".parse().unwrap(), + ); + assert_eq!( + task.helper_aggregator_endpoint, + "http://helper_endpoint/".parse().unwrap(), ); } @@ -999,10 +978,8 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("https://leader.com/prefix/").unwrap(), - Url::parse("https://helper.com/prefix/").unwrap(), - ])) + .with_leader_aggregator_endpoint("https://leader.com/prefix/".parse().unwrap()) + .with_helper_aggregator_endpoint("https://helper.com/prefix/".parse().unwrap()) .build(), ), ] { @@ -1030,10 +1007,8 @@ mod tests { assert_tokens( &Task::new( TaskId::from([0; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, @@ -1068,16 +1043,15 @@ mod tests { &[ Token::Struct { name: "SerializedTask", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Some, Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::UnitVariant { name: "QueryType", @@ -1193,10 +1167,8 @@ mod tests { assert_tokens( &Task::new( TaskId::from([255; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::FixedSize { max_batch_size: 10 }, VdafInstance::Prio3CountVec { length: 8 }, Role::Helper, @@ -1231,16 +1203,15 @@ mod tests { &[ Token::Struct { name: "SerializedTask", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Some, Token::Str("__________________________________________8"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::StructVariant { name: "QueryType", @@ -1370,10 +1341,8 @@ mod tests { let bad_agg_auth_token = SerializedTask { task_id: Some(random()), - aggregator_endpoints: Vec::from([ - "https://www.example.com/".parse().unwrap(), - "https://www.example.net/".parse().unwrap(), - ]), + leader_aggregator_endpoint: "https://www.example.com/".parse().unwrap(), + helper_aggregator_endpoint: "https://www.example.net/".parse().unwrap(), query_type: QueryType::TimeInterval, vdaf: VdafInstance::Prio3Count, role: Role::Helper, @@ -1396,10 +1365,8 @@ mod tests { let bad_collector_auth_token = SerializedTask { task_id: Some(random()), - aggregator_endpoints: Vec::from([ - "https://www.example.com/".parse().unwrap(), - "https://www.example.net/".parse().unwrap(), - ]), + leader_aggregator_endpoint: "https://www.example.com/".parse().unwrap(), + helper_aggregator_endpoint: "https://www.example.net/".parse().unwrap(), query_type: QueryType::TimeInterval, vdaf: VdafInstance::Prio3Count, role: Role::Leader, diff --git a/client/Cargo.toml b/client/Cargo.toml index ba99d7999..b51846263 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,6 +14,7 @@ backoff = { version = "0.4.0", features = ["tokio"] } derivative = "2.2.0" http = "0.2.9" http-api-problem = "0.56.0" +itertools.workspace = true janus_core.workspace = true janus_messages.workspace = true prio.workspace = true diff --git a/client/src/lib.rs b/client/src/lib.rs index 20d820c2e..eb508a5be 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -4,6 +4,7 @@ use backoff::ExponentialBackoff; use derivative::Derivative; use http::header::CONTENT_TYPE; use http_api_problem::HttpApiProblem; +use itertools::Itertools; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, http::response_to_problem_details, @@ -12,18 +13,15 @@ use janus_core::{ time::{Clock, TimeExt}, }; use janus_messages::{ - Duration, HpkeCiphertext, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, - Report, ReportId, ReportMetadata, Role, TaskId, + Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, + ReportMetadata, Role, TaskId, }; use prio::{ codec::{Decode, Encode}, vdaf, }; use rand::random; -use std::{ - fmt::{self, Formatter}, - io::Cursor, -}; +use std::io::Cursor; use url::Url; #[derive(Debug, thiserror::Error)] @@ -60,10 +58,12 @@ static CLIENT_USER_AGENT: &str = concat!( pub struct ClientParameters { /// Unique identifier for the task. task_id: TaskId, - /// URLs relative to which aggregator API endpoints are found. The first - /// entry is the leader's. - #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - aggregator_endpoints: Vec, + /// URL relative to which the Leader's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + leader_aggregator_endpoint: Url, + /// URL relative to which the Helper's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + helper_aggregator_endpoint: Url, /// The time precision of the task. This value is shared by all parties in the protocol, and is /// used to compute report timestamps. time_precision: Duration, @@ -73,10 +73,16 @@ pub struct ClientParameters { impl ClientParameters { /// Creates a new set of client task parameters. - pub fn new(task_id: TaskId, aggregator_endpoints: Vec, time_precision: Duration) -> Self { + pub fn new( + task_id: TaskId, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, + time_precision: Duration, + ) -> Self { Self::new_with_backoff( task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, time_precision, http_request_exponential_backoff(), ) @@ -85,35 +91,32 @@ impl ClientParameters { /// Creates a new set of client task parameters with non-default HTTP request retry parameters. pub fn new_with_backoff( task_id: TaskId, - mut aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, time_precision: Duration, http_request_retry_parameters: ExponentialBackoff, ) -> Self { - // Ensure provided aggregator endpoints end with a slash, as we will be joining additional - // path segments into these endpoints & the Url::join implementation is persnickety about - // the slash at the end of the path. - for url in &mut aggregator_endpoints { - url_ensure_trailing_slash(url); - } - Self { task_id, - aggregator_endpoints, + leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint), + helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint), time_precision, http_request_retry_parameters, } } - /// The URL relative to which the API endpoints for the aggregator may be - /// found, if the role is an aggregator, or an error otherwise. + /// The URL relative to which the API endpoints for the aggregator may be found, if the role is + /// an aggregator, or an error otherwise. fn aggregator_endpoint(&self, role: &Role) -> Result<&Url, Error> { - Ok(&self.aggregator_endpoints[role - .index() - .ok_or(Error::InvalidParameter("role is not an aggregator"))?]) + match role { + Role::Leader => Ok(&self.leader_aggregator_endpoint), + Role::Helper => Ok(&self.helper_aggregator_endpoint), + _ => Err(Error::InvalidParameter("role is not an aggregator")), + } } - /// URL from which the HPKE configuration for the server filling `role` may - /// be fetched per draft-gpew-priv-ppm §4.3.1 + /// URL from which the HPKE configuration for the server filling `role` may be fetched per + /// draft-gpew-priv-ppm §4.3.1 fn hpke_config_endpoint(&self, role: &Role) -> Result { Ok(self.aggregator_endpoint(role)?.join("hpke_config")?) } @@ -121,21 +124,13 @@ impl ClientParameters { // URI to which reports may be uploaded for the provided task. fn reports_resource_uri(&self, task_id: &TaskId) -> Result { Ok(self - .aggregator_endpoint(&Role::Leader)? + .leader_aggregator_endpoint .join(&format!("tasks/{task_id}/reports"))?) } } -fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for url in urls { - list.entry(&format!("{url}")); - } - list.finish() -} - -/// Fetches HPKE configuration from the specified aggregator using the -/// aggregator endpoints in the provided [`ClientParameters`]. +/// Fetches HPKE configuration from the specified aggregator using the aggregator endpoints in the +/// provided [`ClientParameters`]. #[tracing::instrument(err)] pub async fn aggregator_hpke_config( client_parameters: &ClientParameters, @@ -228,14 +223,14 @@ impl, C: Clock> Client { let report_metadata = ReportMetadata::new(report_id, time); let encoded_public_share = public_share.get_encoded(); - let encrypted_input_shares: Vec = [ + let (leader_encrypted_input_share, helper_encrypted_input_share) = [ (&self.leader_hpke_config, &Role::Leader), (&self.helper_hpke_config, &Role::Helper), ] .into_iter() .zip(input_shares) .map(|((hpke_config, receiver_role), input_share)| { - Ok(hpke::seal( + hpke::seal( hpke_config, &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, receiver_role), &PlaintextInputShare::new( @@ -249,14 +244,16 @@ impl, C: Clock> Client { encoded_public_share.clone(), ) .get_encoded(), - )?) + ) }) - .collect::>()?; + .collect_tuple() + .expect("iterator to yield two items"); // expect safety: iterator contains two items. Ok(Report::new( report_metadata, encoded_public_share, - encrypted_input_shares, + leader_encrypted_input_share?, + helper_encrypted_input_share?, )) } @@ -309,14 +306,15 @@ mod tests { use url::Url; fn setup_client>( - server: &mut mockito::Server, + server: &mockito::Server, vdaf_client: V, ) -> Client { let server_url = Url::parse(&server.url()).unwrap(); Client::new( ClientParameters::new_with_backoff( random(), - Vec::from([server_url.clone(), server_url]), + server_url.clone(), + server_url, Duration::from_seconds(1), test_http_request_exponential_backoff(), ), @@ -332,19 +330,18 @@ mod tests { fn aggregator_endpoints_end_in_slash() { let client_parameters = ClientParameters::new( random(), - Vec::from([ - "http://leader_endpoint/foo/bar".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint/foo/bar".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), Duration::from_seconds(1), ); assert_eq!( - client_parameters.aggregator_endpoints, - Vec::from([ - "http://leader_endpoint/foo/bar/".parse().unwrap(), - "http://helper_endpoint/".parse().unwrap() - ]) + client_parameters.leader_aggregator_endpoint, + "http://leader_endpoint/foo/bar/".parse().unwrap() + ); + assert_eq!( + client_parameters.helper_aggregator_endpoint, + "http://helper_endpoint/".parse().unwrap() ); } @@ -352,7 +349,7 @@ mod tests { async fn upload_prio3_count() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -373,9 +370,9 @@ mod tests { #[tokio::test] async fn upload_prio3_invalid_measurement() { install_test_trace_subscriber(); - let mut server = mockito::Server::new_async().await; + let server = mockito::Server::new_async().await; let vdaf = Prio3::new_sum(2, 16).unwrap(); - let client = setup_client(&mut server, vdaf); + let client = setup_client(&server, vdaf); // 65536 is too big for a 16 bit sum and will be rejected by the VDAF. // Make sure we get the right error variant but otherwise we aren't @@ -387,7 +384,7 @@ mod tests { async fn upload_prio3_http_status_code() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -414,7 +411,7 @@ mod tests { async fn upload_problem_details() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -455,8 +452,12 @@ mod tests { async fn upload_bad_time_precision() { install_test_trace_subscriber(); - let client_parameters = - ClientParameters::new(random(), Vec::new(), Duration::from_seconds(0)); + let client_parameters = ClientParameters::new( + random(), + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), + Duration::from_seconds(0), + ); let client = Client::new( client_parameters, Prio3::new_count(2).unwrap(), @@ -472,9 +473,9 @@ mod tests { #[test] fn report_timestamp() { install_test_trace_subscriber(); - let mut server = mockito::Server::new(); + let server = mockito::Server::new(); let vdaf = Prio3::new_count(2).unwrap(); - let mut client = setup_client(&mut server, vdaf); + let mut client = setup_client(&server, vdaf); client.parameters.time_precision = Duration::from_seconds(100); client.clock = MockClock::new(Time::from_seconds_since_epoch(101)); diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 679bdf977..bdd501177 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -113,8 +113,6 @@ pub enum Error { Codec(#[from] prio::codec::CodecError), #[error("aggregate share decoding error")] AggregateShareDecode, - #[error("expected two aggregate shares, got {0}")] - AggregateShareCount(usize), #[error("VDAF error: {0}")] Vdaf(#[from] prio::vdaf::VdafError), #[error("HPKE error: {0}")] @@ -189,17 +187,14 @@ impl CollectorParameters { /// Creates a new set of collector task parameters. pub fn new( task_id: TaskId, - mut leader_endpoint: Url, + leader_endpoint: Url, authentication_token: AuthenticationToken, hpke_config: HpkeConfig, hpke_private_key: HpkePrivateKey, ) -> CollectorParameters { - // Ensure the provided leader endpoint ends with a slash. - url_ensure_trailing_slash(&mut leader_endpoint); - CollectorParameters { task_id, - leader_endpoint, + leader_endpoint: url_ensure_trailing_slash(leader_endpoint), authentication: Authentication::DapAuthToken(authentication_token), hpke_config, hpke_private_key, @@ -512,41 +507,40 @@ impl Collector { } let collect_response = CollectionMessage::::get_decoded(&response.bytes().await?)?; - if collect_response.encrypted_aggregate_shares().len() != 2 { - return Err(Error::AggregateShareCount( - collect_response.encrypted_aggregate_shares().len(), - )); - } - let aggregate_shares_bytes = collect_response - .encrypted_aggregate_shares() - .iter() - .zip(&[Role::Leader, Role::Helper]) - .map(|(encrypted_aggregate_share, role)| { - hpke::open( - &self.parameters.hpke_config, - &self.parameters.hpke_private_key, - &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, role, &Role::Collector), - encrypted_aggregate_share, - &AggregateShareAad::new( - self.parameters.task_id, - BatchSelector::::new(Q::batch_identifier_for_collection( - &job.query, - &collect_response, - )), - ) - .get_encoded(), - ) - }); - let aggregate_shares = aggregate_shares_bytes - .map(|bytes| { - V::AggregateShare::get_decoded_with_param( - &(&self.vdaf_collector, &job.aggregation_parameter), - &bytes?, + let aggregate_shares = [ + ( + Role::Leader, + collect_response.leader_encrypted_aggregate_share(), + ), + ( + Role::Helper, + collect_response.helper_encrypted_aggregate_share(), + ), + ] + .into_iter() + .map(|(role, encrypted_aggregate_share)| { + let bytes = hpke::open( + &self.parameters.hpke_config, + &self.parameters.hpke_private_key, + &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, &role, &Role::Collector), + encrypted_aggregate_share, + &AggregateShareAad::new( + self.parameters.task_id, + BatchSelector::::new(Q::batch_identifier_for_collection( + &job.query, + &collect_response, + )), ) - .map_err(|_err| Error::AggregateShareDecode) - }) - .collect::, Error>>()?; + .get_encoded(), + )?; + V::AggregateShare::get_decoded_with_param( + &(&self.vdaf_collector, &job.aggregation_parameter), + &bytes, + ) + .map_err(|_err| Error::AggregateShareDecode) + }) + .collect::, Error>>()?; let report_count = collect_response .report_count() @@ -742,30 +736,20 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::::from([ - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &transcript.aggregate_shares[0].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &transcript.aggregate_shares[1].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &transcript.aggregate_shares[0].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &transcript.aggregate_shares[1].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ) } @@ -780,30 +764,20 @@ mod tests { PartialBatchSelector::new_fixed_size(batch_id), 1, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - Vec::::from([ - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &transcript.aggregate_shares[0].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &transcript.aggregate_shares[1].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &transcript.aggregate_shares[0].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &transcript.aggregate_shares[1].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ) } @@ -1409,31 +1383,6 @@ mod tests { mock_collection_job_bad_message_bytes.assert_async().await; - let mock_collection_job_bad_share_count = server - .mock("POST", job.collection_job_url.path()) - .with_status(200) - .with_header( - CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, - ) - .with_body( - CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 0, - batch_interval, - Vec::new(), - ) - .get_encoded(), - ) - .expect_at_least(1) - .create_async() - .await; - - let error = collector.poll_once(&job).await.unwrap_err(); - assert_matches!(error, Error::AggregateShareCount(0)); - - mock_collection_job_bad_share_count.assert_async().await; - let mock_collection_job_bad_ciphertext = server .mock("POST", job.collection_job_url.path()) .with_status(200) @@ -1446,18 +1395,16 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - HpkeCiphertext::new( - *collector.parameters.hpke_config.id(), - Vec::new(), - Vec::new(), - ), - HpkeCiphertext::new( - *collector.parameters.hpke_config.id(), - Vec::new(), - Vec::new(), - ), - ]), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), ) .get_encoded(), ) @@ -1478,30 +1425,20 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - b"bad", - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - b"bad", - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + b"bad", + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + b"bad", + &associated_data.get_encoded(), + ) + .unwrap(), ); let mock_collection_job_bad_shares = server .mock("POST", job.collection_job_url.path()) @@ -1524,35 +1461,25 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) - .get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &AggregateShare::from(OutputShare::from(Vec::from([ - Field64::from(0), - Field64::from(0), - ]))) + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) .get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &AggregateShare::from(OutputShare::from(Vec::from([ + Field64::from(0), + Field64::from(0), + ]))) + .get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ); let mock_collection_job_wrong_length = server .mock("POST", job.collection_job_url.path()) diff --git a/core/Cargo.toml b/core/Cargo.toml index a1bce9836..5ad302cdb 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -57,6 +57,7 @@ tokio = { version = "1.27", features = ["macros", "net", "rt"] } tracing = "0.1.37" tracing-log = { version = "0.1.3", optional = true } tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"], optional = true } +url = "2" [dev-dependencies] fixed = "1.23" diff --git a/core/src/task.rs b/core/src/task.rs index 1e3a513cc..e10b9529b 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -1,14 +1,14 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use rand::{distributions::Standard, prelude::Distribution}; -use reqwest::Url; use ring::constant_time; use serde::{Deserialize, Serialize}; +use url::Url; /// HTTP header where auth tokens are provided in messages between participants. pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token"; -/// The length of the verify key parameter for Prio3 VDAF instantiations. -pub const PRIO3_VERIFY_KEY_LENGTH: usize = 16; +/// The length of the verify key parameter for Prio3 & Poplar1 VDAF instantiations. +pub const VERIFY_KEY_LEN: usize = 16; /// Identifiers for supported VDAFs, corresponding to definitions in /// [draft-irtf-cfrg-vdaf-03][1] and implementations in [`prio::vdaf::prio3`]. @@ -62,9 +62,8 @@ impl VdafInstance { | VdafInstance::FakeFailsPrepInit | VdafInstance::FakeFailsPrepStep => 0, - // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's - // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, + // All "real" VDAFs use a verify key of length 16 currently. + _ => VERIFY_KEY_LEN, } } } @@ -73,35 +72,41 @@ impl VdafInstance { #[macro_export] macro_rules! vdaf_dispatch_impl_base { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match base $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match base $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count => { type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3CountVec { length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Sum { bits } => { type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3SumVec { bits, length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; + $body + } + + ::janus_core::task::VdafInstance::Poplar1 { bits } => { + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -110,12 +115,12 @@ macro_rules! vdaf_dispatch_impl_base { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match base $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match base $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_count(2)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -123,14 +128,14 @@ macro_rules! vdaf_dispatch_impl_base { // Prio3CountVec is implemented as a 1-bit sum vec let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, 1, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Sum { bits } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *bits)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -138,14 +143,21 @@ macro_rules! vdaf_dispatch_impl_base { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, *bits, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, buckets)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; + $body + } + + ::janus_core::task::VdafInstance::Poplar1 { bits } => { + let $vdaf = ::prio::vdaf::poplar1::Poplar1::new_sha3(*bits); + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -159,13 +171,13 @@ macro_rules! vdaf_dispatch_impl_base { #[macro_export] macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI16<::fixed::types::extra::U15>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -173,7 +185,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI32<::fixed::types::extra::U31>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -181,7 +193,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI64<::fixed::types::extra::U63>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -190,7 +202,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => { let $vdaf = @@ -200,7 +212,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI16<::fixed::types::extra::U15>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -212,7 +224,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI32<::fixed::types::extra::U31>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -224,7 +236,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI64<::fixed::types::extra::U63>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -238,23 +250,23 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { #[macro_export] macro_rules! vdaf_dispatch_impl_test_util { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match test_util $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match test_util $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Fake => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepInit => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepStep => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -263,12 +275,12 @@ macro_rules! vdaf_dispatch_impl_test_util { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match test_util $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match test_util $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Fake => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new(); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -278,26 +290,30 @@ macro_rules! vdaf_dispatch_impl_test_util { ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( "FakeFailsPrepInit failed at prep_init".to_string(), )) - } + }, ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepStep => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result< - ::prio::vdaf::PrepareTransition<::janus_core::test_util::dummy_vdaf::Vdaf, 0, 16>, + |_| -> Result< + ::prio::vdaf::PrepareTransition< + ::janus_core::test_util::dummy_vdaf::Vdaf, + 0, + 16, + >, ::prio::vdaf::VdafError, > { ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( "FakeFailsPrepStep failed at prep_step".to_string(), )) - } + }, ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -311,26 +327,27 @@ macro_rules! vdaf_dispatch_impl_test_util { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -338,26 +355,27 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -370,20 +388,21 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -391,20 +410,21 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -417,20 +437,21 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -438,20 +459,21 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -464,14 +486,15 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -479,14 +502,15 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -513,21 +537,21 @@ macro_rules! vdaf_dispatch_impl { /// # } /// # fn test() -> Result<(), prio::vdaf::VdafError> { /// # let vdaf = janus_core::task::VdafInstance::Prio3Count; -/// vdaf_dispatch!(&vdaf, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { -/// handle_request_generic::(&vdaf) +/// vdaf_dispatch!(&vdaf, (vdaf, VdafType, VERIFY_KEY_LEN) => { +/// handle_request_generic::(&vdaf) /// }) /// # } /// ``` #[macro_export] macro_rules! vdaf_dispatch { // Provide the dispatched type only, don't construct a VDAF instance. - ($vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ($vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { + ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) }; // Construct a VDAF instance, and provide that to the block as well. - ($vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ($vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { + ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) }; } @@ -567,15 +591,16 @@ impl Distribution for Standard { } } -/// Modifies a [`Url`] in place to ensure it ends with a slash. +/// Returns the given [`Url`], possibly modified to end with a slash. /// /// Aggregator endpoint URLs should end with a slash if they will be used with [`Url::join`], /// because that method will drop the last path component of the base URL if it does not end with a /// slash. -pub fn url_ensure_trailing_slash(url: &mut Url) { +pub fn url_ensure_trailing_slash(mut url: Url) -> Url { if !url.as_str().ends_with('/') { url.set_path(&format!("{}/", url.path())); } + url } #[cfg(test)] diff --git a/core/src/test_util/dummy_vdaf.rs b/core/src/test_util/dummy_vdaf.rs index 7d1583112..6d87ea60e 100644 --- a/core/src/test_util/dummy_vdaf.rs +++ b/core/src/test_util/dummy_vdaf.rs @@ -4,20 +4,24 @@ use prio::{ codec::{CodecError, Decode, Encode}, vdaf::{self, Aggregatable, PrepareTransition, VdafError}, }; +use rand::random; use std::fmt::Debug; use std::io::Cursor; use std::sync::Arc; type ArcPrepInitFn = Arc Result<(), VdafError> + 'static + Send + Sync>; -type ArcPrepStepFn = - Arc Result, VdafError> + 'static + Send + Sync>; +type ArcPrepStepFn = Arc< + dyn Fn(&PrepareState) -> Result, VdafError> + + 'static + + Send + + Sync, +>; #[derive(Clone)] pub struct Vdaf { prep_init_fn: ArcPrepInitFn, prep_step_fn: ArcPrepStepFn, - input_share: InputShare, } impl Debug for Vdaf { @@ -31,15 +35,16 @@ impl Debug for Vdaf { impl Vdaf { /// The length of the verify key parameter for fake VDAF instantiations. - pub const VERIFY_KEY_LENGTH: usize = 0; + pub const VERIFY_KEY_LEN: usize = 0; pub fn new() -> Self { Self { prep_init_fn: Arc::new(|_| -> Result<(), VdafError> { Ok(()) }), - prep_step_fn: Arc::new(|| -> Result, VdafError> { - Ok(PrepareTransition::Finish(OutputShare())) - }), - input_share: InputShare::default(), + prep_step_fn: Arc::new( + |state| -> Result, VdafError> { + Ok(PrepareTransition::Finish(OutputShare(state.0))) + }, + ), } } @@ -54,7 +59,9 @@ impl Vdaf { self } - pub fn with_prep_step_fn Result, VdafError>>( + pub fn with_prep_step_fn< + F: Fn(&PrepareState) -> Result, VdafError>, + >( mut self, f: F, ) -> Self @@ -64,11 +71,6 @@ impl Vdaf { self.prep_step_fn = Arc::new(f); self } - - pub fn with_input_share(mut self, input_share: InputShare) -> Self { - self.input_share = input_share; - self - } } impl Default for Vdaf { @@ -80,8 +82,8 @@ impl Default for Vdaf { impl vdaf::Vdaf for Vdaf { const ID: u32 = 0xFFFF0000; - type Measurement = (); - type AggregateResult = (); + type Measurement = u8; + type AggregateResult = u8; type AggregationParam = AggregationParam; type PublicShare = (); type InputShare = InputShare; @@ -120,10 +122,10 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_step( &self, - _: Self::PrepareState, + state: Self::PrepareState, _: Self::PrepareMessage, ) -> Result, VdafError> { - (self.prep_step_fn)() + (self.prep_step_fn)(&state) } fn aggregate>( @@ -142,10 +144,18 @@ impl vdaf::Aggregator<0, 16> for Vdaf { impl vdaf::Client<16> for Vdaf { fn shard( &self, - _measurement: &Self::Measurement, + measurement: &Self::Measurement, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { - Ok(((), Vec::from([self.input_share, self.input_share]))) + let first_input_share = random(); + let (second_input_share, _) = measurement.overflowing_sub(first_input_share); + Ok(( + (), + Vec::from([ + InputShare(first_input_share), + InputShare(second_input_share), + ]), + )) } } @@ -173,7 +183,7 @@ pub struct AggregationParam(pub u8); impl Encode for AggregationParam { fn encode(&self, bytes: &mut Vec) { - self.0.encode(bytes); + self.0.encode(bytes) } fn encoded_len(&self) -> Option { @@ -188,19 +198,21 @@ impl Decode for AggregationParam { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct OutputShare(); +pub struct OutputShare(pub u8); impl Decode for OutputShare { - fn decode(_: &mut Cursor<&[u8]>) -> Result { - Ok(Self()) + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + Ok(Self(u8::decode(bytes)?)) } } impl Encode for OutputShare { - fn encode(&self, _: &mut Vec) {} + fn encode(&self, bytes: &mut Vec) { + self.0.encode(bytes); + } fn encoded_len(&self) -> Option { - Some(0) + self.0.encoded_len() } } @@ -234,15 +246,15 @@ impl Aggregatable for AggregateShare { Ok(()) } - fn accumulate(&mut self, _: &Self::OutputShare) -> Result<(), VdafError> { - self.0 += 1; + fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> { + self.0 += u64::from(out_share.0); Ok(()) } } impl From for AggregateShare { - fn from(_: OutputShare) -> Self { - Self(1) + fn from(out_share: OutputShare) -> Self { + Self(u64::from(out_share.0)) } } diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 60d2e15db..e150b4c9f 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -32,15 +32,15 @@ pub struct VdafTranscript> VdafTranscript { - /// Get the leader's preparation state at the requested round. - pub fn leader_prep_state(&self, round: usize) -> &V::PrepareState { + /// Get the Leader's preparation state and prepare share at the requested round. + pub fn leader_prep_state(&self, round: usize) -> (&V::PrepareState, &V::PrepareShare) { assert_matches!( &self.prepare_transitions[Role::Leader.index().unwrap()][round], - PrepareTransition::::Continue(prep_state, _) => prep_state + PrepareTransition::::Continue(prep_state, prep_share) => (prep_state, prep_share) ) } - /// Get the helper's preparation state and prepare share at the requested round. + /// Get the Helper's preparation state and prepare share at the requested round. pub fn helper_prep_state(&self, round: usize) -> (&V::PrepareState, &V::PrepareShare) { assert_matches!( &self.prepare_transitions[Role::Helper.index().unwrap()][round], diff --git a/db/20230405185602_initial-schema.up.sql b/db/20230405185602_initial-schema.up.sql index 0b2ec2c2b..cc4914448 100644 --- a/db/20230405185602_initial-schema.up.sql +++ b/db/20230405185602_initial-schema.up.sql @@ -11,19 +11,20 @@ CREATE TYPE AGGREGATOR_ROLE AS ENUM( -- Corresponds to a DAP task, containing static data associated with the task. CREATE TABLE tasks( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification - aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task - aggregator_endpoints TEXT[] NOT NULL, -- aggregator HTTPS endpoints, leader first - query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters - vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters - max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected - task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted - report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. - min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected - time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds - tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds - collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only + task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification + aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task + leader_aggregator_endpoint TEXT NOT NULL, -- Leader's API endpoint + helper_aggregator_endpoint TEXT NOT NULL, -- Helper's API endpoint + query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters + vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters + max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected + task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted + report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. + min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected + time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds + tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds + collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) ); CREATE INDEX task_id_index ON tasks(task_id); diff --git a/db/schema.sql b/db/schema.sql index a47f70a4e..4f42ef19b 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -11,19 +11,20 @@ CREATE TYPE AGGREGATOR_ROLE AS ENUM( -- Corresponds to a DAP task, containing static data associated with the task. CREATE TABLE tasks( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification - aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task - aggregator_endpoints TEXT[] NOT NULL, -- aggregator HTTPS endpoints, leader first - query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters - vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters - max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected - task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted - report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. - min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected - time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds - tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds - collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only + task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification + aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task + leader_aggregator_endpoint TEXT NOT NULL, -- Leader's API endpoint + helper_aggregator_endpoint TEXT NOT NULL, -- Helper's API endpoint + query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters + vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters + max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected + task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted + report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. + min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected + time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds + tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds + collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) ); CREATE INDEX task_id_index ON tasks(task_id); @@ -140,11 +141,10 @@ CREATE TABLE report_aggregations( ord BIGINT NOT NULL, -- a value used to specify the ordering of client reports in the aggregation job state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation prep_state BYTEA, -- the current preparation state (opaque VDAF message, only if in state WAITING) - prep_msg BYTEA, -- for the leader, the next preparation message to be sent to the helper (opaque VDAF message) - -- for the helper, the next preparation share to be sent to the leader (opaque VDAF message) - -- only non-NULL if in state WAITING + prep_msg BYTEA, -- the next preparation message to be sent to the Helper (opaque VDAF message, populated for Leader only) out_share BYTEA, -- the output share (opaque VDAF message, only if in state FINISHED) error_code SMALLINT, -- error code corresponding to a DAP ReportShareError value; null if in a state other than FAILED + last_prep_step BYTEA, -- the last PreparationStep message sent to the Leader, to assist in replay (opaque VADF message, populated for Helper only) CONSTRAINT report_aggregations_unique_ord UNIQUE(aggregation_job_id, ord), CONSTRAINT fk_aggregation_job_id FOREIGN KEY(aggregation_job_id) REFERENCES aggregation_jobs(id) ON DELETE CASCADE, diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index e9f0c801a..470e27f89 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -6,10 +6,9 @@ # DAP's recommendation. task_id: "G9YKXjoEjfoU7M_fi_o2H0wmzavRb2sBFHeykeRhDMk" - # HTTPS endpoints of the leader and helper aggregators, in a list. - aggregator_endpoints: - - "https://example.com/" - - "https://example.net/" + # HTTPS endpoints of the leader and helper aggregators. + leader_aggregator_endpoint: "https://example.com/" + helper_aggregator_endpoint: "https://example.net/" # The DAP query type. See below for an example of a fixed-size task query_type: TimeInterval @@ -94,9 +93,8 @@ private_key: wFRYwiypcHC-mkGP1u3XQgIvtnlkQlUfZjgtM_zRsnI - task_id: "D-hCKPuqL2oTf7ZVRVyMP5VGt43EAEA8q34mDf6p1JE" - aggregator_endpoints: - - "https://example.org/" - - "https://example.com/" + leader_aggregator_endpoint: "https://example.org/" + helper_aggregator_endpoint: "https://example.com/" # For tasks using the fixed size query type, an additional `max_batch_size` # parameter must be provided. query_type: !FixedSize diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index a6dba1d29..f3669f656 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -40,6 +40,6 @@ url = { version = "2.3.1", features = ["serde"] } [dev-dependencies] http = "0.2" -itertools = "0.10" +itertools.workspace = true janus_collector = { workspace = true, features = ["test-util"] } tempfile = "3" diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index dc4a01ff2..1097f62ad 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -159,18 +159,22 @@ impl<'a> ClientBackend<'a> { pub async fn build( &self, task: &Task, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, vdaf: V, ) -> anyhow::Result> where V: vdaf::Client<16> + InteropClientEncoding, { match self { - ClientBackend::InProcess => { - ClientImplementation::new_in_process(task, aggregator_endpoints, vdaf) - .await - .map_err(Into::into) - } + ClientBackend::InProcess => ClientImplementation::new_in_process( + task, + leader_aggregator_endpoint, + helper_aggregator_endpoint, + vdaf, + ) + .await + .map_err(Into::into), ClientBackend::Container { container_client, container_image, @@ -217,11 +221,16 @@ where { pub async fn new_in_process( task: &Task, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, vdaf: V, ) -> Result, janus_client::Error> { - let client_parameters = - ClientParameters::new(*task.id(), aggregator_endpoints, *task.time_precision()); + let client_parameters = ClientParameters::new( + *task.id(), + leader_aggregator_endpoint, + helper_aggregator_endpoint, + *task.time_precision(), + ); let http_client = default_http_client()?; let leader_config = aggregator_hpke_config(&client_parameters, &Role::Leader, task.id(), &http_client) @@ -259,8 +268,8 @@ where let http_client = reqwest::Client::new(); ClientImplementation::Container(Box::new(ContainerClientImplementation { _container: container, - leader: task.aggregator_endpoints()[Role::Leader.index().unwrap()].clone(), - helper: task.aggregator_endpoints()[Role::Helper.index().unwrap()].clone(), + leader: task.leader_aggregator_endpoint().clone(), + helper: task.helper_aggregator_endpoint().clone(), task_id: *task.id(), time_precision: *task.time_precision(), vdaf, diff --git a/integration_tests/src/daphne.rs b/integration_tests/src/daphne.rs index 7eb355322..3f0f00f80 100644 --- a/integration_tests/src/daphne.rs +++ b/integration_tests/src/daphne.rs @@ -34,7 +34,11 @@ impl<'a> Daphne<'a> { let (image_name, image_tag) = image_name_and_tag.rsplit_once(':').unwrap(); // Start the Daphne test container running. - let endpoint = task.aggregator_url(task.role()).unwrap(); + let endpoint = match task.role() { + Role::Leader => task.leader_aggregator_endpoint(), + Role::Helper => task.helper_aggregator_endpoint(), + _ => panic!("unexpected role"), + }; let runnable_image = RunnableImage::from(GenericImage::new(image_name, image_tag)) .with_network(network) .with_container_name(endpoint.host_str().unwrap()); diff --git a/integration_tests/src/janus.rs b/integration_tests/src/janus.rs index 4ef64b442..4fe11b1c7 100644 --- a/integration_tests/src/janus.rs +++ b/integration_tests/src/janus.rs @@ -51,7 +51,11 @@ impl<'a> Janus<'a> { task: &Task, ) -> Janus<'a> { // Start the Janus interop aggregator container running. - let endpoint = task.aggregator_url(task.role()).unwrap(); + let endpoint = match task.role() { + Role::Leader => task.leader_aggregator_endpoint(), + Role::Helper => task.helper_aggregator_endpoint(), + _ => panic!("unexpected task role"), + }; let container = container_client.run( RunnableImage::from(Aggregator::default()) .with_network(network) diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index b83104ddd..376499918 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -30,10 +30,12 @@ pub fn test_task_builders( let endpoint_random_value = hex::encode(random::<[u8; 4]>()); let collector_keypair = generate_test_hpke_config_and_private_key(); let leader_task = TaskBuilder::new(query_type, vdaf, Role::Leader) - .with_aggregator_endpoints(Vec::from([ + .with_leader_aggregator_endpoint( Url::parse(&format!("http://leader-{endpoint_random_value}:8080/")).unwrap(), + ) + .with_helper_aggregator_endpoint( Url::parse(&format!("http://helper-{endpoint_random_value}:8080/")).unwrap(), - ])) + ) .with_min_batch_size(46) .with_collector_hpke_config(collector_keypair.config().clone()); let helper_task = leader_task @@ -106,7 +108,7 @@ where pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( vdaf: V, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: &Url, leader_task: &'a Task, collector_private_key: &'a HpkePrivateKey, test_case: &'a AggregationTestCase, @@ -124,7 +126,7 @@ pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( let collector_params = CollectorParameters::new( *leader_task.id(), - aggregator_endpoints[Role::Leader.index().unwrap()].clone(), + leader_aggregator_endpoint.clone(), leader_task.primary_collector_auth_token().clone(), leader_task.collector_hpke_config().clone(), collector_private_key.clone(), @@ -143,9 +145,7 @@ pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( janus_collector::default_http_client().unwrap(), ); - let forwarded_port = aggregator_endpoints[Role::Leader.index().unwrap()] - .port() - .unwrap(); + let forwarded_port = leader_aggregator_endpoint.port().unwrap(); // Send a collect request and verify that we got the correct result. match leader_task.query_type() { @@ -218,12 +218,10 @@ pub async fn submit_measurements_and_verify_aggregate( client_backend: &ClientBackend<'_>, ) { // Translate aggregator endpoints for our perspective outside the container network. - let aggregator_endpoints: Vec<_> = leader_task - .aggregator_endpoints() - .iter() - .zip([leader_port, helper_port]) - .map(|(url, port)| translate_url_for_external_access(url, port)) - .collect(); + let leader_aggregator_endpoint = + translate_url_for_external_access(leader_task.leader_aggregator_endpoint(), leader_port); + let helper_aggregator_endpoint = + translate_url_for_external_access(leader_task.helper_aggregator_endpoint(), helper_port); // We generate exactly one batch's worth of measurement uploads to work around an issue in // Daphne at time of writing. @@ -247,13 +245,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -275,13 +278,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -315,13 +323,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -354,13 +367,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -394,13 +412,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, diff --git a/integration_tests/tests/daphne.rs b/integration_tests/tests/daphne.rs index 1dde04b6f..c400cd923 100644 --- a/integration_tests/tests/daphne.rs +++ b/integration_tests/tests/daphne.rs @@ -6,7 +6,6 @@ use janus_core::{ }; use janus_integration_tests::{client::ClientBackend, daphne::Daphne, janus::Janus}; use janus_interop_binaries::test_util::generate_network_name; -use janus_messages::Role; mod common; @@ -24,10 +23,12 @@ async fn daphne_janus() { // Daphne is hardcoded to serve from a path starting with /v04/. let [leader_task, helper_task]: [Task; 2] = [leader_task, helper_task] .into_iter() - .map(|task| { - let mut endpoints = task.aggregator_endpoints().to_vec(); - endpoints[Role::Leader.index().unwrap()].set_path("/v04/"); - task.with_aggregator_endpoints(endpoints).build() + .map(|task_builder| { + let mut endpoint = task_builder.task().leader_aggregator_endpoint().clone(); + endpoint.set_path("/v04/"); + task_builder + .with_leader_aggregator_endpoint(endpoint) + .build() }) .collect::>() .try_into() @@ -49,6 +50,7 @@ async fn daphne_janus() { // This test places Janus in the leader role & Daphne in the helper role. #[tokio::test(flavor = "multi_thread")] +#[ignore = "Daphne does not yet implement ping-pong aggregation"] async fn janus_daphne() { install_test_trace_subscriber(); @@ -60,10 +62,12 @@ async fn janus_daphne() { // Daphne is hardcoded to serve from a path starting with /v04/. let [leader_task, helper_task]: [Task; 2] = [leader_task, helper_task] .into_iter() - .map(|task| { - let mut endpoints = task.aggregator_endpoints().to_vec(); - endpoints[Role::Helper.index().unwrap()].set_path("/v04/"); - task.with_aggregator_endpoints(endpoints).build() + .map(|task_builder| { + let mut endpoint = task_builder.task().helper_aggregator_endpoint().clone(); + endpoint.set_path("/v04/"); + task_builder + .with_helper_aggregator_endpoint(endpoint) + .build() }) .collect::>() .try_into() diff --git a/integration_tests/tests/janus.rs b/integration_tests/tests/janus.rs index aa89eea65..3161485cc 100644 --- a/integration_tests/tests/janus.rs +++ b/integration_tests/tests/janus.rs @@ -68,15 +68,12 @@ impl<'a> JanusPair<'a> { // where "port" is whatever unused port we use with `kubectl port-forward`. But when // the aggregators talk to each other, they do it on the cluster's private network, // and so they need the in-cluster DNS name of the other aggregator. However, since - // aggregators use the endpoint URLs in the task to construct collection job URIs, we - // must only fix the _peer_ aggregator's endpoint. - let leader_endpoints = { - let mut endpoints = leader_task.aggregator_endpoints().to_vec(); - endpoints[1] = Self::in_cluster_aggregator_url(&helper_namespace); - endpoints - }; + // aggregators use the endpoint URLs in the task to construct collection job URIs, + // we must only fix the _peer_ aggregator's endpoint. let leader_task = leader_task - .with_aggregator_endpoints(leader_endpoints) + .with_helper_aggregator_endpoint(Self::in_cluster_aggregator_url( + &helper_namespace, + )) .build(); let leader = Janus::new_with_kubernetes_cluster( &kubeconfig_path, @@ -86,13 +83,10 @@ impl<'a> JanusPair<'a> { ) .await; - let helper_endpoints = { - let mut endpoints = helper_task.aggregator_endpoints().to_vec(); - endpoints[0] = Self::in_cluster_aggregator_url(&leader_namespace); - endpoints - }; let helper_task = helper_task - .with_aggregator_endpoints(helper_endpoints) + .with_leader_aggregator_endpoint(Self::in_cluster_aggregator_url( + &leader_namespace, + )) .build(); let helper = Janus::new_with_kubernetes_cluster( &kubeconfig_path, diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 56bc4c483..4d9c35383 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -84,7 +84,8 @@ async fn handle_add_task( let task = Task::new( request.task_id, - Vec::from([request.leader, request.helper]), + request.leader, + request.helper, query_type, vdaf, request.role.into(), diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index c22701163..300bd39bb 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -79,11 +79,8 @@ async fn handle_upload_generic>( .context("invalid base64url content in \"task_id\"")?; let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; let time_precision = Duration::from_seconds(request.time_precision); - let client_parameters = ClientParameters::new( - task_id, - Vec::::from([request.leader, request.helper]), - time_precision, - ); + let client_parameters = + ClientParameters::new(task_id, request.leader, request.helper, time_precision); let leader_hpke_config = janus_client::aggregator_hpke_config( &client_parameters, diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index 22d0b5d16..c7fcfff3b 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -287,8 +287,8 @@ impl From for AggregatorAddTaskRequest { }; Self { task_id: *task.id(), - leader: task.aggregator_url(&Role::Leader).unwrap().clone(), - helper: task.aggregator_url(&Role::Helper).unwrap().clone(), + leader: task.leader_aggregator_endpoint().clone(), + helper: task.helper_aggregator_endpoint().clone(), vdaf: task.vdaf().clone().into(), leader_authentication_token: String::from_utf8( task.primary_aggregator_auth_token().as_ref().to_vec(), diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index 405a173b2..297c29c41 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -2,7 +2,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use futures::future::join_all; use janus_core::{ - task::PRIO3_VERIFY_KEY_LENGTH, + task::VERIFY_KEY_LEN, test_util::{install_test_trace_subscriber, testcontainers::container_client}, time::{Clock, RealClock, TimeExt}, }; @@ -109,7 +109,7 @@ async fn run( let task_id: TaskId = random(); let aggregator_auth_token = URL_SAFE_NO_PAD.encode(random::<[u8; 16]>()); let collector_auth_token = URL_SAFE_NO_PAD.encode(random::<[u8; 16]>()); - let vdaf_verify_key = rand::random::<[u8; PRIO3_VERIFY_KEY_LENGTH]>(); + let vdaf_verify_key = rand::random::<[u8; VERIFY_KEY_LEN]>(); let task_id_encoded = URL_SAFE_NO_PAD.encode(task_id.get_encoded()); let vdaf_verify_key_encoded = URL_SAFE_NO_PAD.encode(vdaf_verify_key); diff --git a/messages/Cargo.toml b/messages/Cargo.toml index eb38b56be..3bd3f784c 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -20,7 +20,7 @@ hex = "0.4" num_enum = "0.6.1" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.12.1", default-features = false } +prio = { workspace = true, default-features = false } # XXX: revert to 0.12.1? rand = "0.8" serde = { version = "1.0.160", features = ["derive"] } thiserror = "1.0" diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 01f097406..4cac9a8ea 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -1222,7 +1222,8 @@ impl Decode for PlaintextInputShare { pub struct Report { metadata: ReportMetadata, public_share: Vec, - encrypted_input_shares: Vec, + leader_encrypted_input_share: HpkeCiphertext, + helper_encrypted_input_share: HpkeCiphertext, } impl Report { @@ -1233,12 +1234,14 @@ impl Report { pub fn new( metadata: ReportMetadata, public_share: Vec, - encrypted_input_shares: Vec, + leader_encrypted_input_share: HpkeCiphertext, + helper_encrypted_input_share: HpkeCiphertext, ) -> Self { Self { metadata, public_share, - encrypted_input_shares, + leader_encrypted_input_share, + helper_encrypted_input_share, } } @@ -1247,13 +1250,19 @@ impl Report { &self.metadata } + /// Retrieve the public share from this report. pub fn public_share(&self) -> &[u8] { &self.public_share } - /// Get this report's encrypted input shares. - pub fn encrypted_input_shares(&self) -> &[HpkeCiphertext] { - &self.encrypted_input_shares + /// Retrieve the encrypted leader input share from this report. + pub fn leader_encrypted_input_share(&self) -> &HpkeCiphertext { + &self.leader_encrypted_input_share + } + + /// Retrieve the encrypted helper input share from this report. + pub fn helper_encrypted_input_share(&self) -> &HpkeCiphertext { + &self.helper_encrypted_input_share } } @@ -1261,17 +1270,16 @@ impl Encode for Report { fn encode(&self, bytes: &mut Vec) { self.metadata.encode(bytes); encode_u32_items(bytes, &(), &self.public_share); - encode_u32_items(bytes, &(), &self.encrypted_input_shares); + self.leader_encrypted_input_share.encode(bytes); + self.helper_encrypted_input_share.encode(bytes); } fn encoded_len(&self) -> Option { let mut length = self.metadata.encoded_len()?; length += 4; length += self.public_share.len(); - length += 4; - for encrypted_input_share in self.encrypted_input_shares.iter() { - length += encrypted_input_share.encoded_len()?; - } + length += self.leader_encrypted_input_share.encoded_len()?; + length += self.helper_encrypted_input_share.encoded_len()?; Some(length) } } @@ -1280,12 +1288,14 @@ impl Decode for Report { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let metadata = ReportMetadata::decode(bytes)?; let public_share = decode_u32_items(&(), bytes)?; - let encrypted_input_shares = decode_u32_items(&(), bytes)?; + let leader_encrypted_input_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_input_share = HpkeCiphertext::decode(bytes)?; Ok(Self { metadata, public_share, - encrypted_input_shares, + leader_encrypted_input_share, + helper_encrypted_input_share, }) } } @@ -1589,7 +1599,8 @@ pub struct Collection { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, - encrypted_aggregate_shares: Vec, + leader_encrypted_agg_share: HpkeCiphertext, + helper_encrypted_agg_share: HpkeCiphertext, } impl Collection { @@ -1601,34 +1612,41 @@ impl Collection { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, - encrypted_aggregate_shares: Vec, + leader_encrypted_agg_share: HpkeCiphertext, + helper_encrypted_agg_share: HpkeCiphertext, ) -> Self { Self { partial_batch_selector, report_count, interval, - encrypted_aggregate_shares, + leader_encrypted_agg_share, + helper_encrypted_agg_share, } } - /// Gets the batch selector associated with this collection. + /// Retrieves the batch selector associated with this collection. pub fn partial_batch_selector(&self) -> &PartialBatchSelector { &self.partial_batch_selector } - /// Gets the number of reports that were aggregated into this collection. + /// Retrieves the number of reports that were aggregated into this collection. pub fn report_count(&self) -> u64 { self.report_count } - /// Gets the interval spanned by the reports aggregated into this collection. + /// Retrieves the interval spanned by the reports aggregated into this collection. pub fn interval(&self) -> &Interval { &self.interval } - /// Gets the encrypted aggregate shares associated with this collection. - pub fn encrypted_aggregate_shares(&self) -> &[HpkeCiphertext] { - &self.encrypted_aggregate_shares + /// Retrieves the leader encrypted aggregate share associated with this collection. + pub fn leader_encrypted_aggregate_share(&self) -> &HpkeCiphertext { + &self.leader_encrypted_agg_share + } + + /// Retrieves the helper encrypted aggregate share associated with this collection. + pub fn helper_encrypted_aggregate_share(&self) -> &HpkeCiphertext { + &self.helper_encrypted_agg_share } } @@ -1637,18 +1655,18 @@ impl Encode for Collection { self.partial_batch_selector.encode(bytes); self.report_count.encode(bytes); self.interval.encode(bytes); - encode_u32_items(bytes, &(), &self.encrypted_aggregate_shares); + self.leader_encrypted_agg_share.encode(bytes); + self.helper_encrypted_agg_share.encode(bytes); } fn encoded_len(&self) -> Option { - let mut length = self.partial_batch_selector.encoded_len()? - + self.report_count.encoded_len()? - + self.interval.encoded_len()?; - length += 4; - for encrypted_aggregate_share in self.encrypted_aggregate_shares.iter() { - length += encrypted_aggregate_share.encoded_len()?; - } - Some(length) + Some( + self.partial_batch_selector.encoded_len()? + + self.report_count.encoded_len()? + + self.interval.encoded_len()? + + self.leader_encrypted_agg_share.encoded_len()? + + self.helper_encrypted_agg_share.encoded_len()?, + ) } } @@ -1657,13 +1675,15 @@ impl Decode for Collection { let partial_batch_selector = PartialBatchSelector::decode(bytes)?; let report_count = u64::decode(bytes)?; let interval = Interval::decode(bytes)?; - let encrypted_aggregate_shares = decode_u32_items(&(), bytes)?; + let leader_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; Ok(Self { partial_batch_selector, report_count, interval, - encrypted_aggregate_shares, + leader_encrypted_agg_share, + helper_encrypted_agg_share, }) } } @@ -2007,6 +2027,57 @@ impl Decode for ReportShare { } } +/// DAP protocol message representing information required to initialize preparation of a report for +/// aggregation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ReportPrepInit { + report_share: ReportShare, + leader_prep_share: Vec, +} + +impl ReportPrepInit { + /// Constructs a new report preparation initialization message from its components. + pub fn new(report_share: ReportShare, leader_prep_share: Vec) -> Self { + Self { + report_share, + leader_prep_share, + } + } + + /// Gets the report share associated with this report prep init. + pub fn report_share(&self) -> &ReportShare { + &self.report_share + } + + /// Gets the leader preparation share associated with this report prep init. + pub fn leader_prep_share(&self) -> &[u8] { + &self.leader_prep_share + } +} + +impl Encode for ReportPrepInit { + fn encode(&self, bytes: &mut Vec) { + self.report_share.encode(bytes); + encode_u32_items(bytes, &(), &self.leader_prep_share); + } + + fn encoded_len(&self) -> Option { + Some(self.report_share.encoded_len()? + 4 + self.leader_prep_share.len()) + } +} + +impl Decode for ReportPrepInit { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let report_share = ReportShare::decode(bytes)?; + let leader_prep_share = decode_u32_items(&(), bytes)?; + + Ok(Self { + report_share, + leader_prep_share, + }) + } +} + /// DAP protocol message representing the result of a preparation step in a VDAF evaluation. #[derive(Clone, Debug, PartialEq, Eq)] pub struct PrepareStep { @@ -2056,8 +2127,16 @@ impl Decode for PrepareStep { #[derive(Clone, Derivative, PartialEq, Eq)] #[derivative(Debug)] pub enum PrepareStepResult { - Continued(#[derivative(Debug = "ignore")] Vec), // content is a serialized preparation message - Finished, + Continued { + #[derivative(Debug = "ignore")] + prep_msg: Vec, + #[derivative(Debug = "ignore")] + prep_share: Vec, + }, + Finished { + #[derivative(Debug = "ignore")] + prep_msg: Vec, + }, Failed(ReportShareError), } @@ -2066,11 +2145,18 @@ impl Encode for PrepareStepResult { // The encoding includes an implicit discriminator byte, called PrepareStepResult in the // DAP spec. match self { - Self::Continued(vdaf_msg) => { + Self::Continued { + prep_msg, + prep_share, + } => { 0u8.encode(bytes); - encode_u32_items(bytes, &(), vdaf_msg); + encode_u32_items(bytes, &(), prep_msg); + encode_u32_items(bytes, &(), prep_share); + } + Self::Finished { prep_msg } => { + 1u8.encode(bytes); + encode_u32_items(bytes, &(), prep_msg); } - Self::Finished => 1u8.encode(bytes), Self::Failed(error) => { 2u8.encode(bytes); error.encode(bytes); @@ -2080,8 +2166,11 @@ impl Encode for PrepareStepResult { fn encoded_len(&self) -> Option { match self { - PrepareStepResult::Continued(vdaf_msg) => Some(1 + 4 + vdaf_msg.len()), - PrepareStepResult::Finished => Some(1), + PrepareStepResult::Continued { + prep_msg, + prep_share, + } => Some(1 + 4 + prep_msg.len() + 4 + prep_share.len()), + PrepareStepResult::Finished { prep_msg } => Some(1 + 4 + prep_msg.len()), PrepareStepResult::Failed(error) => Some(1 + error.encoded_len()?), } } @@ -2091,8 +2180,18 @@ impl Decode for PrepareStepResult { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let val = u8::decode(bytes)?; Ok(match val { - 0 => Self::Continued(decode_u32_items(&(), bytes)?), - 1 => Self::Finished, + 0 => { + let prep_msg = decode_u32_items(&(), bytes)?; + let prep_share = decode_u32_items(&(), bytes)?; + Self::Continued { + prep_msg, + prep_share, + } + } + 1 => { + let prep_msg = decode_u32_items(&(), bytes)?; + Self::Finished { prep_msg } + } 2 => Self::Failed(ReportShareError::decode(bytes)?), _ => return Err(CodecError::UnexpectedValue), }) @@ -2205,7 +2304,7 @@ pub struct AggregationJobInitializeReq { #[derivative(Debug = "ignore")] aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + report_inits: Vec, } impl AggregationJobInitializeReq { @@ -2216,12 +2315,12 @@ impl AggregationJobInitializeReq { pub fn new( aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + report_inits: Vec, ) -> Self { Self { aggregation_parameter, partial_batch_selector, - report_shares, + report_inits, } } @@ -2235,9 +2334,10 @@ impl AggregationJobInitializeReq { &self.partial_batch_selector } - /// Gets the report shares associated with this aggregate initialization request. - pub fn report_shares(&self) -> &[ReportShare] { - &self.report_shares + /// Gets the report preparation initialization messages associated with this aggregate + /// initialization request. + pub fn report_inits(&self) -> &[ReportPrepInit] { + &self.report_inits } } @@ -2245,15 +2345,15 @@ impl Encode for AggregationJobInitializeReq { fn encode(&self, bytes: &mut Vec) { encode_u32_items(bytes, &(), &self.aggregation_parameter); self.partial_batch_selector.encode(bytes); - encode_u32_items(bytes, &(), &self.report_shares); + encode_u32_items(bytes, &(), &self.report_inits); } fn encoded_len(&self) -> Option { let mut length = 4 + self.aggregation_parameter.len(); length += self.partial_batch_selector.encoded_len()?; length += 4; - for report_share in self.report_shares.iter() { - length += report_share.encoded_len()?; + for report_init in &self.report_inits { + length += report_init.encoded_len()?; } Some(length) } @@ -2263,12 +2363,12 @@ impl Decode for AggregationJobInitializeReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let aggregation_parameter = decode_u32_items(&(), bytes)?; let partial_batch_selector = PartialBatchSelector::decode(bytes)?; - let report_shares = decode_u32_items(&(), bytes)?; + let report_inits = decode_u32_items(&(), bytes)?; Ok(Self { aggregation_parameter, partial_batch_selector, - report_shares, + report_inits, }) } } @@ -2637,8 +2737,8 @@ mod tests { Extension, ExtensionType, FixedSize, FixedSizeQuery, HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Query, Report, - ReportId, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, - Time, TimeInterval, + ReportId, ReportIdChecksum, ReportMetadata, ReportPrepInit, ReportShare, ReportShareError, + Role, TaskId, Time, TimeInterval, }; use assert_matches::assert_matches; use prio::codec::{CodecError, Decode, Encode}; @@ -3134,7 +3234,16 @@ mod tests { Time::from_seconds_since_epoch(12345), ), Vec::new(), - Vec::new(), + HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), ), concat!( concat!( @@ -3147,9 +3256,33 @@ mod tests { "00000000", // length ), concat!( - // encrypted_input_shares - "00000000", // length - ) + // leader_encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435" // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + concat!( + // helper_encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), ), ), ( @@ -3159,18 +3292,16 @@ mod tests { Time::from_seconds_since_epoch(54321), ), Vec::from("3210"), - Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), - ]), + HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), ), concat!( concat!( @@ -3184,33 +3315,31 @@ mod tests { "33323130", // opaque data ), concat!( - // encrypted_input_shares - "00000022", // length + // leader_encrypted_input_share + "2A", // config_id concat!( - "2A", // config_id - concat!( - // encapsulated_context - "0006", // length - "303132333435" // opaque data - ), - concat!( - // payload - "00000006", // length - "353433323130", // opaque data - ), + // encapsulated_context + "0006", // length + "303132333435" // opaque data ), concat!( - "0D", // config_id - concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data - ), - concat!( - // payload - "00000004", // length - "61626664", // opaque data - ), + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + concat!( + // helper_encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data ), ), ), @@ -3467,7 +3596,16 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 0, interval, - encrypted_aggregate_shares: Vec::new(), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3481,8 +3619,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "00000000", // length + // leader_encrypted_agg_share + "0A", // config_id + concat!( + // encapsulated_context + "0004", // length + "30313233", // opaque data + ), + concat!( + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3491,18 +3653,16 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 23, interval, - encrypted_aggregate_shares: Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(10), - Vec::from("0123"), - Vec::from("4567"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("01234"), - Vec::from("567"), - ), - ]), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3516,34 +3676,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "0000001E", // length + // leader_encrypted_agg_share + "0A", // config_id concat!( - "0A", // config_id - concat!( - // encapsulated_context - "0004", // length - "30313233", // opaque data - ), - concat!( - // payload - "00000004", // length - "34353637", // opaque data - ), + // encapsulated_context + "0004", // length + "30313233", // opaque data ), concat!( - "0C", // config_id - concat!( - // encapsulated_context - "0005", // length - "3031323334", // opaque data - ), - concat!( - // payload - "00000003", // length - "353637", // opaque data - ), - ) + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3558,7 +3716,16 @@ mod tests { )), report_count: 0, interval, - encrypted_aggregate_shares: Vec::new(), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3573,8 +3740,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "00000000", // length + // leader_encrypted_agg_share + "0A", // config_id + concat!( + // encapsulated_context + "0004", // length + "30313233", // opaque data + ), + concat!( + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3585,18 +3776,16 @@ mod tests { )), report_count: 23, interval, - encrypted_aggregate_shares: Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(10), - Vec::from("0123"), - Vec::from("4567"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("01234"), - Vec::from("567"), - ), - ]), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3611,34 +3800,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "0000001E", // length + // leader_encrypted_agg_share + "0A", // config_id concat!( - "0A", // config_id - concat!( - // encapsulated_context - "0004", // length - "30313233", // opaque data - ), - concat!( - // payload - "00000004", // length - "34353637", // opaque data - ), + // encapsulated_context + "0004", // length + "30313233", // opaque data ), concat!( - "0C", // config_id - concat!( - // encapsulated_context - "0005", // length - "3031323334", // opaque data - ), - concat!( - // payload - "00000003", // length - "353637", // opaque data - ), - ) + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3655,22 +3842,224 @@ mod tests { } #[test] - fn roundtrip_prepare_step() { + fn roundtrip_report_share() { roundtrip_encoding(&[ ( - PrepareStep { - report_id: ReportId::from([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), }, concat!( - "0102030405060708090A0B0C0D0E0F10", // report_id - "00", // prepare_step_result concat!( - // vdaf_msg - "00000006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + ), + ), + ( + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + concat!( + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + ), + ]) + } + + #[test] + fn roundtrip_report_prep_init() { + roundtrip_encoding(&[ + ( + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + leader_prep_share: Vec::from("012345"), + }, + concat!( + concat!( + // report_share + concat!( + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + ), + concat!( + // leader_prep_share + "00000006", // length + "303132333435", // opaque data + ) + ), + ), + ( + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + leader_prep_share: Vec::new(), + }, + concat!( + concat!( + // report_share + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + concat!( + // leader_prep_share + "00000000", // length + "" // opaque data + ) + ), + ), + ]) + } + + #[test] + fn roundtrip_prepare_step() { + roundtrip_encoding(&[ + ( + PrepareStep { + report_id: ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + result: PrepareStepResult::Continued { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("543210"), + }, + }, + concat!( + "0102030405060708090A0B0C0D0E0F10", // report_id + "00", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), + concat!( + // prep_share + "00000006", // length + "353433323130", // opaque data ), ), ), @@ -3679,11 +4068,18 @@ mod tests { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + result: PrepareStepResult::Finished { + prep_msg: Vec::from("012345"), + }, }, concat!( "100F0E0D0C0B0A090807060504030201", // report_id "01", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), ), ), ( @@ -3719,30 +4115,40 @@ mod tests { AggregationJobInitializeReq { aggregation_parameter: Vec::from("012345"), partial_batch_selector: PartialBatchSelector::new_time_interval(), - report_shares: Vec::from([ - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - public_share: Vec::new(), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + report_inits: Vec::from([ + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + leader_prep_share: Vec::from("012345"), }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + leader_prep_share: Vec::new(), }, ]), }, @@ -3757,58 +4163,74 @@ mod tests { "01", // query_type ), concat!( - // report_shares - "0000005E", // length + // report_inits + "0000006C", // length concat!( concat!( - // metadata - "0102030405060708090A0B0C0D0E0F10", // report_id - "000000000000D431", // time - ), - concat!( - // public_share - "00000000", // length - "", // opaque data - ), - concat!( - // encrypted_input_share - "2A", // config_id + // report_share concat!( - // encapsulated_context - "0006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time ), concat!( - // payload - "00000006", // length - "353433323130", // opaque data + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000006", // length + "303132333435", // opaque data + ) ), concat!( concat!( - // metadata - "100F0E0D0C0B0A090807060504030201", // report_id - "0000000000011F46", // time - ), - concat!( - "00000004", // payload - "30313233", // opaque data - ), - concat!( - // encrypted_input_share - "0D", // config_id concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time ), concat!( - // payload + // public_share "00000004", // length - "61626664", // opaque data + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000000", // length + "" // opaque data + ) ), ), ), @@ -3821,30 +4243,40 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_fixed_size(BatchId::from( [2u8; 32], )), - report_shares: Vec::from([ - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - public_share: Vec::new(), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + report_inits: Vec::from([ + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + leader_prep_share: Vec::from("012345"), }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + leader_prep_share: Vec::new(), }, ]), }, @@ -3860,59 +4292,74 @@ mod tests { "0202020202020202020202020202020202020202020202020202020202020202", // batch_id ), concat!( - // report_shares - "0000005E", // length + // report_inits + "0000006C", // length concat!( concat!( - // metadata - "0102030405060708090A0B0C0D0E0F10", // report_id - "000000000000D431", // time - ), - concat!( - // public_share - "00000000", // length - "", // opaque data - ), - concat!( - // encrypted_input_share - "2A", // config_id + // report_share concat!( - // encapsulated_context - "0006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time ), concat!( - // payload - "00000006", // length - "353433323130", // opaque data + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000006", // length + "303132333435", // opaque data + ) ), concat!( concat!( - // metadata - "100F0E0D0C0B0A090807060504030201", // report_id - "0000000000011F46", // time - ), - concat!( - // public_share - "00000004", // length - "30313233", // opaque data - ), - concat!( - // encrypted_input_share - "0D", // config_id concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time ), concat!( - // payload + // public_share "00000004", // length - "61626664", // opaque data + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000000", // length + "" // opaque data + ) ), ), ), @@ -3929,13 +4376,18 @@ mod tests { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continued { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("543210"), + }, }, PrepareStep { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + result: PrepareStepResult::Finished { + prep_msg: Vec::from("012345"), + }, }, ]), }, @@ -3943,19 +4395,29 @@ mod tests { "A5A5", // round concat!( // prepare_steps - "0000002C", // length + "00000040", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // payload + // prep_msg "00000006", // length "303132333435", // opaque data ), + concat!( + // prep_share + "00000006", // length + "353433323130", // opaque data + ), ), concat!( "100F0E0D0C0B0A090807060504030201", // report_id "01", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), ) ), ), @@ -3971,31 +4433,46 @@ mod tests { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continued { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("543210"), + }, }, PrepareStep { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + result: PrepareStepResult::Finished { + prep_msg: Vec::from("012345"), + }, }, ]), }, concat!(concat!( // prepare_steps - "0000002C", // length + "00000040", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // payload + // prep_msg "00000006", // length "303132333435", // opaque data ), + concat!( + // prep_share + "00000006", // length + "353433323130", // opaque data + ), ), concat!( "100F0E0D0C0B0A090807060504030201", // report_id "01", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), ) ),), )])