From e3106f61c17afa9c8b4ee220928def1b5ede1a14 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 29 Mar 2023 14:18:02 -0700 Subject: [PATCH 1/8] Update Report message definition. --- Cargo.lock | 1 + Cargo.toml | 1 + aggregator/Cargo.toml | 2 +- aggregator/src/aggregator.rs | 121 ++++++++--------------------- client/Cargo.toml | 1 + client/src/lib.rs | 17 ++-- integration_tests/Cargo.toml | 2 +- messages/src/lib.rs | 146 ++++++++++++++++++++++------------- 8 files changed, 141 insertions(+), 150 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3445c420c..f52074e93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1795,6 +1795,7 @@ dependencies = [ "derivative", "http", "http-api-problem", + "itertools", "janus_core", "janus_messages", "mockito", diff --git a/Cargo.toml b/Cargo.toml index 0f260c09f..c364b20b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ version = "0.4.0" # (yet) need other default features. # https://docs.rs/chrono/latest/chrono/#duration chrono = { version = "0.4", default-features = false } +itertools = "0.10" janus_aggregator = { version = "0.4", path = "aggregator" } janus_aggregator_core = { version = "0.4", path = "aggregator_core" } janus_build_script_utils = { version = "0.4", path = "build_script_utils" } diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index 20a30716c..8a4be731c 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -96,7 +96,7 @@ warp = { version = "0.3", features = ["tls"] } [dev-dependencies] assert_matches = "1" hyper = "0.14.25" -itertools = "0.10.5" +itertools.workspace = true # Enable `kube`'s `openssl-tls` feature (which takes precedence over the # `rustls-tls` feature when creating a default client) to work around rustls's # lack of support for connecting to servers by IP addresses, which affects many diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 9b8e92a63..827c6cb30 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -1171,24 +1171,16 @@ impl VdafOps { C: Clock, Q: UploadableQueryType, { - // The leader's report is the first one. - // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 - if report.encrypted_input_shares().len() != 2 { - return Err(Arc::new(Error::UnrecognizedMessage( - Some(*task.id()), - "unexpected number of encrypted shares in report", - ))); - } - let leader_encrypted_input_share = - &report.encrypted_input_shares()[Role::Leader.index().unwrap()]; - // Verify that the report's HPKE config ID is known. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 let hpke_keypair = task .hpke_keys() - .get(leader_encrypted_input_share.config_id()) + .get(report.leader_encrypted_input_share().config_id()) .ok_or_else(|| { - Error::OutdatedHpkeConfig(*task.id(), *leader_encrypted_input_share.config_id()) + Error::OutdatedHpkeConfig( + *task.id(), + *report.leader_encrypted_input_share().config_id(), + ) })?; let report_deadline = clock @@ -1256,7 +1248,7 @@ impl VdafOps { hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, task.role()), - leader_encrypted_input_share, + report.leader_encrypted_input_share(), &InputShareAad::new( *task.id(), report.metadata().clone(), @@ -1298,16 +1290,13 @@ impl VdafOps { } }; - let helper_encrypted_input_share = - &report.encrypted_input_shares()[Role::Helper.index().unwrap()]; - let report = LeaderStoredReport::new( *task.id(), report.metadata().clone(), public_share, Vec::from(leader_plaintext_input_share.extensions()), leader_input_share, - helper_encrypted_input_share.clone(), + report.helper_encrypted_input_share().clone(), ); report_writer @@ -3516,7 +3505,8 @@ mod tests { Report::new( report_metadata, public_share.get_encoded(), - Vec::from([leader_ciphertext, helper_ciphertext]), + leader_ciphertext, + helper_ciphertext, ) } @@ -3638,7 +3628,8 @@ mod tests { .unwrap(), ), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ); let mut response = drive_filter( Method::PUT, @@ -3657,29 +3648,6 @@ mod tests { ) .await; - // should reject a report with only one share with the unrecognizedMessage type. - let bad_report = Report::new( - report.metadata().clone(), - report.public_share().to_vec(), - Vec::from([report.encrypted_input_shares()[0].clone()]), - ); - let mut response = drive_filter( - Method::PUT, - task.report_upload_uri().unwrap().path(), - &bad_report.get_encoded(), - &filter, - ) - .await - .unwrap(); - check_response( - &mut response, - StatusCode::BAD_REQUEST, - "unrecognizedMessage", - "The message type for a response was incorrect or the payload was malformed.", - task.id(), - ) - .await; - // should reject a report using the wrong HPKE config for the leader, and reply with // the error type outdatedConfig. let unused_hpke_config_id = (0..) @@ -3689,16 +3657,15 @@ mod tests { let bad_report = Report::new( report.metadata().clone(), report.public_share().to_vec(), - Vec::from([ - HpkeCiphertext::new( - unused_hpke_config_id, - report.encrypted_input_shares()[0] - .encapsulated_key() - .to_vec(), - report.encrypted_input_shares()[0].payload().to_vec(), - ), - report.encrypted_input_shares()[1].clone(), - ]), + HpkeCiphertext::new( + unused_hpke_config_id, + report + .leader_encrypted_input_share() + .encapsulated_key() + .to_vec(), + report.leader_encrypted_input_share().payload().to_vec(), + ), + report.helper_encrypted_input_share().clone(), ); let mut response = drive_filter( Method::PUT, @@ -3727,7 +3694,8 @@ mod tests { let bad_report = Report::new( ReportMetadata::new(*report.metadata().id(), bad_report_time), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ); let mut response = drive_filter( Method::PUT, @@ -3816,7 +3784,8 @@ mod tests { .unwrap(), ), report.public_share().to_vec(), - report.encrypted_input_shares().to_vec(), + report.leader_encrypted_input_share().clone(), + report.helper_encrypted_input_share().clone(), ) .get_encoded(), ) @@ -4023,29 +3992,6 @@ mod tests { assert_eq!(want_report_ids, got_report_ids); } - #[tokio::test] - async fn upload_wrong_number_of_encrypted_shares() { - install_test_trace_subscriber(); - - let (_, aggregator, clock, task, _, _ephemeral_datastore) = - setup_upload_test(default_aggregator_config()).await; - let report = create_report(&task, clock.now()); - let report = Report::new( - report.metadata().clone(), - report.public_share().to_vec(), - Vec::from([report.encrypted_input_shares()[0].clone()]), - ); - - assert_matches!( - aggregator - .handle_upload(task.id(), &report.get_encoded()) - .await - .unwrap_err() - .as_ref(), - Error::UnrecognizedMessage(_, _) - ); - } - #[tokio::test] async fn upload_wrong_hpke_config_id() { install_test_trace_subscriber(); @@ -4062,16 +4008,15 @@ mod tests { let report = Report::new( report.metadata().clone(), report.public_share().to_vec(), - Vec::from([ - HpkeCiphertext::new( - unused_hpke_config_id, - report.encrypted_input_shares()[0] - .encapsulated_key() - .to_vec(), - report.encrypted_input_shares()[0].payload().to_vec(), - ), - report.encrypted_input_shares()[1].clone(), - ]), + HpkeCiphertext::new( + unused_hpke_config_id, + report + .leader_encrypted_input_share() + .encapsulated_key() + .to_vec(), + report.leader_encrypted_input_share().payload().to_vec(), + ), + report.helper_encrypted_input_share().clone(), ); assert_matches!(aggregator.handle_upload(task.id(), &report.get_encoded()).await.unwrap_err().as_ref(), Error::OutdatedHpkeConfig(task_id, config_id) => { diff --git a/client/Cargo.toml b/client/Cargo.toml index 713a25c3d..0b3ff3f26 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,6 +14,7 @@ backoff = { version = "0.4.0", features = ["tokio"] } derivative = "2.2.0" http = "0.2.9" http-api-problem = "0.56.0" +itertools.workspace = true janus_core.workspace = true janus_messages.workspace = true prio.workspace = true diff --git a/client/src/lib.rs b/client/src/lib.rs index 20d820c2e..7d0f5b773 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -4,6 +4,7 @@ use backoff::ExponentialBackoff; use derivative::Derivative; use http::header::CONTENT_TYPE; use http_api_problem::HttpApiProblem; +use itertools::Itertools; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, http::response_to_problem_details, @@ -12,8 +13,8 @@ use janus_core::{ time::{Clock, TimeExt}, }; use janus_messages::{ - Duration, HpkeCiphertext, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, - Report, ReportId, ReportMetadata, Role, TaskId, + Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, + ReportMetadata, Role, TaskId, }; use prio::{ codec::{Decode, Encode}, @@ -228,14 +229,14 @@ impl, C: Clock> Client { let report_metadata = ReportMetadata::new(report_id, time); let encoded_public_share = public_share.get_encoded(); - let encrypted_input_shares: Vec = [ + let (leader_encrypted_input_share, helper_encrypted_input_share) = [ (&self.leader_hpke_config, &Role::Leader), (&self.helper_hpke_config, &Role::Helper), ] .into_iter() .zip(input_shares) .map(|((hpke_config, receiver_role), input_share)| { - Ok(hpke::seal( + hpke::seal( hpke_config, &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, receiver_role), &PlaintextInputShare::new( @@ -249,14 +250,16 @@ impl, C: Clock> Client { encoded_public_share.clone(), ) .get_encoded(), - )?) + ) }) - .collect::>()?; + .collect_tuple() + .expect("iterator to yield two items"); // expect safety: iterator contains two items. Ok(Report::new( report_metadata, encoded_public_share, - encrypted_input_shares, + leader_encrypted_input_share?, + helper_encrypted_input_share?, )) } diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index a604b3aa3..4e5f8cd86 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -37,6 +37,6 @@ url = { version = "2.3.1", features = ["serde"] } [dev-dependencies] http = "0.2" -itertools = "0.10" +itertools.workspace = true janus_collector = { workspace = true, features = ["test-util"] } tempfile = "3" diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 989b501af..58a7b0bfb 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -427,6 +427,7 @@ impl Role { /// If this [`Role`] is one of the aggregators, returns the index at which /// that aggregator's message or data can be found in various lists, or /// `None` if the role is not an aggregator. + // XXX: can this be removed once all messages are updated to have explicit leader/helper fields? pub fn index(&self) -> Option { match self { // draft-gpew-priv-ppm §4.2: the leader's endpoint MUST be the first @@ -1222,7 +1223,8 @@ impl Decode for PlaintextInputShare { pub struct Report { metadata: ReportMetadata, public_share: Vec, - encrypted_input_shares: Vec, + leader_encrypted_input_share: HpkeCiphertext, + helper_encrypted_input_share: HpkeCiphertext, } impl Report { @@ -1233,12 +1235,14 @@ impl Report { pub fn new( metadata: ReportMetadata, public_share: Vec, - encrypted_input_shares: Vec, + leader_encrypted_input_share: HpkeCiphertext, + helper_encrypted_input_share: HpkeCiphertext, ) -> Self { Self { metadata, public_share, - encrypted_input_shares, + leader_encrypted_input_share, + helper_encrypted_input_share, } } @@ -1247,13 +1251,19 @@ impl Report { &self.metadata } + /// Retrieve the public share from this report. pub fn public_share(&self) -> &[u8] { &self.public_share } - /// Get this report's encrypted input shares. - pub fn encrypted_input_shares(&self) -> &[HpkeCiphertext] { - &self.encrypted_input_shares + /// Retrieve the encrypted leader input share from this report. + pub fn leader_encrypted_input_share(&self) -> &HpkeCiphertext { + &self.leader_encrypted_input_share + } + + /// Retrieve the encrypted helper input share from this report. + pub fn helper_encrypted_input_share(&self) -> &HpkeCiphertext { + &self.helper_encrypted_input_share } } @@ -1261,17 +1271,16 @@ impl Encode for Report { fn encode(&self, bytes: &mut Vec) { self.metadata.encode(bytes); encode_u32_items(bytes, &(), &self.public_share); - encode_u32_items(bytes, &(), &self.encrypted_input_shares); + self.leader_encrypted_input_share.encode(bytes); + self.helper_encrypted_input_share.encode(bytes); } fn encoded_len(&self) -> Option { let mut length = self.metadata.encoded_len()?; length += 4; length += self.public_share.len(); - length += 4; - for encrypted_input_share in self.encrypted_input_shares.iter() { - length += encrypted_input_share.encoded_len()?; - } + length += self.leader_encrypted_input_share.encoded_len()?; + length += self.helper_encrypted_input_share.encoded_len()?; Some(length) } } @@ -1280,12 +1289,14 @@ impl Decode for Report { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let metadata = ReportMetadata::decode(bytes)?; let public_share = decode_u32_items(&(), bytes)?; - let encrypted_input_shares = decode_u32_items(&(), bytes)?; + let leader_encrypted_input_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_input_share = HpkeCiphertext::decode(bytes)?; Ok(Self { metadata, public_share, - encrypted_input_shares, + leader_encrypted_input_share, + helper_encrypted_input_share, }) } } @@ -3135,7 +3146,16 @@ mod tests { Time::from_seconds_since_epoch(12345), ), Vec::new(), - Vec::new(), + HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), ), concat!( concat!( @@ -3148,9 +3168,33 @@ mod tests { "00000000", // length ), concat!( - // encrypted_input_shares - "00000000", // length - ) + // leader_encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435" // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + concat!( + // helper_encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), ), ), ( @@ -3160,18 +3204,16 @@ mod tests { Time::from_seconds_since_epoch(54321), ), Vec::from("3210"), - Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), - ]), + HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), ), concat!( concat!( @@ -3185,33 +3227,31 @@ mod tests { "33323130", // opaque data ), concat!( - // encrypted_input_shares - "00000022", // length + // leader_encrypted_input_share + "2A", // config_id concat!( - "2A", // config_id - concat!( - // encapsulated_context - "0006", // length - "303132333435" // opaque data - ), - concat!( - // payload - "00000006", // length - "353433323130", // opaque data - ), + // encapsulated_context + "0006", // length + "303132333435" // opaque data ), concat!( - "0D", // config_id - concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data - ), - concat!( - // payload - "00000004", // length - "61626664", // opaque data - ), + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + concat!( + // helper_encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data ), ), ), From 258d2447cab5e9e26fc118e2eeca1a722ebaa5fe Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 29 Mar 2023 15:11:07 -0700 Subject: [PATCH 2/8] Update Collection message definition. --- aggregator/src/aggregator.rs | 19 +- .../src/aggregator/collection_job_tests.rs | 5 +- collector/src/lib.rs | 274 +++++++----------- messages/src/lib.rs | 271 ++++++++++------- 4 files changed, 283 insertions(+), 286 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 827c6cb30..3747e2ec0 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -44,10 +44,10 @@ use janus_messages::{ query_type::{FixedSize, TimeInterval}, AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, - BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeCiphertext, - HpkeConfigId, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, - PlaintextInputShare, PrepareStep, PrepareStepResult, Report, ReportId, ReportIdChecksum, - ReportShare, ReportShareError, Role, TaskId, Time, + BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeConfigId, + HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, + PrepareStep, PrepareStepResult, Report, ReportId, ReportIdChecksum, ReportShare, + ReportShareError, Role, TaskId, Time, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter, Unit}, @@ -2208,10 +2208,8 @@ impl VdafOps { ), *report_count, spanned_interval, - Vec::::from([ - encrypted_leader_aggregate_share, - encrypted_helper_aggregate_share.clone(), - ]), + encrypted_leader_aggregate_share, + encrypted_helper_aggregate_share.clone(), ) .get_encoded(), )) @@ -7057,13 +7055,12 @@ mod tests { ) .unwrap() ); - assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[0], + collect_resp.leader_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_time_interval(batch_interval), @@ -7081,7 +7078,7 @@ mod tests { test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[1], + collect_resp.helper_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_time_interval(batch_interval), diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index de666c6ae..75ef2f32e 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -346,13 +346,12 @@ async fn collection_job_success_fixed_size() { .align_to_time_precision(test_case.task.time_precision()) .unwrap(), ); - assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[0], + collect_resp.leader_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_fixed_size(batch_id), @@ -370,7 +369,7 @@ async fn collection_job_success_fixed_size() { test_case.task.collector_hpke_config(), test_case.collector_hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), - &collect_resp.encrypted_aggregate_shares()[1], + collect_resp.helper_encrypted_aggregate_share(), &AggregateShareAad::new( *test_case.task.id(), BatchSelector::new_fixed_size(batch_id), diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 6d8514816..6884bc328 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -113,8 +113,6 @@ pub enum Error { Codec(#[from] prio::codec::CodecError), #[error("aggregate share decoding error")] AggregateShareDecode, - #[error("expected two aggregate shares, got {0}")] - AggregateShareCount(usize), #[error("VDAF error: {0}")] Vdaf(#[from] prio::vdaf::VdafError), #[error("HPKE error: {0}")] @@ -512,41 +510,40 @@ impl Collector { } let collect_response = CollectionMessage::::get_decoded(&response.bytes().await?)?; - if collect_response.encrypted_aggregate_shares().len() != 2 { - return Err(Error::AggregateShareCount( - collect_response.encrypted_aggregate_shares().len(), - )); - } - let aggregate_shares_bytes = collect_response - .encrypted_aggregate_shares() - .iter() - .zip(&[Role::Leader, Role::Helper]) - .map(|(encrypted_aggregate_share, role)| { - hpke::open( - &self.parameters.hpke_config, - &self.parameters.hpke_private_key, - &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, role, &Role::Collector), - encrypted_aggregate_share, - &AggregateShareAad::new( - self.parameters.task_id, - BatchSelector::::new(Q::batch_identifier_for_collection( - &job.query, - &collect_response, - )), - ) - .get_encoded(), - ) - }); - let aggregate_shares = aggregate_shares_bytes - .map(|bytes| { - V::AggregateShare::get_decoded_with_param( - &(&self.vdaf_collector, &job.aggregation_parameter), - &bytes?, + let aggregate_shares = [ + ( + Role::Leader, + collect_response.leader_encrypted_aggregate_share(), + ), + ( + Role::Helper, + collect_response.helper_encrypted_aggregate_share(), + ), + ] + .into_iter() + .map(|(role, encrypted_aggregate_share)| { + let bytes = hpke::open( + &self.parameters.hpke_config, + &self.parameters.hpke_private_key, + &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, &role, &Role::Collector), + encrypted_aggregate_share, + &AggregateShareAad::new( + self.parameters.task_id, + BatchSelector::::new(Q::batch_identifier_for_collection( + &job.query, + &collect_response, + )), ) - .map_err(|_err| Error::AggregateShareDecode) - }) - .collect::, Error>>()?; + .get_encoded(), + )?; + V::AggregateShare::get_decoded_with_param( + &(&self.vdaf_collector, &job.aggregation_parameter), + &bytes, + ) + .map_err(|_err| Error::AggregateShareDecode) + }) + .collect::, Error>>()?; let report_count = collect_response .report_count() @@ -738,30 +735,20 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::::from([ - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &transcript.aggregate_shares[0].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &transcript.aggregate_shares[1].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &transcript.aggregate_shares[0].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &transcript.aggregate_shares[1].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ) } @@ -776,30 +763,20 @@ mod tests { PartialBatchSelector::new_fixed_size(batch_id), 1, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - Vec::::from([ - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &transcript.aggregate_shares[0].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - ¶meters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &transcript.aggregate_shares[1].get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &transcript.aggregate_shares[0].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + ¶meters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &transcript.aggregate_shares[1].get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ) } @@ -1405,31 +1382,6 @@ mod tests { mock_collection_job_bad_message_bytes.assert_async().await; - let mock_collection_job_bad_share_count = server - .mock("POST", job.collection_job_url.path()) - .with_status(200) - .with_header( - CONTENT_TYPE.as_str(), - CollectionMessage::::MEDIA_TYPE, - ) - .with_body( - CollectionMessage::new( - PartialBatchSelector::new_time_interval(), - 0, - batch_interval, - Vec::new(), - ) - .get_encoded(), - ) - .expect_at_least(1) - .create_async() - .await; - - let error = collector.poll_once(&job).await.unwrap_err(); - assert_matches!(error, Error::AggregateShareCount(0)); - - mock_collection_job_bad_share_count.assert_async().await; - let mock_collection_job_bad_ciphertext = server .mock("POST", job.collection_job_url.path()) .with_status(200) @@ -1442,18 +1394,16 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - HpkeCiphertext::new( - *collector.parameters.hpke_config.id(), - Vec::new(), - Vec::new(), - ), - HpkeCiphertext::new( - *collector.parameters.hpke_config.id(), - Vec::new(), - Vec::new(), - ), - ]), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), ) .get_encoded(), ) @@ -1474,30 +1424,20 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - b"bad", - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - b"bad", - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + b"bad", + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + b"bad", + &associated_data.get_encoded(), + ) + .unwrap(), ); let mock_collection_job_bad_shares = server .mock("POST", job.collection_job_url.path()) @@ -1520,35 +1460,25 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, batch_interval, - Vec::from([ - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Leader, - &Role::Collector, - ), - &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) - .get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - hpke::seal( - &collector.parameters.hpke_config, - &HpkeApplicationInfo::new( - &Label::AggregateShare, - &Role::Helper, - &Role::Collector, - ), - &AggregateShare::from(OutputShare::from(Vec::from([ - Field64::from(0), - Field64::from(0), - ]))) + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), + &AggregateShare::from(OutputShare::from(Vec::from([Field64::from(0)]))) .get_encoded(), - &associated_data.get_encoded(), - ) - .unwrap(), - ]), + &associated_data.get_encoded(), + ) + .unwrap(), + hpke::seal( + &collector.parameters.hpke_config, + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), + &AggregateShare::from(OutputShare::from(Vec::from([ + Field64::from(0), + Field64::from(0), + ]))) + .get_encoded(), + &associated_data.get_encoded(), + ) + .unwrap(), ); let mock_collection_job_wrong_length = server .mock("POST", job.collection_job_url.path()) diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 58a7b0bfb..240857a09 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -1600,7 +1600,8 @@ pub struct Collection { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, - encrypted_aggregate_shares: Vec, + leader_encrypted_agg_share: HpkeCiphertext, + helper_encrypted_agg_share: HpkeCiphertext, } impl Collection { @@ -1612,34 +1613,41 @@ impl Collection { partial_batch_selector: PartialBatchSelector, report_count: u64, interval: Interval, - encrypted_aggregate_shares: Vec, + leader_encrypted_agg_share: HpkeCiphertext, + helper_encrypted_agg_share: HpkeCiphertext, ) -> Self { Self { partial_batch_selector, report_count, interval, - encrypted_aggregate_shares, + leader_encrypted_agg_share, + helper_encrypted_agg_share, } } - /// Gets the batch selector associated with this collection. + /// Retrieves the batch selector associated with this collection. pub fn partial_batch_selector(&self) -> &PartialBatchSelector { &self.partial_batch_selector } - /// Gets the number of reports that were aggregated into this collection. + /// Retrieves the number of reports that were aggregated into this collection. pub fn report_count(&self) -> u64 { self.report_count } - /// Gets the interval spanned by the reports aggregated into this collection. + /// Retrieves the interval spanned by the reports aggregated into this collection. pub fn interval(&self) -> &Interval { &self.interval } - /// Gets the encrypted aggregate shares associated with this collection. - pub fn encrypted_aggregate_shares(&self) -> &[HpkeCiphertext] { - &self.encrypted_aggregate_shares + /// Retrieves the leader encrypted aggregate share associated with this collection. + pub fn leader_encrypted_aggregate_share(&self) -> &HpkeCiphertext { + &self.leader_encrypted_agg_share + } + + /// Retrieves the helper encrypted aggregate share associated with this collection. + pub fn helper_encrypted_aggregate_share(&self) -> &HpkeCiphertext { + &self.helper_encrypted_agg_share } } @@ -1648,18 +1656,18 @@ impl Encode for Collection { self.partial_batch_selector.encode(bytes); self.report_count.encode(bytes); self.interval.encode(bytes); - encode_u32_items(bytes, &(), &self.encrypted_aggregate_shares); + self.leader_encrypted_agg_share.encode(bytes); + self.helper_encrypted_agg_share.encode(bytes); } fn encoded_len(&self) -> Option { - let mut length = self.partial_batch_selector.encoded_len()? - + self.report_count.encoded_len()? - + self.interval.encoded_len()?; - length += 4; - for encrypted_aggregate_share in self.encrypted_aggregate_shares.iter() { - length += encrypted_aggregate_share.encoded_len()?; - } - Some(length) + Some( + self.partial_batch_selector.encoded_len()? + + self.report_count.encoded_len()? + + self.interval.encoded_len()? + + self.leader_encrypted_agg_share.encoded_len()? + + self.helper_encrypted_agg_share.encoded_len()?, + ) } } @@ -1668,13 +1676,15 @@ impl Decode for Collection { let partial_batch_selector = PartialBatchSelector::decode(bytes)?; let report_count = u64::decode(bytes)?; let interval = Interval::decode(bytes)?; - let encrypted_aggregate_shares = decode_u32_items(&(), bytes)?; + let leader_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; + let helper_encrypted_agg_share = HpkeCiphertext::decode(bytes)?; Ok(Self { partial_batch_selector, report_count, interval, - encrypted_aggregate_shares, + leader_encrypted_agg_share, + helper_encrypted_agg_share, }) } } @@ -2664,6 +2674,9 @@ mod tests { for (val, hex_encoding) in vals_and_encodings { let mut encoded_val = Vec::new(); val.encode(&mut encoded_val); + if let Some(want_encoded_len) = val.encoded_len() { + assert_eq!(want_encoded_len, encoded_val.len()); + } let encoding = hex::decode(hex_encoding).unwrap(); assert_eq!( encoding, encoded_val, @@ -3508,7 +3521,16 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 0, interval, - encrypted_aggregate_shares: Vec::new(), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3522,8 +3544,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "00000000", // length + // leader_encrypted_agg_share + "0A", // config_id + concat!( + // encapsulated_context + "0004", // length + "30313233", // opaque data + ), + concat!( + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3532,18 +3578,16 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_time_interval(), report_count: 23, interval, - encrypted_aggregate_shares: Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(10), - Vec::from("0123"), - Vec::from("4567"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("01234"), - Vec::from("567"), - ), - ]), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3557,34 +3601,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "0000001E", // length + // leader_encrypted_agg_share + "0A", // config_id concat!( - "0A", // config_id - concat!( - // encapsulated_context - "0004", // length - "30313233", // opaque data - ), - concat!( - // payload - "00000004", // length - "34353637", // opaque data - ), + // encapsulated_context + "0004", // length + "30313233", // opaque data ), concat!( - "0C", // config_id - concat!( - // encapsulated_context - "0005", // length - "3031323334", // opaque data - ), - concat!( - // payload - "00000003", // length - "353637", // opaque data - ), - ) + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3599,7 +3641,16 @@ mod tests { )), report_count: 0, interval, - encrypted_aggregate_shares: Vec::new(), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3614,8 +3665,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "00000000", // length + // leader_encrypted_agg_share + "0A", // config_id + concat!( + // encapsulated_context + "0004", // length + "30313233", // opaque data + ), + concat!( + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), @@ -3626,18 +3701,16 @@ mod tests { )), report_count: 23, interval, - encrypted_aggregate_shares: Vec::from([ - HpkeCiphertext::new( - HpkeConfigId::from(10), - Vec::from("0123"), - Vec::from("4567"), - ), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("01234"), - Vec::from("567"), - ), - ]), + leader_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(10), + Vec::from("0123"), + Vec::from("4567"), + ), + helper_encrypted_agg_share: HpkeCiphertext::new( + HpkeConfigId::from(12), + Vec::from("01234"), + Vec::from("567"), + ), }, concat!( concat!( @@ -3652,34 +3725,32 @@ mod tests { "0000000000003039", // duration ), concat!( - // encrypted_aggregate_shares - "0000001E", // length + // leader_encrypted_agg_share + "0A", // config_id concat!( - "0A", // config_id - concat!( - // encapsulated_context - "0004", // length - "30313233", // opaque data - ), - concat!( - // payload - "00000004", // length - "34353637", // opaque data - ), + // encapsulated_context + "0004", // length + "30313233", // opaque data ), concat!( - "0C", // config_id - concat!( - // encapsulated_context - "0005", // length - "3031323334", // opaque data - ), - concat!( - // payload - "00000003", // length - "353637", // opaque data - ), - ) + // payload + "00000004", // length + "34353637", // opaque data + ), + ), + concat!( + // helper_encrypted_agg_share + "0C", // config_id + concat!( + // encapsulated_context + "0005", // length + "3031323334", // opaque data + ), + concat!( + // payload + "00000003", // length + "353637", // opaque data + ), ) ), ), From 054d0e4f0cde69f420bfaa5ac787fc8421190d0a Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 29 Mar 2023 17:53:12 -0700 Subject: [PATCH 3/8] Update Task struct definition. --- aggregator/src/aggregator.rs | 4 +- .../src/aggregator/aggregation_job_driver.rs | 31 +-- .../src/aggregator/collection_job_driver.rs | 11 +- aggregator/src/bin/janus_cli.rs | 10 +- aggregator_core/src/datastore.rs | 44 ++--- aggregator_core/src/task.rs | 183 +++++++----------- client/src/lib.rs | 112 ++++++----- collector/src/lib.rs | 7 +- core/Cargo.toml | 1 + core/src/task.rs | 7 +- db/schema.sql | 27 +-- docs/samples/tasks.yaml | 12 +- integration_tests/src/client.rs | 31 +-- integration_tests/src/janus.rs | 6 +- integration_tests/tests/common/mod.rs | 69 ++++--- integration_tests/tests/janus.rs | 22 +-- .../src/bin/janus_interop_aggregator.rs | 3 +- .../src/bin/janus_interop_client.rs | 7 +- interop_binaries/src/lib.rs | 4 +- 19 files changed, 273 insertions(+), 318 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 3747e2ec0..06dab5023 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -55,12 +55,14 @@ use opentelemetry::{ }; #[cfg(feature = "fpvec_bounded_l2")] use prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded; +#[cfg(feature = "test-util")] +use prio::vdaf::PrepareTransition; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, vdaf::{ self, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded}, - PrepareTransition, VdafError, + VdafError, }, }; use regex::Regex; diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index becb81e05..61bb16fc4 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -893,7 +893,6 @@ mod tests { }, }; use rand::random; - use reqwest::Url; use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration}; #[tokio::test] @@ -917,10 +916,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -1140,10 +1136,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock @@ -1433,10 +1426,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let report_metadata = ReportMetadata::new( @@ -1663,10 +1653,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let time = clock .now() @@ -1953,10 +1940,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let report_metadata = ReportMetadata::new( random(), @@ -2409,10 +2393,7 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index ad5297d72..cb9c6da37 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -499,7 +499,6 @@ mod tests { use prio::codec::{Decode, Encode}; use rand::random; use std::{str, sync::Arc, time::Duration as StdDuration}; - use url::Url; async fn setup_collection_job_test_case( server: &mut mockito::Server, @@ -513,10 +512,7 @@ mod tests { ) { let time_precision = Duration::from_seconds(500); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .with_time_precision(time_precision) .with_min_batch_size(10) .build(); @@ -640,10 +636,7 @@ mod tests { let time_precision = Duration::from_seconds(500); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .with_aggregator_endpoints(Vec::from([ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&server.url()).unwrap(), - ])) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) .with_time_precision(time_precision) .with_min_batch_size(10) .build(); diff --git a/aggregator/src/bin/janus_cli.rs b/aggregator/src/bin/janus_cli.rs index 6f342122c..be30d0e9d 100644 --- a/aggregator/src/bin/janus_cli.rs +++ b/aggregator/src/bin/janus_cli.rs @@ -778,9 +778,8 @@ mod tests { // YAML contains no task ID, VDAF verify keys, aggregator auth tokens, collector auth tokens // or HPKE keys. let serialized_task_yaml = r#" -- aggregator_endpoints: - - https://leader - - https://helper +- leader_aggregator_endpoint: https://leader + helper_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 @@ -801,9 +800,8 @@ mod tests { aggregator_auth_tokens: [] collector_auth_tokens: [] hpke_keys: [] -- aggregator_endpoints: - - https://leader - - https://helper +- leader_aggregator_endpoint: https://leader + helper_aggregator_endpoint: https://helper query_type: TimeInterval vdaf: !Prio3Sum bits: 2 diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 4ea41253c..bdca3886f 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -52,7 +52,6 @@ use std::{ use tokio::try_join; use tokio_postgres::{error::SqlState, row::RowIndex, IsolationLevel, Row, Statement, ToStatement}; use tracing::error; -use url::Url; #[cfg(feature = "test-util")] #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] @@ -314,20 +313,15 @@ impl Transaction<'_, C> { /// Writes a task into the datastore. #[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)] pub async fn put_task(&self, task: &Task) -> Result<(), Error> { - let endpoints: Vec<_> = task - .aggregator_endpoints() - .iter() - .map(Url::as_str) - .collect(); - // Main task insert. let stmt = self .prepare_cached( "INSERT INTO tasks ( - task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)", + task_id, aggregator_role, leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, + task_expiration, report_expiry_age, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)", ) .await?; self.execute( @@ -335,7 +329,8 @@ impl Transaction<'_, C> { &[ /* task_id */ &task.id().as_ref(), /* aggregator_role */ &AggregatorRole::from_role(*task.role())?, - /* aggregator_endpoints */ &endpoints, + /* leader_aggregator_endpoint */ &task.leader_aggregator_endpoint().as_str(), + /* helper_aggregator_endpoint */ &task.helper_aggregator_endpoint().as_str(), /* query_type */ &Json(task.query_type()), /* vdaf */ &Json(task.vdaf()), /* max_batch_query_count */ @@ -557,9 +552,9 @@ impl Transaction<'_, C> { let params: &[&(dyn ToSql + Sync)] = &[&task_id.as_ref()]; let stmt = self .prepare_cached( - "SELECT aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config + "SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint, + query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age, + min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config FROM tasks WHERE task_id = $1", ) .await?; @@ -629,9 +624,10 @@ impl Transaction<'_, C> { pub async fn get_tasks(&self) -> Result, Error> { let stmt = self .prepare_cached( - "SELECT task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, - max_batch_query_count, task_expiration, report_expiry_age, min_batch_size, - time_precision, tolerable_clock_skew, collector_hpke_config + "SELECT task_id, aggregator_role, leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count, + task_expiration, report_expiry_age, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config FROM tasks", ) .await?; @@ -766,11 +762,10 @@ impl Transaction<'_, C> { ) -> Result { // Scalar task parameters. let aggregator_role: AggregatorRole = row.get("aggregator_role"); - let endpoints = row - .get::<_, Vec>("aggregator_endpoints") - .into_iter() - .map(|endpoint| Ok(Url::parse(&endpoint)?)) - .collect::>()?; + let leader_aggregator_endpoint = + row.get::<_, String>("leader_aggregator_endpoint").parse()?; + let helper_aggregator_endpoint = + row.get::<_, String>("helper_aggregator_endpoint").parse()?; let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?; @@ -855,7 +850,8 @@ impl Transaction<'_, C> { Ok(Task::new( *task_id, - endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, query_type, vdaf, aggregator_role.as_role(), diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 2f99ddb09..a45ef3608 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -14,11 +14,7 @@ use janus_messages::{ }; use rand::{distributions::Standard, random, thread_rng, Rng}; use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; -use std::{ - array::TryFromSliceError, - collections::HashMap, - fmt::{self, Formatter}, -}; +use std::{array::TryFromSliceError, collections::HashMap}; use url::Url; /// Errors that methods and functions in this module may return. @@ -77,10 +73,12 @@ impl TryFrom<&SecretBytes> for VerifyKey { pub struct Task { /// Unique identifier for the task. task_id: TaskId, - /// URLs relative to which aggregator API endpoints are found. The first - /// entry is the leader's. - #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - aggregator_endpoints: Vec, + /// URL relative to which the Leader's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + leader_aggregator_endpoint: Url, + /// URL relative to which the Helper's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + helper_aggregator_endpoint: Url, /// The query type this task uses to generate batches. query_type: QueryType, /// The VDAF this task executes. @@ -122,7 +120,8 @@ impl Task { #[allow(clippy::too_many_arguments)] pub fn new>( task_id: TaskId, - mut aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -138,13 +137,6 @@ impl Task { collector_auth_tokens: Vec, hpke_keys: I, ) -> Result { - // Ensure provided aggregator endpoints end with a slash, as we will be joining additional - // path segments into these endpoints & the Url::join implementation is persnickety about - // the slash at the end of the path. - for url in &mut aggregator_endpoints { - url_ensure_trailing_slash(url); - } - // Compute hpke_configs mapping cfg.id -> (cfg, key). let hpke_keys: HashMap = hpke_keys .into_iter() @@ -153,7 +145,8 @@ impl Task { let task = Self { task_id, - aggregator_endpoints, + leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint), + helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint), query_type, vdaf, role, @@ -174,10 +167,6 @@ impl Task { } fn validate(&self) -> Result<(), Error> { - // DAP currently only supports configurations of exactly two aggregators. - if self.aggregator_endpoints.len() != 2 { - return Err(Error::InvalidParameter("aggregator_endpoints")); - } if !self.role.is_aggregator() { return Err(Error::InvalidParameter("role")); } @@ -203,9 +192,14 @@ impl Task { &self.task_id } - /// Retrieves the aggregator endpoints associated with this task in natural order. - pub fn aggregator_endpoints(&self) -> &[Url] { - &self.aggregator_endpoints + /// Retrieves the Leader's aggregator endpoint associated with this task. + pub fn leader_aggregator_endpoint(&self) -> &Url { + &self.leader_aggregator_endpoint + } + + /// Retrieves the Helper's aggregator endpoint associated with this task. + pub fn helper_aggregator_endpoint(&self) -> &Url { + &self.helper_aggregator_endpoint } /// Retrieves the query type associated with this task. @@ -303,12 +297,6 @@ impl Task { } } - /// Returns the [`Url`] relative to which the server performing `role` serves its API. - pub fn aggregator_url(&self, role: &Role) -> Result<&Url, Error> { - let index = role.index().ok_or(Error::InvalidParameter(role.as_str()))?; - Ok(&self.aggregator_endpoints[index]) - } - /// Returns the [`AuthenticationToken`] currently used by this aggregator to authenticate itself /// to other aggregators. pub fn primary_aggregator_auth_token(&self) -> &AuthenticationToken { @@ -363,14 +351,14 @@ impl Task { /// Returns the URI at which reports may be uploaded for this task. pub fn report_upload_uri(&self) -> Result { Ok(self - .aggregator_url(&Role::Leader)? + .leader_aggregator_endpoint() .join(&format!("{}/reports", self.tasks_path()))?) } /// Returns the URI at which the helper resource for the specified aggregation job ID can be /// accessed. pub fn aggregation_job_uri(&self, aggregation_job_id: &AggregationJobId) -> Result { - Ok(self.aggregator_url(&Role::Helper)?.join(&format!( + Ok(self.helper_aggregator_endpoint().join(&format!( "{}/aggregation_jobs/{aggregation_job_id}", self.tasks_path() ))?) @@ -379,34 +367,27 @@ impl Task { /// Returns the URI at which the helper aggregate shares resource can be accessed. pub fn aggregate_shares_uri(&self) -> Result { Ok(self - .aggregator_url(&Role::Helper)? + .helper_aggregator_endpoint() .join(&format!("{}/aggregate_shares", self.tasks_path()))?) } /// Returns the URI at which the leader resource for the specified collection job ID can be /// accessed. pub fn collection_job_uri(&self, collection_job_id: &CollectionJobId) -> Result { - Ok(self.aggregator_url(&Role::Leader)?.join(&format!( + Ok(self.leader_aggregator_endpoint().join(&format!( "{}/collection_jobs/{collection_job_id}", self.tasks_path() ))?) } } -fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for url in urls { - list.entry(&format!("{url}")); - } - list.finish() -} - /// SerializedTask is an intermediate representation for tasks being serialized via the Serialize & /// Deserialize traits. #[derive(Clone, Serialize, Deserialize)] pub struct SerializedTask { task_id: Option, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, query_type: QueryType, vdaf: VdafInstance, role: Role, @@ -499,7 +480,8 @@ impl Serialize for Task { SerializedTask { task_id: Some(self.task_id), - aggregator_endpoints: self.aggregator_endpoints.clone(), + leader_aggregator_endpoint: self.leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint: self.helper_aggregator_endpoint.clone(), query_type: self.query_type, vdaf: self.vdaf.clone(), role: self.role, @@ -567,7 +549,8 @@ impl TryFrom for Task { Task::new( task_id, - serialized_task.aggregator_endpoints, + serialized_task.leader_aggregator_endpoint, + serialized_task.helper_aggregator_endpoint, serialized_task.query_type, serialized_task.vdaf, serialized_task.role, @@ -594,7 +577,6 @@ impl<'de> Deserialize<'de> for Task { } } -// This is public to allow use in integration tests. #[cfg(feature = "test-util")] #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] pub mod test_util { @@ -664,10 +646,8 @@ pub mod test_util { Self( Task::new( task_id, - Vec::from([ - "https://leader.endpoint".parse().unwrap(), - "https://helper.endpoint".parse().unwrap(), - ]), + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), query_type, vdaf, role, @@ -692,17 +672,20 @@ pub mod test_util { Self(Task { task_id, ..self.0 }) } - /// Associates the eventual task with the given aggregator endpoints. - pub fn with_aggregator_endpoints(self, aggregator_endpoints: Vec) -> Self { + /// Associates the eventual task with the given aggregator endpoint for the Leader. + pub fn with_leader_aggregator_endpoint(self, leader_aggregator_endpoint: Url) -> Self { Self(Task { - aggregator_endpoints, + leader_aggregator_endpoint, ..self.0 }) } - /// Retrieves the aggregator endpoints associated with this task builder. - pub fn aggregator_endpoints(&self) -> &[Url] { - self.0.aggregator_endpoints() + /// Associates the eventual task with the given aggregator endpoint for the Helper. + pub fn with_helper_aggregator_endpoint(self, helper_aggregator_endpoint: Url) -> Self { + Self(Task { + helper_aggregator_endpoint, + ..self.0 + }) } /// Associates the eventual task with the given aggregator role. @@ -828,7 +811,6 @@ mod tests { }; use rand::random; use serde_test::{assert_tokens, Token}; - use url::Url; #[test] fn task_serialization() { @@ -852,10 +834,8 @@ mod tests { // As leader, we receive an error if no collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, @@ -876,10 +856,8 @@ mod tests { // As leader, we receive no error if a collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, @@ -900,10 +878,8 @@ mod tests { // As helper, we receive no error if no collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, @@ -924,10 +900,8 @@ mod tests { // As helper, we receive an error if a collector auth token is specified. Task::new( random(), - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, @@ -950,10 +924,8 @@ mod tests { fn aggregator_endpoints_end_in_slash() { let task = Task::new( random(), - Vec::from([ - "http://leader_endpoint/foo/bar".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint/foo/bar".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, @@ -972,11 +944,12 @@ mod tests { .unwrap(); assert_eq!( - task.aggregator_endpoints, - Vec::from([ - "http://leader_endpoint/foo/bar/".parse().unwrap(), - "http://helper_endpoint/".parse().unwrap() - ]) + task.leader_aggregator_endpoint, + "http://leader_endpoint/foo/bar/".parse().unwrap(), + ); + assert_eq!( + task.helper_aggregator_endpoint, + "http://helper_endpoint/".parse().unwrap(), ); } @@ -999,10 +972,8 @@ mod tests { VdafInstance::Prio3Count, Role::Leader, ) - .with_aggregator_endpoints(Vec::from([ - Url::parse("https://leader.com/prefix/").unwrap(), - Url::parse("https://helper.com/prefix/").unwrap(), - ])) + .with_leader_aggregator_endpoint("https://leader.com/prefix/".parse().unwrap()) + .with_helper_aggregator_endpoint("https://helper.com/prefix/".parse().unwrap()) .build(), ), ] { @@ -1030,10 +1001,8 @@ mod tests { assert_tokens( &Task::new( TaskId::from([0; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, @@ -1068,16 +1037,15 @@ mod tests { &[ Token::Struct { name: "SerializedTask", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Some, Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::UnitVariant { name: "QueryType", @@ -1193,10 +1161,8 @@ mod tests { assert_tokens( &Task::new( TaskId::from([255; 32]), - Vec::from([ - "https://example.com/".parse().unwrap(), - "https://example.net/".parse().unwrap(), - ]), + "https://example.com/".parse().unwrap(), + "https://example.net/".parse().unwrap(), QueryType::FixedSize { max_batch_size: 10 }, VdafInstance::Prio3CountVec { length: 8 }, Role::Helper, @@ -1231,16 +1197,15 @@ mod tests { &[ Token::Struct { name: "SerializedTask", - len: 16, + len: 17, }, Token::Str("task_id"), Token::Some, Token::Str("__________________________________________8"), - Token::Str("aggregator_endpoints"), - Token::Seq { len: Some(2) }, + Token::Str("leader_aggregator_endpoint"), Token::Str("https://example.com/"), + Token::Str("helper_aggregator_endpoint"), Token::Str("https://example.net/"), - Token::SeqEnd, Token::Str("query_type"), Token::StructVariant { name: "QueryType", @@ -1370,10 +1335,8 @@ mod tests { let bad_agg_auth_token = SerializedTask { task_id: Some(random()), - aggregator_endpoints: Vec::from([ - "https://www.example.com/".parse().unwrap(), - "https://www.example.net/".parse().unwrap(), - ]), + leader_aggregator_endpoint: "https://www.example.com/".parse().unwrap(), + helper_aggregator_endpoint: "https://www.example.net/".parse().unwrap(), query_type: QueryType::TimeInterval, vdaf: VdafInstance::Prio3Count, role: Role::Helper, @@ -1396,10 +1359,8 @@ mod tests { let bad_collector_auth_token = SerializedTask { task_id: Some(random()), - aggregator_endpoints: Vec::from([ - "https://www.example.com/".parse().unwrap(), - "https://www.example.net/".parse().unwrap(), - ]), + leader_aggregator_endpoint: "https://www.example.com/".parse().unwrap(), + helper_aggregator_endpoint: "https://www.example.net/".parse().unwrap(), query_type: QueryType::TimeInterval, vdaf: VdafInstance::Prio3Count, role: Role::Leader, diff --git a/client/src/lib.rs b/client/src/lib.rs index 7d0f5b773..eb508a5be 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -21,10 +21,7 @@ use prio::{ vdaf, }; use rand::random; -use std::{ - fmt::{self, Formatter}, - io::Cursor, -}; +use std::io::Cursor; use url::Url; #[derive(Debug, thiserror::Error)] @@ -61,10 +58,12 @@ static CLIENT_USER_AGENT: &str = concat!( pub struct ClientParameters { /// Unique identifier for the task. task_id: TaskId, - /// URLs relative to which aggregator API endpoints are found. The first - /// entry is the leader's. - #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - aggregator_endpoints: Vec, + /// URL relative to which the Leader's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + leader_aggregator_endpoint: Url, + /// URL relative to which the Helper's API endpoints are found. + #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] + helper_aggregator_endpoint: Url, /// The time precision of the task. This value is shared by all parties in the protocol, and is /// used to compute report timestamps. time_precision: Duration, @@ -74,10 +73,16 @@ pub struct ClientParameters { impl ClientParameters { /// Creates a new set of client task parameters. - pub fn new(task_id: TaskId, aggregator_endpoints: Vec, time_precision: Duration) -> Self { + pub fn new( + task_id: TaskId, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, + time_precision: Duration, + ) -> Self { Self::new_with_backoff( task_id, - aggregator_endpoints, + leader_aggregator_endpoint, + helper_aggregator_endpoint, time_precision, http_request_exponential_backoff(), ) @@ -86,35 +91,32 @@ impl ClientParameters { /// Creates a new set of client task parameters with non-default HTTP request retry parameters. pub fn new_with_backoff( task_id: TaskId, - mut aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, time_precision: Duration, http_request_retry_parameters: ExponentialBackoff, ) -> Self { - // Ensure provided aggregator endpoints end with a slash, as we will be joining additional - // path segments into these endpoints & the Url::join implementation is persnickety about - // the slash at the end of the path. - for url in &mut aggregator_endpoints { - url_ensure_trailing_slash(url); - } - Self { task_id, - aggregator_endpoints, + leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint), + helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint), time_precision, http_request_retry_parameters, } } - /// The URL relative to which the API endpoints for the aggregator may be - /// found, if the role is an aggregator, or an error otherwise. + /// The URL relative to which the API endpoints for the aggregator may be found, if the role is + /// an aggregator, or an error otherwise. fn aggregator_endpoint(&self, role: &Role) -> Result<&Url, Error> { - Ok(&self.aggregator_endpoints[role - .index() - .ok_or(Error::InvalidParameter("role is not an aggregator"))?]) + match role { + Role::Leader => Ok(&self.leader_aggregator_endpoint), + Role::Helper => Ok(&self.helper_aggregator_endpoint), + _ => Err(Error::InvalidParameter("role is not an aggregator")), + } } - /// URL from which the HPKE configuration for the server filling `role` may - /// be fetched per draft-gpew-priv-ppm §4.3.1 + /// URL from which the HPKE configuration for the server filling `role` may be fetched per + /// draft-gpew-priv-ppm §4.3.1 fn hpke_config_endpoint(&self, role: &Role) -> Result { Ok(self.aggregator_endpoint(role)?.join("hpke_config")?) } @@ -122,21 +124,13 @@ impl ClientParameters { // URI to which reports may be uploaded for the provided task. fn reports_resource_uri(&self, task_id: &TaskId) -> Result { Ok(self - .aggregator_endpoint(&Role::Leader)? + .leader_aggregator_endpoint .join(&format!("tasks/{task_id}/reports"))?) } } -fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for url in urls { - list.entry(&format!("{url}")); - } - list.finish() -} - -/// Fetches HPKE configuration from the specified aggregator using the -/// aggregator endpoints in the provided [`ClientParameters`]. +/// Fetches HPKE configuration from the specified aggregator using the aggregator endpoints in the +/// provided [`ClientParameters`]. #[tracing::instrument(err)] pub async fn aggregator_hpke_config( client_parameters: &ClientParameters, @@ -312,14 +306,15 @@ mod tests { use url::Url; fn setup_client>( - server: &mut mockito::Server, + server: &mockito::Server, vdaf_client: V, ) -> Client { let server_url = Url::parse(&server.url()).unwrap(); Client::new( ClientParameters::new_with_backoff( random(), - Vec::from([server_url.clone(), server_url]), + server_url.clone(), + server_url, Duration::from_seconds(1), test_http_request_exponential_backoff(), ), @@ -335,19 +330,18 @@ mod tests { fn aggregator_endpoints_end_in_slash() { let client_parameters = ClientParameters::new( random(), - Vec::from([ - "http://leader_endpoint/foo/bar".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), + "http://leader_endpoint/foo/bar".parse().unwrap(), + "http://helper_endpoint".parse().unwrap(), Duration::from_seconds(1), ); assert_eq!( - client_parameters.aggregator_endpoints, - Vec::from([ - "http://leader_endpoint/foo/bar/".parse().unwrap(), - "http://helper_endpoint/".parse().unwrap() - ]) + client_parameters.leader_aggregator_endpoint, + "http://leader_endpoint/foo/bar/".parse().unwrap() + ); + assert_eq!( + client_parameters.helper_aggregator_endpoint, + "http://helper_endpoint/".parse().unwrap() ); } @@ -355,7 +349,7 @@ mod tests { async fn upload_prio3_count() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -376,9 +370,9 @@ mod tests { #[tokio::test] async fn upload_prio3_invalid_measurement() { install_test_trace_subscriber(); - let mut server = mockito::Server::new_async().await; + let server = mockito::Server::new_async().await; let vdaf = Prio3::new_sum(2, 16).unwrap(); - let client = setup_client(&mut server, vdaf); + let client = setup_client(&server, vdaf); // 65536 is too big for a 16 bit sum and will be rejected by the VDAF. // Make sure we get the right error variant but otherwise we aren't @@ -390,7 +384,7 @@ mod tests { async fn upload_prio3_http_status_code() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -417,7 +411,7 @@ mod tests { async fn upload_problem_details() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let client = setup_client(&mut server, Prio3::new_count(2).unwrap()); + let client = setup_client(&server, Prio3::new_count(2).unwrap()); let mocked_upload = server .mock( @@ -458,8 +452,12 @@ mod tests { async fn upload_bad_time_precision() { install_test_trace_subscriber(); - let client_parameters = - ClientParameters::new(random(), Vec::new(), Duration::from_seconds(0)); + let client_parameters = ClientParameters::new( + random(), + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), + Duration::from_seconds(0), + ); let client = Client::new( client_parameters, Prio3::new_count(2).unwrap(), @@ -475,9 +473,9 @@ mod tests { #[test] fn report_timestamp() { install_test_trace_subscriber(); - let mut server = mockito::Server::new(); + let server = mockito::Server::new(); let vdaf = Prio3::new_count(2).unwrap(); - let mut client = setup_client(&mut server, vdaf); + let mut client = setup_client(&server, vdaf); client.parameters.time_precision = Duration::from_seconds(100); client.clock = MockClock::new(Time::from_seconds_since_epoch(101)); diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 6884bc328..a95df87ef 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -187,17 +187,14 @@ impl CollectorParameters { /// Creates a new set of collector task parameters. pub fn new( task_id: TaskId, - mut leader_endpoint: Url, + leader_endpoint: Url, authentication_token: AuthenticationToken, hpke_config: HpkeConfig, hpke_private_key: HpkePrivateKey, ) -> CollectorParameters { - // Ensure the provided leader endpoint ends with a slash. - url_ensure_trailing_slash(&mut leader_endpoint); - CollectorParameters { task_id, - leader_endpoint, + leader_endpoint: url_ensure_trailing_slash(leader_endpoint), authentication: Authentication::DapAuthToken(authentication_token), hpke_config, hpke_private_key, diff --git a/core/Cargo.toml b/core/Cargo.toml index 179330310..0a618ed2b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -57,6 +57,7 @@ tokio = { version = "1.26", features = ["macros", "net", "rt"] } tracing = "0.1.37" tracing-log = { version = "0.1.3", optional = true } tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"], optional = true } +url = "2" [dev-dependencies] fixed = "1.23" diff --git a/core/src/task.rs b/core/src/task.rs index ae23735c3..fe02d350e 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -1,8 +1,8 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use rand::{distributions::Standard, prelude::Distribution}; -use reqwest::Url; use ring::constant_time; use serde::{Deserialize, Serialize}; +use url::Url; /// HTTP header where auth tokens are provided in messages between participants. pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token"; @@ -568,15 +568,16 @@ impl Distribution for Standard { } } -/// Modifies a [`Url`] in place to ensure it ends with a slash. +/// Returns the given [`Url`], possibly modified to end with a slash. /// /// Aggregator endpoint URLs should end with a slash if they will be used with [`Url::join`], /// because that method will drop the last path component of the base URL if it does not end with a /// slash. -pub fn url_ensure_trailing_slash(url: &mut Url) { +pub fn url_ensure_trailing_slash(mut url: Url) -> Url { if !url.as_str().ends_with('/') { url.set_path(&format!("{}/", url.path())); } + url } #[cfg(test)] diff --git a/db/schema.sql b/db/schema.sql index faed1fc9d..3590dd1c5 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -11,19 +11,20 @@ CREATE TYPE AGGREGATOR_ROLE AS ENUM( -- Corresponds to a DAP task, containing static data associated with the task. CREATE TABLE tasks( - id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only - task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification - aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task - aggregator_endpoints TEXT[] NOT NULL, -- aggregator HTTPS endpoints, leader first - query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters - vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters - max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected - task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted - report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. - min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected - time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds - tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds - collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, -- artificial ID, internal-only + task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification + aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task + leader_aggregator_endpoint TEXT NOT NULL, -- Leader's API endpoint + helper_aggregator_endpoint TEXT NOT NULL, -- Helper's API endpoint + query_type JSONB NOT NULL, -- the query type in use for this task, along with its parameters + vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters + max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected + task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted + report_expiry_age BIGINT, -- the maximum age of a report before it is considered expired (and acceptable for garbage collection), in seconds. NULL means that GC is disabled. + min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected + time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds + tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds + collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) ); -- The aggregator authentication tokens used by a given task. diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index e9f0c801a..470e27f89 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -6,10 +6,9 @@ # DAP's recommendation. task_id: "G9YKXjoEjfoU7M_fi_o2H0wmzavRb2sBFHeykeRhDMk" - # HTTPS endpoints of the leader and helper aggregators, in a list. - aggregator_endpoints: - - "https://example.com/" - - "https://example.net/" + # HTTPS endpoints of the leader and helper aggregators. + leader_aggregator_endpoint: "https://example.com/" + helper_aggregator_endpoint: "https://example.net/" # The DAP query type. See below for an example of a fixed-size task query_type: TimeInterval @@ -94,9 +93,8 @@ private_key: wFRYwiypcHC-mkGP1u3XQgIvtnlkQlUfZjgtM_zRsnI - task_id: "D-hCKPuqL2oTf7ZVRVyMP5VGt43EAEA8q34mDf6p1JE" - aggregator_endpoints: - - "https://example.org/" - - "https://example.com/" + leader_aggregator_endpoint: "https://example.org/" + helper_aggregator_endpoint: "https://example.com/" # For tasks using the fixed size query type, an additional `max_batch_size` # parameter must be provided. query_type: !FixedSize diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index dc4a01ff2..1097f62ad 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -159,18 +159,22 @@ impl<'a> ClientBackend<'a> { pub async fn build( &self, task: &Task, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, vdaf: V, ) -> anyhow::Result> where V: vdaf::Client<16> + InteropClientEncoding, { match self { - ClientBackend::InProcess => { - ClientImplementation::new_in_process(task, aggregator_endpoints, vdaf) - .await - .map_err(Into::into) - } + ClientBackend::InProcess => ClientImplementation::new_in_process( + task, + leader_aggregator_endpoint, + helper_aggregator_endpoint, + vdaf, + ) + .await + .map_err(Into::into), ClientBackend::Container { container_client, container_image, @@ -217,11 +221,16 @@ where { pub async fn new_in_process( task: &Task, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: Url, + helper_aggregator_endpoint: Url, vdaf: V, ) -> Result, janus_client::Error> { - let client_parameters = - ClientParameters::new(*task.id(), aggregator_endpoints, *task.time_precision()); + let client_parameters = ClientParameters::new( + *task.id(), + leader_aggregator_endpoint, + helper_aggregator_endpoint, + *task.time_precision(), + ); let http_client = default_http_client()?; let leader_config = aggregator_hpke_config(&client_parameters, &Role::Leader, task.id(), &http_client) @@ -259,8 +268,8 @@ where let http_client = reqwest::Client::new(); ClientImplementation::Container(Box::new(ContainerClientImplementation { _container: container, - leader: task.aggregator_endpoints()[Role::Leader.index().unwrap()].clone(), - helper: task.aggregator_endpoints()[Role::Helper.index().unwrap()].clone(), + leader: task.leader_aggregator_endpoint().clone(), + helper: task.helper_aggregator_endpoint().clone(), task_id: *task.id(), time_precision: *task.time_precision(), vdaf, diff --git a/integration_tests/src/janus.rs b/integration_tests/src/janus.rs index 475b66d13..3622f453b 100644 --- a/integration_tests/src/janus.rs +++ b/integration_tests/src/janus.rs @@ -52,7 +52,11 @@ impl<'a> Janus<'a> { task: &Task, ) -> Janus<'a> { // Start the Janus interop aggregator container running. - let endpoint = task.aggregator_url(task.role()).unwrap(); + let endpoint = match task.role() { + Role::Leader => task.leader_aggregator_endpoint(), + Role::Helper => task.helper_aggregator_endpoint(), + _ => panic!("unexpected task role"), + }; let container = container_client.run( RunnableImage::from(Aggregator::default()) .with_network(network) diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index b83104ddd..376499918 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -30,10 +30,12 @@ pub fn test_task_builders( let endpoint_random_value = hex::encode(random::<[u8; 4]>()); let collector_keypair = generate_test_hpke_config_and_private_key(); let leader_task = TaskBuilder::new(query_type, vdaf, Role::Leader) - .with_aggregator_endpoints(Vec::from([ + .with_leader_aggregator_endpoint( Url::parse(&format!("http://leader-{endpoint_random_value}:8080/")).unwrap(), + ) + .with_helper_aggregator_endpoint( Url::parse(&format!("http://helper-{endpoint_random_value}:8080/")).unwrap(), - ])) + ) .with_min_batch_size(46) .with_collector_hpke_config(collector_keypair.config().clone()); let helper_task = leader_task @@ -106,7 +108,7 @@ where pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( vdaf: V, - aggregator_endpoints: Vec, + leader_aggregator_endpoint: &Url, leader_task: &'a Task, collector_private_key: &'a HpkePrivateKey, test_case: &'a AggregationTestCase, @@ -124,7 +126,7 @@ pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( let collector_params = CollectorParameters::new( *leader_task.id(), - aggregator_endpoints[Role::Leader.index().unwrap()].clone(), + leader_aggregator_endpoint.clone(), leader_task.primary_collector_auth_token().clone(), leader_task.collector_hpke_config().clone(), collector_private_key.clone(), @@ -143,9 +145,7 @@ pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( janus_collector::default_http_client().unwrap(), ); - let forwarded_port = aggregator_endpoints[Role::Leader.index().unwrap()] - .port() - .unwrap(); + let forwarded_port = leader_aggregator_endpoint.port().unwrap(); // Send a collect request and verify that we got the correct result. match leader_task.query_type() { @@ -218,12 +218,10 @@ pub async fn submit_measurements_and_verify_aggregate( client_backend: &ClientBackend<'_>, ) { // Translate aggregator endpoints for our perspective outside the container network. - let aggregator_endpoints: Vec<_> = leader_task - .aggregator_endpoints() - .iter() - .zip([leader_port, helper_port]) - .map(|(url, port)| translate_url_for_external_access(url, port)) - .collect(); + let leader_aggregator_endpoint = + translate_url_for_external_access(leader_task.leader_aggregator_endpoint(), leader_port); + let helper_aggregator_endpoint = + translate_url_for_external_access(leader_task.helper_aggregator_endpoint(), helper_port); // We generate exactly one batch's worth of measurement uploads to work around an issue in // Daphne at time of writing. @@ -247,13 +245,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -275,13 +278,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -315,13 +323,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -354,13 +367,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, @@ -394,13 +412,18 @@ pub async fn submit_measurements_and_verify_aggregate( }; let client_implementation = client_backend - .build(leader_task, aggregator_endpoints.clone(), vdaf.clone()) + .build( + leader_task, + leader_aggregator_endpoint.clone(), + helper_aggregator_endpoint, + vdaf.clone(), + ) .await .unwrap(); submit_measurements_and_verify_aggregate_generic( vdaf, - aggregator_endpoints, + &leader_aggregator_endpoint, leader_task, collector_private_key, &test_case, diff --git a/integration_tests/tests/janus.rs b/integration_tests/tests/janus.rs index e0fac0c27..dcad841bc 100644 --- a/integration_tests/tests/janus.rs +++ b/integration_tests/tests/janus.rs @@ -73,15 +73,12 @@ impl<'a> JanusPair<'a> { // where "port" is whatever unused port we use with `kubectl port-forward`. But when // the aggregators talk to each other, they do it on the cluster's private network, // and so they need the in-cluster DNS name of the other aggregator. However, since - // aggregators use the endpoint URLs in the task to construct collection job URIs, we - // must only fix the _peer_ aggregator's endpoint. - let leader_endpoints = { - let mut endpoints = leader_task.aggregator_endpoints().to_vec(); - endpoints[1] = Self::in_cluster_aggregator_url(&helper_namespace); - endpoints - }; + // aggregators use the endpoint URLs in the task to construct collection job URIs, + // we must only fix the _peer_ aggregator's endpoint. let leader_task = leader_task - .with_aggregator_endpoints(leader_endpoints) + .with_helper_aggregator_endpoint(Self::in_cluster_aggregator_url( + &helper_namespace, + )) .build(); let leader = Janus::new_with_kubernetes_cluster( &kubeconfig_path, @@ -91,13 +88,10 @@ impl<'a> JanusPair<'a> { ) .await; - let helper_endpoints = { - let mut endpoints = helper_task.aggregator_endpoints().to_vec(); - endpoints[0] = Self::in_cluster_aggregator_url(&leader_namespace); - endpoints - }; let helper_task = helper_task - .with_aggregator_endpoints(helper_endpoints) + .with_leader_aggregator_endpoint(Self::in_cluster_aggregator_url( + &leader_namespace, + )) .build(); let helper = Janus::new_with_kubernetes_cluster( &kubeconfig_path, diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 1d47a4591..e982d15d7 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -81,7 +81,8 @@ async fn handle_add_task( let task = Task::new( request.task_id, - Vec::from([request.leader, request.helper]), + request.leader, + request.helper, query_type, vdaf, request.role.into(), diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index ff5486873..94681eabc 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -81,11 +81,8 @@ async fn handle_upload_generic>( .context("invalid base64url content in \"task_id\"")?; let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; let time_precision = Duration::from_seconds(request.time_precision); - let client_parameters = ClientParameters::new( - task_id, - Vec::::from([request.leader, request.helper]), - time_precision, - ); + let client_parameters = + ClientParameters::new(task_id, request.leader, request.helper, time_precision); let leader_hpke_config = janus_client::aggregator_hpke_config( &client_parameters, diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index 0c3b52d5b..eb32d6a4a 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -283,8 +283,8 @@ impl From for AggregatorAddTaskRequest { }; Self { task_id: *task.id(), - leader: task.aggregator_url(&Role::Leader).unwrap().clone(), - helper: task.aggregator_url(&Role::Helper).unwrap().clone(), + leader: task.leader_aggregator_endpoint().clone(), + helper: task.helper_aggregator_endpoint().clone(), vdaf: task.vdaf().clone().into(), leader_authentication_token: String::from_utf8( task.primary_aggregator_auth_token().as_bytes().to_vec(), From 3c26e8e9acb3ce858b515e5c4cd57b81c1e351a1 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 29 Mar 2023 18:05:09 -0700 Subject: [PATCH 4/8] Remove redundant encoding-length check. --- messages/src/lib.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 240857a09..231ce0f6e 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -2674,9 +2674,6 @@ mod tests { for (val, hex_encoding) in vals_and_encodings { let mut encoded_val = Vec::new(); val.encode(&mut encoded_val); - if let Some(want_encoded_len) = val.encoded_len() { - assert_eq!(want_encoded_len, encoded_val.len()); - } let encoding = hex::decode(hex_encoding).unwrap(); assert_eq!( encoding, encoded_val, From 70e389457c13f66caad0df4ef07faa127d15965f Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Thu, 30 Mar 2023 15:58:44 -0700 Subject: [PATCH 5/8] Implement "ping-pong" aggregation process. This required a few other changes that may be notable: * I enabled the use of Poplar1 to test aggregation-continuation behavior. Neither Prio3 nor the dummy VDAF would work for this, as they do not have enough rounds to trigger continuation. (In retrospect, it may have been wiser to extend the dummy VDAF to allow the number of rounds to be controlled.) * I refactored the replay logic to store the last responses directly, and build current-round responses directly from the same data that would be used to handle replay. Replay state is now attached to the report aggregation, rather than the report aggregation state. --- Cargo.lock | 11 +- Cargo.toml | 3 +- aggregator/src/aggregator.rs | 1550 +++++++++-------- .../src/aggregator/aggregate_init_tests.rs | 135 +- .../aggregator/aggregation_job_continue.rs | 524 ++++-- .../src/aggregator/aggregation_job_creator.rs | 44 +- .../src/aggregator/aggregation_job_driver.rs | 952 +++++----- .../src/aggregator/collection_job_driver.rs | 22 +- .../src/aggregator/collection_job_tests.rs | 3 +- .../src/aggregator/garbage_collector.rs | 4 + aggregator_core/src/datastore.rs | 477 +++-- aggregator_core/src/task.rs | 16 +- core/src/task.rs | 174 +- core/src/test_util/dummy_vdaf.rs | 72 +- core/src/test_util/mod.rs | 8 +- db/schema.sql | 5 +- interop_binaries/tests/end_to_end.rs | 4 +- messages/Cargo.toml | 2 +- messages/src/lib.rs | 675 +++++-- 19 files changed, 2784 insertions(+), 1897 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f52074e93..feb724cae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -455,9 +455,9 @@ dependencies = [ [[package]] name = "cmac" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "606383658416244b8dc4b36f864ec1f86cb922b95c41a908fd07aeb01cad06fa" +checksum = "8543454e3c3f5126effff9cd44d562af4e31fb8ce1cc0d3dcd8f084515dbc1aa" dependencies = [ "cipher", "dbl", @@ -964,9 +964,9 @@ dependencies = [ [[package]] name = "fiat-crypto" -version = "0.1.19" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93ace6ec7cc19c8ed33a32eaa9ea692d7faea05006b5356b9e2b668ec4bc3955" +checksum = "e825f6987101665dea6ec934c09ec6d721de7bc1bf92248e1d5810c8cd636b77" [[package]] name = "filetime" @@ -2809,8 +2809,7 @@ dependencies = [ [[package]] name = "prio" version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c2aa1f9faa3fab6f02b54025f411d6f4fcd31765765600db339280e3678ae20" +source = "git+https://github.com/divviup/libprio-rs.git?branch=bran/encode-poplar1-state#e0357075d880d209f1021b1e9282261533678ba7" dependencies = [ "aes", "base64 0.21.0", diff --git a/Cargo.toml b/Cargo.toml index c364b20b0..ba7e9539c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,8 @@ janus_interop_binaries = { version = "0.4", path = "interop_binaries" } janus_messages = { version = "0.4", path = "messages" } k8s-openapi = { version = "0.16.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.75.0", default-features = false, features = ["client"] } -prio = { version = "0.12.0", features = ["multithreaded"] } +# prio = { version = "0.12.0", features = ["multithreaded"] } # XXX +prio = { git = "https://github.com/divviup/libprio-rs.git", branch = "bran/encode-poplar1-state", features = ["multithreaded", "experimental"] } [profile.dev] # Disabling debug info improves build speeds & reduces build artifact sizes, which helps CI caching. diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 06dab5023..cc1783c44 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -23,8 +23,8 @@ use janus_aggregator_core::{ self, models::{ AggregateShareJob, AggregationJob, AggregationJobState, BatchAggregation, - CollectionJob, CollectionJobState, LeaderStoredReport, PrepareMessageOrShare, - ReportAggregation, ReportAggregationState, + CollectionJob, CollectionJobState, LeaderStoredReport, ReportAggregation, + ReportAggregationState, }, Datastore, Transaction, }, @@ -36,7 +36,7 @@ use janus_core::test_util::dummy_vdaf; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, http::response_to_problem_details, - task::{AuthenticationToken, VdafInstance, DAP_AUTH_HEADER, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, DAP_AUTH_HEADER, VERIFY_KEY_LEN}, time::{Clock, DurationExt, IntervalExt, TimeExt}, }; use janus_messages::{ @@ -55,14 +55,14 @@ use opentelemetry::{ }; #[cfg(feature = "fpvec_bounded_l2")] use prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded; -#[cfg(feature = "test-util")] -use prio::vdaf::PrepareTransition; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, vdaf::{ self, + poplar1::Poplar1, + prg::PrgSha3, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded}, - VdafError, + VdafError, PrepareTransition }, }; use regex::Regex; @@ -318,12 +318,16 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "decrypt_failure", "input_share_decode_failure", "public_share_decode_failure", + "prepare_message_decode_failure", + "leader_prep_share_decode_failure", + "helper_prep_share_decode_failure", "continue_mismatch", "accumulate_failure", "finish_mismatch", "helper_step_failure", "plaintext_input_share_decode_failure", "duplicate_extension", + "missing_prepare_message", ] { aggregate_step_failure_counter.add( &Context::current(), @@ -743,6 +747,12 @@ impl TaskAggregator { VdafOps::Prio3FixedPoint64BitBoundedL2VecSum(Arc::new(vdaf), verify_key) } + VdafInstance::Poplar1 { bits } => { + let vdaf = Poplar1::new_sha3(*bits); + let verify_key = task.primary_vdaf_verify_key()?; + VdafOps::Poplar1(Arc::new(vdaf), verify_key) + } + #[cfg(feature = "test-util")] VdafInstance::Fake => VdafOps::Fake(Arc::new(dummy_vdaf::Vdaf::new())), @@ -758,7 +768,7 @@ impl TaskAggregator { #[cfg(feature = "test-util")] VdafInstance::FakeFailsPrepStep => { VdafOps::Fake(Arc::new(dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result, VdafError> { + |_| -> Result, VdafError> { Err(VdafError::Uncategorized( "FakeFailsPrepStep failed at prep_step".to_string(), )) @@ -903,33 +913,29 @@ impl TaskAggregator { /// VdafOps stores VDAF-specific operations for a TaskAggregator in a non-generic way. #[allow(clippy::enum_variant_names)] enum VdafOps { - Prio3Count(Arc, VerifyKey), - Prio3CountVec( - Arc, - VerifyKey, - ), - Prio3Sum(Arc, VerifyKey), - Prio3SumVec( - Arc, - VerifyKey, - ), - Prio3Histogram(Arc, VerifyKey), + Prio3Count(Arc, VerifyKey), + Prio3CountVec(Arc, VerifyKey), + Prio3Sum(Arc, VerifyKey), + Prio3SumVec(Arc, VerifyKey), + Prio3Histogram(Arc, VerifyKey), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint32BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint64BitBoundedL2VecSum( Arc>>, - VerifyKey, + VerifyKey, ), + Poplar1(Arc>, VerifyKey), + #[cfg(feature = "test-util")] Fake(Arc), } @@ -940,13 +946,13 @@ enum VdafOps { /// specify the VDAF's type, and the name of a const that will be set to the VDAF's verify key /// length, also for explicitly specifying type parameters. macro_rules! vdaf_ops_dispatch { - ($vdaf_ops:expr, ($vdaf:pat_param, $verify_key:pat_param, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + ($vdaf_ops:expr, ($vdaf:pat_param, $verify_key:pat_param, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_ops { crate::aggregator::VdafOps::Prio3Count(vdaf, verify_key) => { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -954,7 +960,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -962,7 +968,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -970,7 +976,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -978,7 +984,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -988,7 +994,7 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -998,7 +1004,7 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -1008,7 +1014,15 @@ macro_rules! vdaf_ops_dispatch { let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; + $body + } + + crate::aggregator::VdafOps::Poplar1(vdaf, verify_key) => { + let $vdaf = vdaf; + let $verify_key = verify_key; + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -1017,7 +1031,7 @@ macro_rules! vdaf_ops_dispatch { let $vdaf = vdaf; let $verify_key = &VerifyKey::new([]); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } } @@ -1036,8 +1050,8 @@ impl VdafOps { ) -> Result<(), Arc> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_upload_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_upload_generic::( Arc::clone(vdaf), clock, upload_decrypt_failure_counter, @@ -1050,8 +1064,8 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_upload_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_upload_generic::( Arc::clone(vdaf), clock, upload_decrypt_failure_counter, @@ -1079,8 +1093,8 @@ impl VdafOps { ) -> Result { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_init_generic::( + vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_init_generic::( datastore, vdaf, aggregate_step_failure_counter, @@ -1094,8 +1108,8 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_init_generic::( + vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_init_generic::( datastore, vdaf, aggregate_step_failure_counter, @@ -1123,8 +1137,8 @@ impl VdafOps { ) -> Result { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_continue_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_continue_generic::( datastore, Arc::clone(vdaf), aggregate_step_failure_counter, @@ -1138,8 +1152,8 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { - Self::handle_aggregate_continue_generic::( + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { + Self::handle_aggregate_continue_generic::( datastore, Arc::clone(vdaf), aggregate_step_failure_counter, @@ -1316,26 +1330,16 @@ where { report_share: ReportShare, report_aggregation: ReportAggregation, - prep_result: PrepareStepResult, - existing_report_aggregation: bool, - conflicting_aggregate_share: bool, } impl> ReportShareData where A: vdaf::Aggregator, { - fn new( - report_share: ReportShare, - report_aggregation: ReportAggregation, - prep_result: PrepareStepResult, - ) -> Self { + fn new(report_share: ReportShare, report_aggregation: ReportAggregation) -> Self { Self { report_share, report_aggregation, - prep_result, - existing_report_aggregation: false, - conflicting_aggregate_share: false, } } } @@ -1388,16 +1392,6 @@ impl VdafOps { ) .await?; - // Filter out any report shares in the incoming message that wouldn't get written out: we - // don't expect to see those in the datastore. - let incoming_report_share_data: Vec<_> = incoming_report_share_data - .iter() - .filter(|report_share_data| { - !report_share_data.existing_report_aggregation - && !report_share_data.conflicting_aggregate_share - }) - .collect(); - if existing_report_aggregations.len() != incoming_report_share_data.len() { return Ok(false); } @@ -1409,11 +1403,11 @@ impl VdafOps { if incoming_report_share_data .iter() .zip(existing_report_aggregations) - .any(|(incoming_report_share, existing_report_share)| { - !existing_report_share + .any(|(incoming_report_share_data, existing_report_aggregation)| { + !existing_report_aggregation .report_metadata() - .eq(incoming_report_share.report_share.metadata()) - || !existing_report_share.eq(&incoming_report_share.report_aggregation) + .eq(incoming_report_share_data.report_share.metadata()) + || !existing_report_aggregation.eq(&incoming_report_share_data.report_aggregation) }) { return Ok(false); @@ -1450,9 +1444,9 @@ impl VdafOps { // If two ReportShare messages have the same report ID, then the helper MUST abort with // error "unrecognizedMessage". (§4.4.4.1) - let mut seen_report_ids = HashSet::with_capacity(req.report_shares().len()); - for share in req.report_shares() { - if !seen_report_ids.insert(share.metadata().id()) { + let mut seen_report_ids = HashSet::with_capacity(req.report_inits().len()); + for report_init in req.report_inits() { + if !seen_report_ids.insert(*report_init.report_share().metadata().id()) { return Err(Error::UnrecognizedMessage( Some(*task.id()), "aggregate request contains duplicate report IDs", @@ -1464,13 +1458,18 @@ impl VdafOps { let mut saw_continue = false; let mut report_share_data = Vec::new(); let agg_param = A::AggregationParam::get_decoded(req.aggregation_parameter())?; - for (ord, report_share) in req.report_shares().iter().enumerate() { + for (ord, report_init) in req.report_inits().iter().enumerate() { let hpke_keypair = task .hpke_keys() - .get(report_share.encrypted_input_share().config_id()) + .get( + report_init + .report_share() + .encrypted_input_share() + .config_id(), + ) .ok_or_else(|| { info!( - config_id = %report_share.encrypted_input_share().config_id(), + config_id = %report_init.report_share().encrypted_input_share().config_id(), "Helper encrypted input share references unknown HPKE config ID" ); aggregate_step_failure_counter.add( @@ -1487,18 +1486,18 @@ impl VdafOps { hpke_keypair.config(), hpke_keypair.private_key(), &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - report_share.encrypted_input_share(), + report_init.report_share().encrypted_input_share(), &InputShareAad::new( *task.id(), - report_share.metadata().clone(), - report_share.public_share().to_vec(), + report_init.report_share().metadata().clone(), + report_init.report_share().public_share().to_vec(), ) .get_encoded(), ) .map_err(|error| { info!( task_id = %task.id(), - metadata = ?report_share.metadata(), + metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decrypt helper's report share" ); @@ -1513,7 +1512,7 @@ impl VdafOps { let plaintext_input_share = plaintext.and_then(|plaintext| { let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext).map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's plaintext input share"); + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decode helper's plaintext input share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "plaintext_input_share_decode_failure")]); ReportShareError::UnrecognizedMessage })?; @@ -1523,7 +1522,7 @@ impl VdafOps { .extensions() .iter() .all(|extension| extension_types.insert(extension.extension_type())) { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), "Received report share with duplicate extensions"); + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), "Received report share with duplicate extensions"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "duplicate_extension")]); return Err(ReportShareError::UnrecognizedMessage) } @@ -1533,14 +1532,14 @@ impl VdafOps { let input_share = plaintext_input_share.and_then(|plaintext_input_share| { A::InputShare::get_decoded_with_param(&(vdaf, Role::Helper.index().unwrap()), plaintext_input_share.payload()) .map_err(|error| { - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode helper's input share"); + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decode helper's input share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "input_share_decode_failure")]); ReportShareError::UnrecognizedMessage }) }); - let public_share = A::PublicShare::get_decoded_with_param(vdaf, report_share.public_share()).map_err(|error|{ - info!(task_id = %task.id(), metadata = ?report_share.metadata(), ?error, "Couldn't decode public share"); + let public_share = A::PublicShare::get_decoded_with_param(vdaf, report_init.report_share().public_share()).map_err(|error|{ + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?error, "Couldn't decode public share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "public_share_decode_failure")]); ReportShareError::UnrecognizedMessage }); @@ -1550,78 +1549,109 @@ impl VdafOps { // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF // associated with the task and computes the first state transition. [...] If either // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) - let init_rslt = shares.and_then(|(public_share, input_share)| { + let init_state= shares.and_then(|(public_share, input_share)| { vdaf .prepare_init( verify_key.as_bytes(), Role::Helper.index().unwrap(), &agg_param, - report_share.metadata().id().as_ref(), + report_init.report_share().metadata().id().as_ref(), &public_share, &input_share, ) .map_err(|error| { - info!(task_id = %task.id(), report_id = %report_share.metadata().id(), ?error, "Couldn't prepare_init report share"); + info!(task_id = %task.id(), report_id = %report_init.report_share().metadata().id(), ?error, "Couldn't prepare_init report share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_init_failure")]); ReportShareError::VdafPrepError }) }); - report_share_data.push(match init_rslt { - Ok((prep_state, prep_share)) => { - saw_continue = true; + // Next, the Helper computes the next prepare message, then steps its own VDAF + // preparation state. + let stepped = init_state.and_then(|(prep_state, helper_prep_share)| { + let leader_prep_share = A::PrepareShare::get_decoded_with_param( + &prep_state, + report_init.leader_prep_share(), + ).map_err(|err| { + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?err, "Couldn't decode Leader prepare share"); + aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "leader_prep_share_decode_failure")]); + ReportShareError::UnrecognizedMessage + })?; - let encoded_prep_share = prep_share.get_encoded(); - ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - ReportAggregationState::::Waiting( - prep_state, - PrepareMessageOrShare::Helper(prep_share), - ), - ), - PrepareStepResult::Continued(encoded_prep_share), - ) - } + let prep_msg = vdaf.prepare_preprocess([leader_prep_share, helper_prep_share]) + .map_err(|err| { + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?err, "Couldn't compute prepare message"); + aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_message_failure")]); + ReportShareError::VdafPrepError + })?; - Err(err) => ReportShareData::new( - report_share.clone(), - ReportAggregation::::new( - *task.id(), - *aggregation_job_id, - *report_share.metadata().id(), - *report_share.metadata().time(), - ord.try_into()?, - ReportAggregationState::::Failed(err), + let prep_msg_encoded = prep_msg.get_encoded(); + let helper_next_transition = vdaf.prepare_step(prep_state, prep_msg) + .map_err(|err| { + info!(task_id = %task.id(), metadata = ?report_init.report_share().metadata(), ?err, "Prepare step failed"); + aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_step_failure")]); + ReportShareError::VdafPrepError + })?; + + Ok(match helper_next_transition { + PrepareTransition::Continue(prep_state, prep_share) => { + saw_continue = true; + let prep_share_encoded = prep_share.get_encoded(); + (ReportAggregationState::::Waiting(prep_state, None), + PrepareStepResult::Continued{ + prep_msg: prep_msg_encoded, + prep_share: prep_share_encoded, + }) + }, + PrepareTransition::Finish(out_share) => ( + ReportAggregationState::::Finished(out_share), + PrepareStepResult::Finished { prep_msg: prep_msg_encoded }, ), + }) + }); + + let (report_aggregation_state, prep_step_rslt) = stepped.unwrap_or_else(|err| { + ( + ReportAggregationState::::Failed(err), PrepareStepResult::Failed(err), - ), + ) }); + + report_share_data.push(ReportShareData::new( + report_init.report_share().clone(), + ReportAggregation::::new( + *task.id(), + *aggregation_job_id, + *report_init.report_share().metadata().id(), + *report_init.report_share().metadata().time(), + ord.try_into()?, + Some(PrepareStep::new( + *report_init.report_share().metadata().id(), + prep_step_rslt, + )), + report_aggregation_state, + ), + )); } // Store data to datastore. let req = Arc::new(req); let min_client_timestamp = req - .report_shares() + .report_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|report_init| *report_init.report_share().metadata().time()) .min() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let max_client_timestamp = req - .report_shares() + .report_inits() .iter() - .map(|report_share| report_share.metadata().time()) + .map(|report_init| *report_init.report_share().metadata().time()) .max() .ok_or_else(|| Error::EmptyAggregation(*task.id()))?; let client_timestamp_interval = Interval::new( - *min_client_timestamp, + min_client_timestamp, max_client_timestamp - .difference(min_client_timestamp)? + .difference(&min_client_timestamp)? .add(&Duration::from_seconds(1))?, )?; let aggregation_job = Arc::new(AggregationJob::::new( @@ -1649,14 +1679,14 @@ impl VdafOps { ); Box::pin(async move { - for mut share_data in report_share_data.iter_mut() { + for mut report_share_data in &mut report_share_data { // Verify that we haven't seen this report ID and aggregation parameter // before in another aggregation job, and that the report isn't for a batch // interval that has already started collection. let (report_aggregation_exists, conflicting_aggregate_share_jobs) = try_join!( tx.check_other_report_aggregation_exists::( task.id(), - share_data.report_share.metadata().id(), + report_share_data.report_share.metadata().id(), aggregation_job.aggregation_parameter(), aggregation_job.id(), ), @@ -1665,12 +1695,31 @@ impl VdafOps { &vdaf, task.id(), req.batch_selector().batch_identifier(), - share_data.report_share.metadata() + report_share_data.report_share.metadata() ), )?; - share_data.existing_report_aggregation = report_aggregation_exists; - share_data.conflicting_aggregate_share = !conflicting_aggregate_share_jobs.is_empty(); + if report_aggregation_exists { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportReplayed)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::ReportReplayed)) + )); + } else if !conflicting_aggregate_share_jobs.is_empty() { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::BatchCollected)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::BatchCollected)) + )); + } } // Write aggregation job. @@ -1701,73 +1750,55 @@ impl VdafOps { Err(e) => return Err(e), }; - // Construct a response and write any new report shares and report aggregations - // as we go. - let mut accumulator = Accumulator::::new( - Arc::clone(&task), - batch_aggregation_shard_count, - aggregation_job.aggregation_parameter().clone(), - ); - - let mut prep_steps = Vec::new(); - for report_share_data in report_share_data - { - if report_share_data.existing_report_aggregation { - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed), - )); - continue; - } - if report_share_data.conflicting_aggregate_share { - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), - )); - continue; - } + if !replayed_request { + let mut accumulator = Accumulator::::new( + Arc::clone(&task), + batch_aggregation_shard_count, + aggregation_job.aggregation_parameter().clone(), + ); - if !replayed_request { - // Write client report & report aggregation. - if let Err(error) = tx.put_report_share( - task.id(), - &report_share_data.report_share - ).await { - match error { - datastore::Error::MutationTargetAlreadyExists => { - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed), - )); - continue; + for report_share_data in &mut report_share_data + { + if !replayed_request { + // Write client report & report aggregation. + if let Err(err) = tx.put_report_share( + task.id(), + &report_share_data.report_share + ).await { + match err { + datastore::Error::MutationTargetAlreadyExists => { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportReplayed)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::ReportReplayed)) + )); + }, + err => return Err(err), } - e => return Err(e), } - } - tx.put_report_aggregation(&report_share_data.report_aggregation).await?; - } + tx.put_report_aggregation(&report_share_data.report_aggregation).await?; + + if let ReportAggregationState::::Finished(output_share) = + report_share_data.report_aggregation.state() + { + accumulator.update( + aggregation_job.partial_batch_identifier(), + report_share_data.report_share.metadata().id(), + report_share_data.report_share.metadata().time(), + output_share, + )?; + } - if let ReportAggregationState::::Finished(output_share) = - report_share_data.report_aggregation.state() - { - accumulator.update( - aggregation_job.partial_batch_identifier(), - report_share_data.report_share.metadata().id(), - report_share_data.report_share.metadata().time(), - output_share, - )?; + } } - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - report_share_data.prep_result.clone(), - )); - } - - if !replayed_request { accumulator.flush_to_datastore(tx, &vdaf).await?; } - Ok(prep_steps) + Ok(report_share_data.into_iter().map(|data| data.report_aggregation.last_prep_step().unwrap().clone()).collect()) }) }) .await?; @@ -1872,9 +1903,9 @@ impl VdafOps { } } } - return Self::replay_aggregation_job_round::( + return Ok(Self::aggregation_job_resp_for::( report_aggregations, - ); + )); } else if helper_aggregation_job.round().increment() != leader_aggregation_job.round() { @@ -1895,14 +1926,14 @@ impl VdafOps { // compute the next round of prepare messages and state. Self::step_aggregation_job( tx, - &task, - &vdaf, + task, + vdaf, batch_aggregation_shard_count, helper_aggregation_job, report_aggregations, - &leader_aggregation_job, + leader_aggregation_job, request_hash, - &aggregate_step_failure_counter, + aggregate_step_failure_counter, ) .await }) @@ -1920,9 +1951,9 @@ impl VdafOps { ) -> Result<(), Error> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_create_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -1931,9 +1962,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_create_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -2067,9 +2098,9 @@ impl VdafOps { ) -> Result>, Error> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_get_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -2078,9 +2109,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_get_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -2239,9 +2270,9 @@ impl VdafOps { ) -> Result<(), Error> { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_delete_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -2250,9 +2281,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_delete_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -2316,9 +2347,9 @@ impl VdafOps { ) -> Result { match task.query_type() { task::QueryType::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_aggregate_share_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, VdafType, _, @@ -2327,9 +2358,9 @@ impl VdafOps { }) } task::QueryType::FixedSize { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LEN) => { Self::handle_aggregate_share_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, VdafType, _, @@ -3238,8 +3269,7 @@ mod tests { datastore::{ models::{ AggregateShareJob, AggregationJob, AggregationJobState, BatchAggregation, - CollectionJob, CollectionJobState, PrepareMessageOrShare, ReportAggregation, - ReportAggregationState, + CollectionJob, CollectionJobState, ReportAggregation, ReportAggregationState, }, test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, @@ -3252,8 +3282,8 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, report_id::ReportIdChecksumExt, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, - test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LEN}, + test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf, VdafTranscript}, time::{Clock, DurationExt, IntervalExt, MockClock, RealClock, TimeExt}, }; use janus_messages::{ @@ -3265,16 +3295,19 @@ mod tests { CollectionReq, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Query, Report, ReportId, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + ReportIdChecksum, ReportMetadata, ReportPrepInit, ReportShare, ReportShareError, Role, + TaskId, Time, }; use opentelemetry::global::meter; use prio::{ codec::{Decode, Encode}, - field::Field64, + idpf::IdpfInput, vdaf::{ self, - prio3::{Prio3, Prio3Count}, - AggregateShare, Aggregator as _, Client as VdafClient, OutputShare, + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + prio3::Prio3Count, + Aggregator as _, Client as VdafClient, }, }; use rand::random; @@ -3288,7 +3321,7 @@ mod tests { Filter, Rejection, }; - const DUMMY_VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; + const DUMMY_VERIFY_KEY_LEN: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LEN; pub(crate) fn default_aggregator_config() -> Config { // Enable upload write batching & batch aggregation sharding by default, in hopes that we @@ -4101,17 +4134,15 @@ mod tests { .run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.put_collection_job(&CollectionJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - random(), - batch_interval, - (), - CollectionJobState::Start, - )) + tx.put_collection_job( + &CollectionJob::::new( + *task.id(), + random(), + batch_interval, + (), + CollectionJobState::Start, + ), + ) .await }) }) @@ -4309,7 +4340,7 @@ mod tests { let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); let hpke_key = task.current_hpke_key(); - // report_share_0 is a "happy path" report. + // report_init_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), clock @@ -4322,18 +4353,17 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_0.id(), - &(), + &0, ); - let report_share_0 = generate_helper_report_share::( + let report_init_0 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_0, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_1 fails decryption. + // report_init_1 fails decryption. let report_metadata_1 = ReportMetadata::new( random(), clock @@ -4346,17 +4376,16 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_1.id(), - &(), + &0, ); - let report_share_1 = generate_helper_report_share::( + let report_init_1 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_1.clone(), hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - let encrypted_input_share = report_share_1.encrypted_input_share(); + let encrypted_input_share = report_init_1.report_share().encrypted_input_share(); let mut corrupted_payload = encrypted_input_share.payload().to_vec(); corrupted_payload[0] ^= 0xFF; let corrupted_input_share = HpkeCiphertext::new( @@ -4364,14 +4393,15 @@ mod tests { encrypted_input_share.encapsulated_key().to_vec(), corrupted_payload, ); - let encoded_public_share = transcript.public_share.get_encoded(); let report_share_1 = ReportShare::new( report_metadata_1, - encoded_public_share.clone(), + transcript.public_share.get_encoded(), corrupted_input_share, ); + let report_init_1 = + ReportPrepInit::new(report_share_1, report_init_1.leader_prep_share().to_vec()); - // report_share_2 fails decoding due to an issue with the input share. + // report_init_2 fails decoding due to an issue with the input share. let report_metadata_2 = ReportMetadata::new( random(), clock @@ -4384,19 +4414,25 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_2.id(), - &(), + &0, ); let mut input_share_bytes = transcript.input_shares[1].get_encoded(); input_share_bytes.push(0); // can no longer be decoded. - let report_share_2 = generate_helper_report_share_for_plaintext( + let report_init_2 = generate_helper_report_init_for_plaintext( report_metadata_2.clone(), hpke_key.config(), - encoded_public_share.clone(), + transcript.public_share.get_encoded(), &input_share_bytes, - &InputShareAad::new(*task.id(), report_metadata_2, encoded_public_share).get_encoded(), + &InputShareAad::new( + *task.id(), + report_metadata_2, + transcript.public_share.get_encoded(), + ) + .get_encoded(), + transcript.leader_prep_state(0).1.get_encoded(), ); - // report_share_3 has an unknown HPKE config ID. + // report_init_3 has an unknown HPKE config ID. let report_metadata_3 = ReportMetadata::new( random(), clock @@ -4409,7 +4445,7 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_3.id(), - &(), + &0, ); let wrong_hpke_config = loop { let hpke_config = generate_test_hpke_config_and_private_key().config().clone(); @@ -4418,16 +4454,15 @@ mod tests { } break hpke_config; }; - let report_share_3 = generate_helper_report_share::( + let report_init_3 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_3, &wrong_hpke_config, - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_4 has already been aggregated in another aggregation job, with the same + // report_init_4 has already been aggregated in another aggregation job, with the same // aggregation parameter. let report_metadata_4 = ReportMetadata::new( random(), @@ -4441,18 +4476,17 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_4.id(), - &(), + &0, ); - let report_share_4 = generate_helper_report_share::( + let report_init_4 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_4, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_5 falls into a batch that has already been collected. + // report_init_5 falls into a batch that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( task.time_precision().as_seconds() / 2, )); @@ -4468,18 +4502,17 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_5.id(), - &(), + &0, ); - let report_share_5 = generate_helper_report_share::( + let report_init_5 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_5, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - // report_share_6 fails decoding due to an issue with the public share. + // report_init_6 fails decoding due to an issue with the public share. let public_share_6 = Vec::from([0]); let report_metadata_6 = ReportMetadata::new( random(), @@ -4493,17 +4526,18 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_6.id(), - &(), + &0, ); - let report_share_6 = generate_helper_report_share_for_plaintext( + let report_init_6 = generate_helper_report_init_for_plaintext( report_metadata_6.clone(), hpke_key.config(), public_share_6.clone(), &transcript.input_shares[1].get_encoded(), &InputShareAad::new(*task.id(), report_metadata_6, public_share_6).get_encoded(), + transcript.leader_prep_state(0).1.get_encoded(), ); - // report_share_7 fails due to having repeated extensions. + // report_init_7 fails due to having repeated extensions. let report_metadata_7 = ReportMetadata::new( random(), clock @@ -4516,21 +4550,20 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(0), report_metadata_7.id(), - &(), + &0, ); - let report_share_7 = generate_helper_report_share::( + let report_init_7 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_7, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::from([ Extension::new(ExtensionType::Tbd, Vec::new()), Extension::new(ExtensionType::Tbd, Vec::new()), ]), - &transcript.input_shares[0], ); - // report_share_8 has already been aggregated in another aggregation job, with a different + // report_init_8 has already been aggregated in another aggregation job, with a different // aggregation parameter. let report_metadata_8 = ReportMetadata::new( random(), @@ -4544,30 +4577,60 @@ mod tests { verify_key.as_bytes(), &dummy_vdaf::AggregationParam(1), report_metadata_8.id(), - &(), + &0, ); - let report_share_8 = generate_helper_report_share::( + let report_init_8 = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), report_metadata_8, hpke_key.config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], + ); + + // report_init_9 fails decoding due to an issue with the leader preparation share. + let report_metadata_9 = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &vdaf, + verify_key.as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata_9.id(), + &0, + ); + let report_init_9 = generate_helper_report_init_for_plaintext( + report_metadata_9.clone(), + hpke_key.config(), + transcript.public_share.get_encoded(), + &transcript.input_shares[1].get_encoded(), + &InputShareAad::new( + *task.id(), + report_metadata_9, + transcript.public_share.get_encoded(), + ) + .get_encoded(), + Vec::from([0]), ); let (conflicting_aggregation_job, non_conflicting_aggregation_job) = datastore .run_tx(|tx| { - let (task, report_share_4, report_share_8) = - (task.clone(), report_share_4.clone(), report_share_8.clone()); + let (task, report_init_4, report_init_8) = + (task.clone(), report_init_4.clone(), report_init_8.clone()); Box::pin(async move { tx.put_task(&task).await?; - // report_share_4 and report_share_8 are already in the datastore as they were + // report_init_4 and report_init_8 are already in the datastore as they were // referenced by existing aggregation jobs. - tx.put_report_share(task.id(), &report_share_4).await?; - tx.put_report_share(task.id(), &report_share_8).await?; + tx.put_report_share(task.id(), report_init_4.report_share()) + .await?; + tx.put_report_share(task.id(), report_init_8.report_share()) + .await?; - // Put in an aggregation job and report aggregation for report_share_4. It uses + // Put in an aggregation job and report aggregation for report_init_4. It uses // the same aggregation parameter as the aggregation job this test will later // add and so should cause report_share_4 to fail to prepare. let conflicting_aggregation_job = AggregationJob::new( @@ -4588,9 +4651,10 @@ mod tests { tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( *task.id(), *conflicting_aggregation_job.id(), - *report_share_4.metadata().id(), - *report_share_4.metadata().time(), + *report_init_4.report_share().metadata().id(), + *report_init_4.report_share().metadata().time(), 0, + None, ReportAggregationState::Start, )) .await @@ -4618,9 +4682,10 @@ mod tests { tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( *task.id(), *non_conflicting_aggregation_job.id(), - *report_share_8.metadata().id(), - *report_share_8.metadata().time(), + *report_init_8.report_share().metadata().id(), + *report_init_8.report_share().metadata().time(), 0, + None, ReportAggregationState::Start, )) .await @@ -4654,15 +4719,16 @@ mod tests { dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), Vec::from([ - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), - report_share_6.clone(), - report_share_7.clone(), - report_share_8.clone(), + report_init_0.clone(), + report_init_1.clone(), + report_init_2.clone(), + report_init_3.clone(), + report_init_4.clone(), + report_init_5.clone(), + report_init_6.clone(), + report_init_7.clone(), + report_init_8.clone(), + report_init_9.clone(), ]), ); @@ -4684,64 +4750,101 @@ mod tests { let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap(); // Validate response. - assert_eq!(aggregate_resp.prepare_steps().len(), 9); + assert_eq!(aggregate_resp.prepare_steps().len(), 10); let prepare_step_0 = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step_0.report_id(), report_share_0.metadata().id()); - assert_matches!(prepare_step_0.result(), &PrepareStepResult::Continued(..)); + assert_eq!( + prepare_step_0.report_id(), + report_init_0.report_share().metadata().id() + ); + assert_matches!(prepare_step_0.result(), &PrepareStepResult::Finished { .. }); let prepare_step_1 = aggregate_resp.prepare_steps().get(1).unwrap(); - assert_eq!(prepare_step_1.report_id(), report_share_1.metadata().id()); + assert_eq!( + prepare_step_1.report_id(), + report_init_1.report_share().metadata().id() + ); assert_matches!( prepare_step_1.result(), &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) ); let prepare_step_2 = aggregate_resp.prepare_steps().get(2).unwrap(); - assert_eq!(prepare_step_2.report_id(), report_share_2.metadata().id()); + assert_eq!( + prepare_step_2.report_id(), + report_init_2.report_share().metadata().id() + ); assert_matches!( prepare_step_2.result(), &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage) ); let prepare_step_3 = aggregate_resp.prepare_steps().get(3).unwrap(); - assert_eq!(prepare_step_3.report_id(), report_share_3.metadata().id()); + assert_eq!( + prepare_step_3.report_id(), + report_init_3.report_share().metadata().id() + ); assert_matches!( prepare_step_3.result(), &PrepareStepResult::Failed(ReportShareError::HpkeUnknownConfigId) ); let prepare_step_4 = aggregate_resp.prepare_steps().get(4).unwrap(); - assert_eq!(prepare_step_4.report_id(), report_share_4.metadata().id()); + assert_eq!( + prepare_step_4.report_id(), + report_init_4.report_share().metadata().id() + ); assert_eq!( prepare_step_4.result(), &PrepareStepResult::Failed(ReportShareError::ReportReplayed) ); let prepare_step_5 = aggregate_resp.prepare_steps().get(5).unwrap(); - assert_eq!(prepare_step_5.report_id(), report_share_5.metadata().id()); + assert_eq!( + prepare_step_5.report_id(), + report_init_5.report_share().metadata().id() + ); assert_eq!( prepare_step_5.result(), &PrepareStepResult::Failed(ReportShareError::BatchCollected) ); let prepare_step_6 = aggregate_resp.prepare_steps().get(6).unwrap(); - assert_eq!(prepare_step_6.report_id(), report_share_6.metadata().id()); + assert_eq!( + prepare_step_6.report_id(), + report_init_6.report_share().metadata().id() + ); assert_eq!( prepare_step_6.result(), &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), ); let prepare_step_7 = aggregate_resp.prepare_steps().get(7).unwrap(); - assert_eq!(prepare_step_7.report_id(), report_share_7.metadata().id()); + assert_eq!( + prepare_step_7.report_id(), + report_init_7.report_share().metadata().id() + ); assert_eq!( prepare_step_7.result(), &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), ); let prepare_step_8 = aggregate_resp.prepare_steps().get(8).unwrap(); - assert_eq!(prepare_step_8.report_id(), report_share_8.metadata().id()); - assert_matches!(prepare_step_8.result(), &PrepareStepResult::Continued(..)); + assert_eq!( + prepare_step_8.report_id(), + report_init_8.report_share().metadata().id() + ); + assert_matches!(prepare_step_8.result(), &PrepareStepResult::Finished { .. }); + + let prepare_step_9 = aggregate_resp.prepare_steps().get(9).unwrap(); + assert_eq!( + prepare_step_9.report_id(), + report_init_9.report_share().metadata().id() + ); + assert_matches!( + prepare_step_9.result(), + &PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage) + ); // Check aggregation job in datastore. let aggregation_jobs = datastore @@ -4770,7 +4873,7 @@ mod tests { } else if aggregation_job.task_id().eq(task.id()) && aggregation_job.id().eq(&aggregation_job_id) && aggregation_job.partial_batch_identifier().eq(&()) - && aggregation_job.state().eq(&AggregationJobState::InProgress) + && aggregation_job.state().eq(&AggregationJobState::Finished) { saw_new_aggregation_job = true; } @@ -4792,16 +4895,16 @@ mod tests { // This report has the same ID as the previous one, but a different timestamp. let mutated_timestamp_report_metadata = ReportMetadata::new( - *test_case.report_shares[0].metadata().id(), + *test_case.report_inits[0].report_share().metadata().id(), test_case .clock .now() .add(test_case.task.time_precision()) .unwrap(), ); - let mutated_timestamp_report_share = test_case - .report_share_generator - .next_with_metadata(mutated_timestamp_report_metadata) + let mutated_timestamp_report_init = test_case + .report_init_generator + .next_with_metadata(mutated_timestamp_report_metadata, &0) .0; // Send another aggregate job re-using the same report ID but with a different timestamp. It @@ -4809,7 +4912,7 @@ mod tests { let request = AggregationJobInitializeReq::new( other_aggregation_parameter.get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([mutated_timestamp_report_share.clone()]), + Vec::from([mutated_timestamp_report_init.clone()]), ); let mut response = @@ -4823,7 +4926,7 @@ mod tests { let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); assert_eq!( prepare_step.report_id(), - mutated_timestamp_report_share.metadata().id() + mutated_timestamp_report_init.report_share().metadata().id() ); assert_matches!( prepare_step.result(), @@ -4845,8 +4948,14 @@ mod tests { .await .unwrap(); assert_eq!(client_reports.len(), 2); - assert_eq!(&client_reports[0], test_case.report_shares[0].metadata()); - assert_eq!(&client_reports[1], test_case.report_shares[1].metadata()); + assert_eq!( + &client_reports[0], + test_case.report_inits[0].report_share().metadata() + ); + assert_eq!( + &client_reports[1], + test_case.report_inits[1].report_share().metadata() + ); } #[tokio::test] @@ -4863,28 +4972,35 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = ephemeral_datastore.datastore(clock.clone()); - let hpke_key = task.current_hpke_key(); datastore.put_task(&task).await.unwrap(); - let report_share = generate_helper_report_share::( + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata.id(), + &0, + ); + let report_init = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), + report_metadata, + task.current_hpke_key().config(), + &transcript, Vec::new(), - &dummy_vdaf::InputShare::default(), ); + let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([report_init.clone()]), ); // Create aggregator filter, send request, and parse response. @@ -4905,7 +5021,10 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 1); let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + assert_eq!( + prepare_step.report_id(), + report_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) @@ -4926,28 +5045,35 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = ephemeral_datastore.datastore(clock.clone()); - let hpke_key = task.current_hpke_key(); datastore.put_task(&task).await.unwrap(); - let report_share = generate_helper_report_share::( + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata.id(), + &0, + ); + let report_init = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( *task.id(), - ReportMetadata::new( - random(), - clock - .now() - .to_batch_interval_start(task.time_precision()) - .unwrap(), - ), - hpke_key.config(), - &(), + report_metadata, + task.current_hpke_key().config(), + &transcript, Vec::new(), - &dummy_vdaf::InputShare::default(), ); + let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone()]), + Vec::from([report_init.clone()]), ); // Create aggregator filter, send request, and parse response. @@ -4968,7 +5094,10 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 1); let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!(prepare_step.report_id(), report_share.metadata().id()); + assert_eq!( + prepare_step.report_id(), + report_init.report_share().metadata().id() + ); assert_matches!( prepare_step.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) @@ -4991,24 +5120,32 @@ mod tests { datastore.put_task(&task).await.unwrap(); - let report_share = ReportShare::new( - ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - Vec::from("PUBLIC"), - HpkeCiphertext::new( - // bogus, but we never get far enough to notice - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &dummy_vdaf::Vdaf::new(), + task.primary_vdaf_verify_key().unwrap().as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata.id(), + &0, + ); + let report_init = generate_helper_report_init::<0, dummy_vdaf::Vdaf>( + *task.id(), + report_metadata, + task.current_hpke_key().config(), + &transcript, + Vec::new(), ); let request = AggregationJobInitializeReq::new( dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([report_share.clone(), report_share]), + Vec::from([report_init.clone(), report_init]), ); let filter = @@ -5042,7 +5179,7 @@ mod tests { let aggregation_job_id = random(); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); @@ -5050,12 +5187,16 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); - let hpke_key = task.current_hpke_key(); + let vdaf = Arc::new(Poplar1::new_sha3(1)); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let measurement = IdpfInput::from_bools(&[false]); - // report_share_0 is a "happy path" report. + // report_init_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), clock @@ -5066,22 +5207,21 @@ mod tests { let transcript_0 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, prep_share_0) = transcript_0.helper_prep_state(0); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let (prep_state_0, _) = transcript_0.helper_prep_state(1); + let prep_msg_0 = transcript_0.prepare_messages[1].clone(); + let report_init_0 = generate_helper_report_init::>( *task.id(), report_metadata_0.clone(), - hpke_key.config(), - &transcript_0.public_share, + task.current_hpke_key().config(), + &transcript_0, Vec::new(), - &transcript_0.input_shares[1], ); - // report_share_1 is omitted by the leader's request. + // report_init_1 is omitted by the leader's request. let report_metadata_1 = ReportMetadata::new( random(), clock @@ -5092,22 +5232,20 @@ mod tests { let transcript_1 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - - let (prep_state_1, prep_share_1) = transcript_1.helper_prep_state(0); - let report_share_1 = generate_helper_report_share::( + let (prep_state_1, _) = transcript_1.helper_prep_state(1); + let report_init_1 = generate_helper_report_init::>( *task.id(), report_metadata_1.clone(), - hpke_key.config(), - &transcript_1.public_share, + task.current_hpke_key().config(), + &transcript_1, Vec::new(), - &transcript_1.input_shares[1], ); - // report_share_2 falls into a batch that has already been collected. + // report_init_2 falls into a batch that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( task.time_precision().as_seconds() / 2, )); @@ -5121,44 +5259,44 @@ mod tests { let transcript_2 = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, prep_share_2) = transcript_2.helper_prep_state(0); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let (prep_state_2, _) = transcript_2.helper_prep_state(1); + let prep_msg_2 = transcript_2.prepare_messages[1].clone(); + let report_init_2 = generate_helper_report_init::>( *task.id(), report_metadata_2.clone(), - hpke_key.config(), - &transcript_2.public_share, + task.current_hpke_key().config(), + &transcript_2, Vec::new(), - &transcript_2.input_shares[1], ); + let aggregate_share = vdaf + .aggregate( + &aggregation_param, + [ + transcript_0.output_share(Role::Helper).clone(), + transcript_1.output_share(Role::Helper).clone(), + transcript_2.output_share(Role::Helper).clone(), + ], + ) + .unwrap(); + datastore .run_tx(|tx| { - let task = task.clone(); + let (task, aggregation_param, aggregate_share) = (task.clone(), aggregation_param.clone(), aggregate_share.clone()); let (report_share_0, report_share_1, report_share_2) = ( - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - ); - let (prep_share_0, prep_share_1, prep_share_2) = ( - prep_share_0.clone(), - prep_share_1.clone(), - prep_share_2.clone(), + report_init_0.report_share().clone(), + report_init_1.report_share().clone(), + report_init_2.report_share().clone(), ); let (prep_state_0, prep_state_1, prep_state_2) = ( prep_state_0.clone(), prep_state_1.clone(), prep_state_2.clone(), ); - let (report_metadata_0, report_metadata_1, report_metadata_2) = ( - report_metadata_0.clone(), - report_metadata_1.clone(), - report_metadata_2.clone(), - ); Box::pin(async move { tx.put_task(&task).await?; @@ -5168,13 +5306,13 @@ mod tests { tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -5183,50 +5321,44 @@ mod tests { )) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, - *report_metadata_0.id(), - *report_metadata_0.time(), + *report_share_0.metadata().id(), + *report_share_0.metadata().time(), 0, - ReportAggregationState::Waiting( - prep_state_0, - PrepareMessageOrShare::Helper(prep_share_0), - ), + None, + ReportAggregationState::Waiting(prep_state_0, None), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, - *report_metadata_1.id(), - *report_metadata_1.time(), + *report_share_1.metadata().id(), + *report_share_1.metadata().time(), 1, - ReportAggregationState::Waiting( - prep_state_1, - PrepareMessageOrShare::Helper(prep_share_1), - ), + None, + ReportAggregationState::Waiting(prep_state_1, None), ), ) .await?; - tx.put_report_aggregation::( + tx.put_report_aggregation::>( &ReportAggregation::new( *task.id(), aggregation_job_id, - *report_metadata_2.id(), - *report_metadata_2.time(), + *report_share_2.metadata().id(), + *report_share_2.metadata().time(), 2, - ReportAggregationState::Waiting( - prep_state_2, - PrepareMessageOrShare::Helper(prep_share_2), - ), + None, + ReportAggregationState::Waiting(prep_state_2, None), ), ) .await?; - tx.put_aggregate_share_job::( + tx.put_aggregate_share_job::>( &AggregateShareJob::new( *task.id(), Interval::new( @@ -5234,8 +5366,8 @@ mod tests { *task.time_precision(), ) .unwrap(), - (), - AggregateShare::from(OutputShare::from(Vec::from([Field64::from(7)]))), + aggregation_param, + aggregate_share, 0, ReportIdChecksum::default(), ), @@ -5251,11 +5383,15 @@ mod tests { Vec::from([ PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_0.get_encoded(), + }, ), PrepareStep::new( *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_2.get_encoded(), + }, ), ]), ); @@ -5271,7 +5407,12 @@ mod tests { assert_eq!( aggregate_resp, AggregationJobResp::new(Vec::from([ - PrepareStep::new(*report_metadata_0.id(), PrepareStepResult::Finished), + PrepareStep::new( + *report_metadata_0.id(), + PrepareStepResult::Finished { + prep_msg: Vec::new() + } + ), PrepareStep::new( *report_metadata_2.id(), PrepareStepResult::Failed(ReportShareError::BatchCollected), @@ -5280,38 +5421,39 @@ mod tests { ); // Validate datastore. - let (aggregation_job, report_aggregations) = - datastore - .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); - Box::pin(async move { - let aggregation_job = tx - .get_aggregation_job::( + let (aggregation_job, report_aggregations) = datastore + .run_tx(|tx| { + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) - .await.unwrap().unwrap(); - let report_aggregations = tx - .get_report_aggregations_for_aggregation_job( - vdaf.as_ref(), - &Role::Helper, - task.id(), - &aggregation_job_id, - ) - .await - .unwrap(); - Ok((aggregation_job, report_aggregations)) - }) + .await + .unwrap() + .unwrap(); + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Helper, + task.id(), + &aggregation_job_id, + ) + .await + .unwrap(); + Ok((aggregation_job, report_aggregations)) }) - .await - .unwrap(); + }) + .await + .unwrap(); assert_eq!( aggregation_job, AggregationJob::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -5329,6 +5471,12 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, + Some(PrepareStep::new( + *report_metadata_0.id(), + PrepareStepResult::Finished { + prep_msg: Vec::new() + } + )), ReportAggregationState::Finished( transcript_0.output_share(Role::Helper).clone() ), @@ -5339,6 +5487,7 @@ mod tests { *report_metadata_1.id(), *report_metadata_1.time(), 1, + None, ReportAggregationState::Failed(ReportShareError::ReportDropped), ), ReportAggregation::new( @@ -5347,6 +5496,10 @@ mod tests { *report_metadata_2.id(), *report_metadata_2.time(), 2, + Some(PrepareStep::new( + *report_metadata_2.id(), + PrepareStepResult::Failed(ReportShareError::BatchCollected) + )), ReportAggregationState::Failed(ReportShareError::BatchCollected), ) ]) @@ -5359,7 +5512,7 @@ mod tests { let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Helper, ) .build(); @@ -5375,12 +5528,16 @@ mod tests { .unwrap(), ); - let vdaf = Prio3::new_count(2).unwrap(); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); - let hpke_key = task.current_hpke_key(); + let vdaf = Poplar1::new_sha3(1); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); - // report_share_0 is a "happy path" report. + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let measurement = IdpfInput::from_bools(&[false]); + + // report_init_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), first_batch_interval_clock @@ -5391,23 +5548,22 @@ mod tests { let transcript_0 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_0.id(), - &0, + &measurement, ); - let (prep_state_0, prep_share_0) = transcript_0.helper_prep_state(0); + let (prep_state_0, _) = transcript_0.helper_prep_state(1); let out_share_0 = transcript_0.output_share(Role::Helper); - let prep_msg_0 = transcript_0.prepare_messages[0].clone(); - let report_share_0 = generate_helper_report_share::( + let prep_msg_0 = transcript_0.prepare_messages[1].clone(); + let report_init_0 = generate_helper_report_init::>( *task.id(), report_metadata_0.clone(), - hpke_key.config(), - &transcript_0.public_share, + task.current_hpke_key().config(), + &transcript_0, Vec::new(), - &transcript_0.input_shares[1], ); - // report_share_1 is another "happy path" report to exercise in-memory accumulation of + // report_init_1 is another "happy path" report to exercise in-memory accumulation of // output shares let report_metadata_1 = ReportMetadata::new( random(), @@ -5419,23 +5575,22 @@ mod tests { let transcript_1 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_1.id(), - &0, + &measurement, ); - let (prep_state_1, prep_share_1) = transcript_1.helper_prep_state(0); + let (prep_state_1, _) = transcript_1.helper_prep_state(1); let out_share_1 = transcript_1.output_share(Role::Helper); - let prep_msg_1 = transcript_1.prepare_messages[0].clone(); - let report_share_1 = generate_helper_report_share::( + let prep_msg_1 = transcript_1.prepare_messages[1].clone(); + let report_init_1 = generate_helper_report_init::>( *task.id(), report_metadata_1.clone(), - hpke_key.config(), - &transcript_1.public_share, + task.current_hpke_key().config(), + &transcript_1, Vec::new(), - &transcript_1.input_shares[1], ); - // report share 2 aggregates successfully, but into a distinct batch aggregation. + // report_init_2 aggregates successfully, but into a distinct batch aggregation. let report_metadata_2 = ReportMetadata::new( random(), second_batch_interval_clock @@ -5446,45 +5601,34 @@ mod tests { let transcript_2 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_2.id(), - &0, + &measurement, ); - let (prep_state_2, prep_share_2) = transcript_2.helper_prep_state(0); + let (prep_state_2, _) = transcript_2.helper_prep_state(1); let out_share_2 = transcript_2.output_share(Role::Helper); - let prep_msg_2 = transcript_2.prepare_messages[0].clone(); - let report_share_2 = generate_helper_report_share::( + let prep_msg_2 = transcript_2.prepare_messages[1].clone(); + let report_init_2 = generate_helper_report_init::>( *task.id(), report_metadata_2.clone(), - hpke_key.config(), - &transcript_2.public_share, + task.current_hpke_key().config(), + &transcript_2, Vec::new(), - &transcript_2.input_shares[1], ); datastore .run_tx(|tx| { - let task = task.clone(); + let (task, aggregation_param) = (task.clone(), aggregation_param.clone()); let (report_share_0, report_share_1, report_share_2) = ( - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - ); - let (prep_share_0, prep_share_1, prep_share_2) = ( - prep_share_0.clone(), - prep_share_1.clone(), - prep_share_2.clone(), + report_init_0.report_share().clone(), + report_init_1.report_share().clone(), + report_init_2.report_share().clone(), ); let (prep_state_0, prep_state_1, prep_state_2) = ( prep_state_0.clone(), prep_state_1.clone(), prep_state_2.clone(), ); - let (report_metadata_0, report_metadata_1, report_metadata_2) = ( - report_metadata_0.clone(), - report_metadata_1.clone(), - report_metadata_2.clone(), - ); Box::pin(async move { tx.put_task(&task).await?; @@ -5494,13 +5638,13 @@ mod tests { tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -5510,48 +5654,42 @@ mod tests { .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - *report_metadata_0.id(), - *report_metadata_0.time(), + *report_share_0.metadata().id(), + *report_share_0.metadata().time(), 0, - ReportAggregationState::Waiting( - prep_state_0, - PrepareMessageOrShare::Helper(prep_share_0), - ), + None, + ReportAggregationState::Waiting(prep_state_0, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - *report_metadata_1.id(), - *report_metadata_1.time(), + *report_share_1.metadata().id(), + *report_share_1.metadata().time(), 1, - ReportAggregationState::Waiting( - prep_state_1, - PrepareMessageOrShare::Helper(prep_share_1), - ), + None, + ReportAggregationState::Waiting(prep_state_1, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_0, - *report_metadata_2.id(), - *report_metadata_2.time(), + *report_share_2.metadata().id(), + *report_share_2.metadata().time(), 2, - ReportAggregationState::Waiting( - prep_state_2, - PrepareMessageOrShare::Helper(prep_share_2), - ), + None, + ReportAggregationState::Waiting(prep_state_2, None), )) .await?; @@ -5566,15 +5704,21 @@ mod tests { Vec::from([ PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Continued(prep_msg_0.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_0.get_encoded(), + }, ), PrepareStep::new( *report_metadata_1.id(), - PrepareStepResult::Continued(prep_msg_1.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_1.get_encoded(), + }, ), PrepareStep::new( *report_metadata_2.id(), - PrepareStepResult::Continued(prep_msg_2.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_2.get_encoded(), + }, ), ]), ); @@ -5593,12 +5737,16 @@ mod tests { // Map the batch aggregation ordinal value to 0, as it may vary due to sharding. let batch_aggregations: Vec<_> = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, >( tx, @@ -5613,7 +5761,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds() * 2), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -5622,10 +5770,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -5635,12 +5783,6 @@ mod tests { }) .collect(); - let aggregate_share = vdaf - .aggregate(&(), [out_share_0.clone(), out_share_1.clone()]) - .unwrap(); - let checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) - .updated_with(report_metadata_1.id()); - assert_eq!( batch_aggregations, Vec::from([ @@ -5654,12 +5796,17 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, - aggregate_share, + vdaf.aggregate( + &aggregation_param.clone(), + [out_share_0.clone(), out_share_1.clone()], + ) + .unwrap(), 2, Interval::from_time(report_metadata_0.time()).unwrap(), - checksum, + ReportIdChecksum::for_report_id(report_metadata_0.id()) + .updated_with(report_metadata_1.id()), ), BatchAggregation::new( *task.id(), @@ -5671,9 +5818,10 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, - AggregateShare::from(out_share_2.clone()), + vdaf.aggregate(&aggregation_param, [out_share_2.clone()]) + .unwrap(), 1, Interval::from_time(report_metadata_2.time()).unwrap(), ReportIdChecksum::for_report_id(report_metadata_2.id()), @@ -5683,7 +5831,7 @@ mod tests { // Aggregate some more reports, which should get accumulated into the // batch_aggregations rows created earlier. - // report_share_3 gets aggreated into the first batch interval. + // report_init_3 gets aggregated into the first batch interval. let report_metadata_3 = ReportMetadata::new( random(), first_batch_interval_clock @@ -5694,23 +5842,22 @@ mod tests { let transcript_3 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_3.id(), - &0, + &measurement, ); - let (prep_state_3, prep_share_3) = transcript_3.helper_prep_state(0); + let (prep_state_3, _) = transcript_3.helper_prep_state(1); let out_share_3 = transcript_3.output_share(Role::Helper); - let prep_msg_3 = transcript_3.prepare_messages[0].clone(); - let report_share_3 = generate_helper_report_share::( + let prep_msg_3 = transcript_3.prepare_messages[1].clone(); + let report_init_3 = generate_helper_report_init::>( *task.id(), report_metadata_3.clone(), - hpke_key.config(), - &transcript_3.public_share, + task.current_hpke_key().config(), + &transcript_3, Vec::new(), - &transcript_3.input_shares[1], ); - // report_share_4 gets aggregated into the second batch interval + // report_init_4 gets aggregated into the second batch interval. let report_metadata_4 = ReportMetadata::new( random(), second_batch_interval_clock @@ -5721,23 +5868,22 @@ mod tests { let transcript_4 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_4.id(), - &0, + &measurement, ); - let (prep_state_4, prep_share_4) = transcript_4.helper_prep_state(0); + let (prep_state_4, _) = transcript_4.helper_prep_state(1); let out_share_4 = transcript_4.output_share(Role::Helper); - let prep_msg_4 = transcript_4.prepare_messages[0].clone(); - let report_share_4 = generate_helper_report_share::( + let prep_msg_4 = transcript_4.prepare_messages[1].clone(); + let report_init_4 = generate_helper_report_init::>( *task.id(), report_metadata_4.clone(), - hpke_key.config(), - &transcript_4.public_share, + task.current_hpke_key().config(), + &transcript_4, Vec::new(), - &transcript_4.input_shares[1], ); - // report share 5 also gets aggregated into the second batch interval + // report_init_5 also gets aggregated into the second batch interval. let report_metadata_5 = ReportMetadata::new( random(), second_batch_interval_clock @@ -5748,45 +5894,34 @@ mod tests { let transcript_5 = run_vdaf( &vdaf, verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata_5.id(), - &0, + &measurement, ); - let (prep_state_5, prep_share_5) = transcript_5.helper_prep_state(0); + let (prep_state_5, _) = transcript_5.helper_prep_state(1); let out_share_5 = transcript_5.output_share(Role::Helper); - let prep_msg_5 = transcript_5.prepare_messages[0].clone(); - let report_share_5 = generate_helper_report_share::( + let prep_msg_5 = transcript_5.prepare_messages[1].clone(); + let report_init_5 = generate_helper_report_init::>( *task.id(), report_metadata_5.clone(), - hpke_key.config(), - &transcript_5.public_share, + task.current_hpke_key().config(), + &transcript_5, Vec::new(), - &transcript_5.input_shares[1], ); datastore .run_tx(|tx| { - let task = task.clone(); + let (task, aggregation_param) = (task.clone(), aggregation_param.clone()); let (report_share_3, report_share_4, report_share_5) = ( - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), - ); - let (prep_share_3, prep_share_4, prep_share_5) = ( - prep_share_3.clone(), - prep_share_4.clone(), - prep_share_5.clone(), + report_init_3.report_share().clone(), + report_init_4.report_share().clone(), + report_init_5.report_share().clone(), ); let (prep_state_3, prep_state_4, prep_state_5) = ( prep_state_3.clone(), prep_state_4.clone(), prep_state_5.clone(), ); - let (report_metadata_3, report_metadata_4, report_metadata_5) = ( - report_metadata_3.clone(), - report_metadata_4.clone(), - report_metadata_5.clone(), - ); Box::pin(async move { tx.put_report_share(task.id(), &report_share_3).await?; @@ -5794,13 +5929,13 @@ mod tests { tx.put_report_share(task.id(), &report_share_5).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -5810,48 +5945,42 @@ mod tests { .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - *report_metadata_3.id(), - *report_metadata_3.time(), + *report_share_3.metadata().id(), + *report_share_3.metadata().time(), 3, - ReportAggregationState::Waiting( - prep_state_3, - PrepareMessageOrShare::Helper(prep_share_3), - ), + None, + ReportAggregationState::Waiting(prep_state_3, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - *report_metadata_4.id(), - *report_metadata_4.time(), + *report_share_4.metadata().id(), + *report_share_4.metadata().time(), 4, - ReportAggregationState::Waiting( - prep_state_4, - PrepareMessageOrShare::Helper(prep_share_4), - ), + None, + ReportAggregationState::Waiting(prep_state_4, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id_1, - *report_metadata_5.id(), - *report_metadata_5.time(), + *report_share_5.metadata().id(), + *report_share_5.metadata().time(), 5, - ReportAggregationState::Waiting( - prep_state_5, - PrepareMessageOrShare::Helper(prep_share_5), - ), + None, + ReportAggregationState::Waiting(prep_state_5, None), )) .await?; @@ -5866,15 +5995,21 @@ mod tests { Vec::from([ PrepareStep::new( *report_metadata_3.id(), - PrepareStepResult::Continued(prep_msg_3.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_3.get_encoded(), + }, ), PrepareStep::new( *report_metadata_4.id(), - PrepareStepResult::Continued(prep_msg_4.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_4.get_encoded(), + }, ), PrepareStep::new( *report_metadata_5.id(), - PrepareStepResult::Continued(prep_msg_5.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg_5.get_encoded(), + }, ), ]), ); @@ -5895,12 +6030,16 @@ mod tests { // be the same) let mut batch_aggregations: Vec<_> = datastore .run_tx(|tx| { - let (task, vdaf, report_metadata_0) = - (task.clone(), vdaf.clone(), report_metadata_0.clone()); + let (task, vdaf, report_metadata_0, aggregation_param) = ( + task.clone(), + vdaf.clone(), + report_metadata_0.clone(), + aggregation_param.clone(), + ); Box::pin(async move { TimeInterval::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, >( tx, @@ -5915,7 +6054,7 @@ mod tests { Duration::from_seconds(task.time_precision().as_seconds() * 2), ) .unwrap(), - &(), + &aggregation_param, ) .await }) @@ -5924,10 +6063,10 @@ mod tests { .unwrap() .into_iter() .map(|agg| { - BatchAggregation::::new( + BatchAggregation::>::new( *agg.task_id(), *agg.batch_identifier(), - (), + agg.aggregation_parameter().clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -5943,7 +6082,7 @@ mod tests { let first_aggregate_share = vdaf .aggregate( - &(), + &aggregation_param, [out_share_0, out_share_1, out_share_3].into_iter().cloned(), ) .unwrap(); @@ -5953,7 +6092,7 @@ mod tests { let second_aggregate_share = vdaf .aggregate( - &(), + &aggregation_param, [out_share_2, out_share_4, out_share_5].into_iter().cloned(), ) .unwrap(); @@ -5974,7 +6113,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param.clone(), 0, first_aggregate_share, 3, @@ -5991,7 +6130,7 @@ mod tests { *task.time_precision() ) .unwrap(), - (), + aggregation_param, 0, second_aggregate_share, 3, @@ -6003,7 +6142,7 @@ mod tests { } #[tokio::test] - async fn aggregate_continue_leader_sends_non_continue_transition() { + async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { // Prepare datastore & request. install_test_trace_subscriber(); @@ -6040,7 +6179,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -6055,7 +6194,7 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, dummy_vdaf::Vdaf, >::new( *task.id(), @@ -6063,10 +6202,8 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -6079,7 +6216,7 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report_metadata.id(), - PrepareStepResult::Finished, + PrepareStepResult::Failed(ReportShareError::UnrecognizedMessage), )]), ); @@ -6140,7 +6277,7 @@ mod tests { ) .await?; tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -6155,7 +6292,7 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, dummy_vdaf::Vdaf, >::new( *task.id(), @@ -6163,10 +6300,8 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -6179,7 +6314,10 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report_metadata.id(), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -6202,7 +6340,7 @@ mod tests { let (task, report_metadata) = (task.clone(), report_metadata.clone()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -6244,6 +6382,10 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, + Some(PrepareStep::new( + *report_metadata.id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError) + )), ReportAggregationState::Failed(ReportShareError::VdafPrepError), ) ); @@ -6287,7 +6429,7 @@ mod tests { ) .await?; tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -6302,7 +6444,7 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, dummy_vdaf::Vdaf, >::new( *task.id(), @@ -6310,10 +6452,8 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -6328,7 +6468,10 @@ mod tests { ReportId::from( [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], // not the same as above ), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -6409,7 +6552,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -6425,7 +6568,7 @@ mod tests { .await?; tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, dummy_vdaf::Vdaf, >::new( *task.id(), @@ -6433,14 +6576,12 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, dummy_vdaf::Vdaf, >::new( *task.id(), @@ -6448,10 +6589,8 @@ mod tests { *report_metadata_1.id(), *report_metadata_1.time(), 1, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -6466,11 +6605,17 @@ mod tests { // Report IDs are in opposite order to what was stored in the datastore. PrepareStep::new( *report_metadata_1.id(), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), PrepareStep::new( *report_metadata_0.id(), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, ), ]), ); @@ -6528,7 +6673,7 @@ mod tests { ) .await?; tx.put_aggregation_job(&AggregationJob::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -6543,7 +6688,7 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, dummy_vdaf::Vdaf, >::new( *task.id(), @@ -6551,6 +6696,7 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, + None, ReportAggregationState::Invalid, )) .await @@ -6564,7 +6710,10 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - PrepareStepResult::Continued(Vec::new()), + PrepareStepResult::Continued { + prep_msg: Vec::new(), + prep_share: Vec::new(), + }, )]), ); @@ -7133,7 +7282,7 @@ mod tests { let task = test_case.task.clone(); Box::pin(async move { tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -7202,7 +7351,7 @@ mod tests { let task = test_case.task.clone(); Box::pin(async move { tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -7555,7 +7704,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -7576,7 +7725,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -7597,7 +7746,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -7618,7 +7767,7 @@ mod tests { ) .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - DUMMY_VERIFY_KEY_LENGTH, + DUMMY_VERIFY_KEY_LEN, TimeInterval, dummy_vdaf::Vdaf, >::new( @@ -7927,40 +8076,53 @@ mod tests { } } - pub(crate) fn generate_helper_report_share>( + pub(crate) fn generate_helper_report_init< + const SEED_SIZE: usize, + V: vdaf::Client<16> + vdaf::Aggregator, + >( task_id: TaskId, report_metadata: ReportMetadata, cfg: &HpkeConfig, - public_share: &V::PublicShare, + transcript: &VdafTranscript, extensions: Vec, - input_share: &V::InputShare, - ) -> ReportShare { - generate_helper_report_share_for_plaintext( + ) -> ReportPrepInit { + generate_helper_report_init_for_plaintext( report_metadata.clone(), cfg, - public_share.get_encoded(), - &PlaintextInputShare::new(extensions, input_share.get_encoded()).get_encoded(), - &InputShareAad::new(task_id, report_metadata, public_share.get_encoded()).get_encoded(), + transcript.public_share.get_encoded(), + &PlaintextInputShare::new(extensions, transcript.input_shares[1].get_encoded()) + .get_encoded(), + &InputShareAad::new( + task_id, + report_metadata, + transcript.public_share.get_encoded(), + ) + .get_encoded(), + transcript.leader_prep_state(0).1.get_encoded(), ) } - fn generate_helper_report_share_for_plaintext( + fn generate_helper_report_init_for_plaintext( metadata: ReportMetadata, cfg: &HpkeConfig, encoded_public_share: Vec, plaintext: &[u8], associated_data: &[u8], - ) -> ReportShare { - ReportShare::new( - metadata, - encoded_public_share, - hpke::seal( - cfg, - &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), - plaintext, - associated_data, - ) - .unwrap(), + leader_prep_share: Vec, + ) -> ReportPrepInit { + ReportPrepInit::new( + ReportShare::new( + metadata, + encoded_public_share, + hpke::seal( + cfg, + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), + plaintext, + associated_data, + ) + .unwrap(), + ), + leader_prep_share, ) } diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 0a46a7b27..757428718 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -1,4 +1,4 @@ -use crate::aggregator::{aggregator_filter, tests::generate_helper_report_share, Config}; +use crate::aggregator::{aggregator_filter, tests::generate_helper_report_init, Config}; use http::{header::CONTENT_TYPE, StatusCode}; use janus_aggregator_core::{ datastore::{ @@ -14,87 +14,94 @@ use janus_core::{ }; use janus_messages::{ query_type::TimeInterval, AggregationJobId, AggregationJobInitializeReq, PartialBatchSelector, - ReportMetadata, ReportShare, Role, + ReportMetadata, ReportPrepInit, Role, }; -use prio::codec::Encode; +use prio::{codec::Encode, vdaf}; use rand::random; use std::sync::Arc; use warp::{filters::BoxedFilter, reply::Response, Reply}; -pub(super) struct ReportShareGenerator { +pub(super) struct ReportInitGenerator +where + V: vdaf::Vdaf, +{ clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, - vdaf: dummy_vdaf::Vdaf, + vdaf: V, + aggregation_param: V::AggregationParam, } -impl ReportShareGenerator { +impl ReportInitGenerator +where + V: vdaf::Vdaf + vdaf::Aggregator + vdaf::Client<16>, +{ pub(super) fn new( clock: MockClock, task: Task, - aggregation_param: dummy_vdaf::AggregationParam, + vdaf: V, + aggregation_param: V::AggregationParam, ) -> Self { Self { clock, task, + vdaf, aggregation_param, - vdaf: dummy_vdaf::Vdaf::new(), } } - fn with_vdaf(mut self, vdaf: dummy_vdaf::Vdaf) -> Self { - self.vdaf = vdaf; - self - } - - pub(super) fn next(&self) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { - self.next_with_metadata(ReportMetadata::new( - random(), - self.clock - .now() - .to_batch_interval_start(self.task.time_precision()) - .unwrap(), - )) + pub(super) fn next( + &self, + measurement: &V::Measurement, + ) -> (ReportPrepInit, VdafTranscript) { + self.next_with_metadata( + ReportMetadata::new( + random(), + self.clock + .now() + .to_batch_interval_start(self.task.time_precision()) + .unwrap(), + ), + measurement, + ) } pub(super) fn next_with_metadata( &self, report_metadata: ReportMetadata, - ) -> (ReportShare, VdafTranscript<0, dummy_vdaf::Vdaf>) { + measurement: &V::Measurement, + ) -> (ReportPrepInit, VdafTranscript) { let transcript = run_vdaf( &self.vdaf, self.task.primary_vdaf_verify_key().unwrap().as_bytes(), &self.aggregation_param, report_metadata.id(), - &(), + measurement, ); - let report_share = generate_helper_report_share::( + let report_init = generate_helper_report_init::( *self.task.id(), report_metadata, self.task.current_hpke_key().config(), - &transcript.public_share, + &transcript, Vec::new(), - &transcript.input_shares[1], ); - - (report_share, transcript) + (report_init, transcript) } } -pub(super) struct AggregationJobInitTestCase { +pub(super) struct AggregationJobInitTestCase { pub(super) clock: MockClock, pub(super) task: Task, - pub(super) report_share_generator: ReportShareGenerator, - pub(super) report_shares: Vec, + pub(super) report_init_generator: ReportInitGenerator, + pub(super) report_inits: Vec, pub(super) aggregation_job_id: AggregationJobId, - pub(super) aggregation_param: dummy_vdaf::AggregationParam, + pub(super) aggregation_param: V::AggregationParam, pub(super) filter: BoxedFilter<(R,)>, pub(super) datastore: Arc>, _ephemeral_datastore: EphemeralDatastore, } -pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase -{ +pub(super) async fn setup_aggregate_init_test( +) -> AggregationJobInitTestCase<0, impl Reply + 'static, dummy_vdaf::Vdaf> { install_test_trace_subscriber(); let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); @@ -109,19 +116,23 @@ pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase AggregationJobInitTestCase( tx: &Transaction<'_, C>, - task: &Arc, - vdaf: &Arc, + task: Arc, + vdaf: Arc, batch_aggregation_shard_count: u64, helper_aggregation_job: AggregationJob, - report_aggregations: Vec>, - leader_aggregation_job: &Arc, + mut report_aggregations: Vec>, + leader_aggregation_job: Arc, request_hash: [u8; 32], - aggregate_step_failure_counter: &Counter, + aggregate_step_failure_counter: Counter, ) -> Result where C: Clock, @@ -47,11 +46,9 @@ impl VdafOps { for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, { // Handle each transition in the request. - let mut report_aggregations = report_aggregations.into_iter(); - let (mut saw_continue, mut saw_finish) = (false, false); - let mut response_prep_steps = Vec::new(); + let mut report_aggregations_iter = report_aggregations.iter_mut(); let mut accumulator = Accumulator::::new( - Arc::clone(task), + Arc::clone(&task), batch_aggregation_shard_count, helper_aggregation_job.aggregation_parameter().clone(), ); @@ -60,7 +57,7 @@ impl VdafOps { // Match preparation step received from leader to stored report aggregation, and extract // the stored preparation step. let report_aggregation = loop { - let report_agg = report_aggregations.next().ok_or_else(|| { + let report_agg = report_aggregations_iter.next().ok_or_else(|| { datastore::Error::User( Error::UnrecognizedMessage( Some(*task.id()), @@ -73,10 +70,12 @@ impl VdafOps { // This report was omitted by the leader because of a prior failure. Note that // the report was dropped (if it's not already in an error state) and continue. if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { - tx.update_report_aggregation(&report_agg.with_state( - ReportAggregationState::Failed(ReportShareError::ReportDropped), - )) - .await?; + *report_agg = report_agg + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportDropped, + )) + .with_last_prep_step(None); } continue; } @@ -86,20 +85,21 @@ impl VdafOps { // Make sure this report isn't in an interval that has already started collection. let conflicting_aggregate_share_jobs = tx .get_aggregate_share_jobs_including_time::( - vdaf, + &vdaf, task.id(), report_aggregation.time(), ) .await?; if !conflicting_aggregate_share_jobs.is_empty() { - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), - )); - tx.update_report_aggregation(&report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::BatchCollected), - )) - .await?; + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::BatchCollected, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::BatchCollected), + ))); continue; } @@ -116,87 +116,241 @@ impl VdafOps { } }; - // Parse preparation message out of prepare step received from leader. - let prep_msg = match prep_step.result() { - PrepareStepResult::Continued(payload) => A::PrepareMessage::decode_with_param( - prep_state, - &mut Cursor::new(payload.as_ref()), - )?, - _ => { - return Err(datastore::Error::User( - Error::UnrecognizedMessage( - Some(*task.id()), - "leader sent non-Continued prepare step", - ) - .into(), - )); + // Parse preparation message out of prepare step received from Leader. + let prep_msg = A::PrepareMessage::decode_with_param( + prep_state, + &mut Cursor::new(match prep_step.result() { + PrepareStepResult::Continued { prep_msg, .. } => prep_msg, + PrepareStepResult::Finished { prep_msg } => prep_msg, + _ => { + return Err(datastore::Error::User( + Error::UnrecognizedMessage( + Some(*task.id()), + "leader sent non-Continued/Finished prepare step", + ) + .into(), + )); + } + }), + )?; + + // Compute the next transition; if we're finished, we terminate here. Otherwise, + // retrieve our updated state as well as the leader & helper prepare shares. + let (prep_state, leader_prep_share, helper_prep_share) = match vdaf + .prepare_step(prep_state.clone(), prep_msg) + { + Ok(PrepareTransition::Continue(prep_state, helper_prep_share)) => { + if let PrepareStepResult::Continued { prep_share, .. } = prep_step.result() { + let leader_prep_share = match A::PrepareShare::get_decoded_with_param( + &prep_state, + prep_share, + ) { + Ok(leader_prep_share) => leader_prep_share, + Err(err) => { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + ?err, + "Couldn't parse Leader's prepare share" + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } + }; + (prep_state, leader_prep_share, helper_prep_share) + } else { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + "Leader finished but Helper did not", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } } - }; - // Compute the next transition, prepare to respond & update DB. - let next_state = match vdaf.prepare_step(prep_state.clone(), prep_msg) { - Ok(PrepareTransition::Continue(prep_state, prep_share)) => { - saw_continue = true; - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Continued(prep_share.get_encoded()), - )); - ReportAggregationState::Waiting( - prep_state, - PrepareMessageOrShare::Helper(prep_share), - ) - } + Ok(PrepareTransition::Finish(out_share)) => { + // If we finished but the Leader didn't, fail out. + if !matches!(prep_step.result(), PrepareStepResult::Finished { .. }) { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + "Helper finished but Leader did not", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } - Ok(PrepareTransition::Finish(output_share)) => { - saw_finish = true; + // If both aggregators finished here, record our output share & respond with + // a finished message. accumulator.update( helper_aggregation_job.partial_batch_identifier(), prep_step.report_id(), report_aggregation.time(), - &output_share, + &out_share, )?; - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Finished, - )); - ReportAggregationState::Finished(output_share) + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Finished(out_share)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Finished { + prep_msg: Vec::new(), + }, + ))); + continue; } - Err(error) => { + Err(err) => { info!( task_id = %task.id(), job_id = %helper_aggregation_job.id(), report_id = %prep_step.report_id(), - ?error, "Prepare step failed", + ?err, "Prepare step failed", ); aggregate_step_failure_counter.add( &Context::current(), 1, &[KeyValue::new("type", "prepare_step_failure")], ); - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError), - )); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; } }; - tx.update_report_aggregation(&report_aggregation.with_state(next_state)) - .await?; + // Merge the leader & helper prepare shares into the next message. + let prep_msg = match vdaf.prepare_preprocess([leader_prep_share, helper_prep_share]) { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + ?err, + "Couldn't compute prepare message", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + continue; + } + }; + + // Compute the next step based on the merged message. + let encoded_prep_msg = prep_msg.get_encoded(); + match vdaf.prepare_step(prep_state, prep_msg) { + Ok(PrepareTransition::Continue(prep_state, helper_prep_share)) => { + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Waiting(prep_state, None)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Continued { + prep_msg: encoded_prep_msg, + prep_share: helper_prep_share.get_encoded(), + }, + ))); + } + + Ok(PrepareTransition::Finish(out_share)) => { + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Finished(out_share)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Finished { + prep_msg: encoded_prep_msg, + }, + ))); + } + + Err(err) => { + info!( + task_id = %task.id(), + job_id = %helper_aggregation_job.id(), + report_id = %prep_step.report_id(), + ?err, + "Prepare step failed", + ); + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))); + } + } } - for report_agg in report_aggregations { + for report_agg in report_aggregations_iter { // This report was omitted by the leader because of a prior failure. Note that the // report was dropped (if it's not already in an error state) and continue. if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { - tx.update_report_aggregation(&report_agg.with_state( - ReportAggregationState::Failed(ReportShareError::ReportDropped), - )) - .await?; + *report_agg = report_agg + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportDropped, + )) + .with_last_prep_step(None); } } + let saw_continue = report_aggregations.iter().any(|report_agg| { + matches!( + report_agg.last_prep_step().map(PrepareStep::result), + Some(PrepareStepResult::Continued { .. }) + ) + }); + let saw_finish = report_aggregations.iter().any(|report_agg| { + matches!( + report_agg.last_prep_step().map(PrepareStep::result), + Some(PrepareStepResult::Finished { .. }) + ) + }); let helper_aggregation_job = helper_aggregation_job // Advance the job to the leader's round .with_round(leader_aggregation_job.round()) @@ -215,53 +369,36 @@ impl VdafOps { } }) .with_last_continue_request_hash(request_hash); - tx.update_aggregation_job(&helper_aggregation_job).await?; - - accumulator.flush_to_datastore(tx, vdaf).await?; - Ok(AggregationJobResp::new(response_prep_steps)) + try_join!( + tx.update_aggregation_job(&helper_aggregation_job), + try_join_all( + report_aggregations + .iter() + .map(|report_agg| tx.update_report_aggregation(report_agg)), + ), + accumulator.flush_to_datastore(tx, &vdaf) + )?; + + Ok(Self::aggregation_job_resp_for::( + report_aggregations, + )) } - /// Fetch previously-computed prepare message shares and replay them back to the leader. - pub(super) fn replay_aggregation_job_round( + /// Construct an AggregationJobResp from a given set of Helper report aggregations. + pub(super) fn aggregation_job_resp_for( report_aggregations: Vec>, - ) -> Result + ) -> AggregationJobResp where - C: Clock, - Q: AccumulableQueryType, - A: vdaf::Aggregator + 'static + Send + Sync, - for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A: vdaf::Aggregator, { - let response_prep_steps = report_aggregations - .iter() - .map(|report_aggregation| { - let prepare_step_state = match report_aggregation.state() { - ReportAggregationState::Waiting(_, prep_msg) => PrepareStepResult::Continued( - prep_msg.get_helper_prepare_share()?.get_encoded(), - ), - ReportAggregationState::Finished(_) => PrepareStepResult::Finished, - ReportAggregationState::Failed(report_share_error) => { - PrepareStepResult::Failed(*report_share_error) - } - state => { - return Err(datastore::Error::User( - Error::Internal(format!( - "report aggregation {} unexpectedly in state {state:?}", - report_aggregation.report_id() - )) - .into(), - )); - } - }; - - Ok(PrepareStep::new( - *report_aggregation.report_id(), - prepare_step_state, - )) - }) - .collect::>()?; - - Ok(AggregationJobResp::new(response_prep_steps)) + AggregationJobResp::new( + report_aggregations + .iter() + .filter_map(ReportAggregation::last_prep_step) + .cloned() + .collect(), + ) } } @@ -365,7 +502,7 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - aggregate_init_tests::ReportShareGenerator, + aggregate_init_tests::ReportInitGenerator, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, post_aggregation_job_expecting_status, @@ -377,8 +514,7 @@ mod tests { use janus_aggregator_core::{ datastore::{ models::{ - AggregationJob, AggregationJobState, PrepareMessageOrShare, ReportAggregation, - ReportAggregationState, + AggregationJob, AggregationJobState, ReportAggregation, ReportAggregationState, }, test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, @@ -386,23 +522,34 @@ mod tests { task::{test_util::TaskBuilder, QueryType, Task}, }; use janus_core::{ - task::VdafInstance, - test_util::{dummy_vdaf, install_test_trace_subscriber}, + task::{VdafInstance, VERIFY_KEY_LEN}, + test_util::install_test_trace_subscriber, time::{IntervalExt, MockClock}, }; use janus_messages::{ query_type::TimeInterval, AggregationJobContinueReq, AggregationJobId, AggregationJobResp, AggregationJobRound, Interval, PrepareStep, PrepareStepResult, Role, }; - use prio::codec::Encode; + use prio::{ + codec::Encode, + idpf::IdpfInput, + vdaf::{ + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, + Vdaf, + }, + }; use rand::random; use std::sync::Arc; use warp::{filters::BoxedFilter, Reply}; - struct AggregationJobContinueTestCase { + struct AggregationJobContinueTestCase + where + V: Vdaf, + { task: Task, datastore: Arc>, - report_generator: ReportShareGenerator, + report_generator: ReportInitGenerator, aggregation_job_id: AggregationJobId, first_continue_request: AggregationJobContinueReq, first_continue_response: Option, @@ -413,59 +560,77 @@ mod tests { /// Set up a helper with an aggregation job in round 0 #[allow(clippy::unit_arg)] async fn setup_aggregation_job_continue_test( - ) -> AggregationJobContinueTestCase { + ) -> AggregationJobContinueTestCase> { // Prepare datastore & request. install_test_trace_subscriber(); let aggregation_job_id = random(); - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 1 }, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let report_generator = ReportShareGenerator::new( + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); + let report_generator = ReportInitGenerator::new( clock.clone(), task.clone(), - dummy_vdaf::AggregationParam::default(), + Poplar1::new_sha3(1), + aggregation_param.clone(), ); - let report = report_generator.next(); + let (report_init, transcript) = report_generator.next(&IdpfInput::from_bools(&[true])); datastore .run_tx(|tx| { - let (task, report) = (task.clone(), report.clone()); + let (task, aggregation_param, report_init, transcript) = ( + task.clone(), + aggregation_param.clone(), + report_init.clone(), + transcript.clone(), + ); Box::pin(async move { tx.put_task(&task).await.unwrap(); - tx.put_report_share(task.id(), &report.0).await.unwrap(); + tx.put_report_share(task.id(), report_init.report_share()) + .await + .unwrap(); - tx.put_aggregation_job( - &AggregationJob::<0, TimeInterval, dummy_vdaf::Vdaf>::new( - *task.id(), - aggregation_job_id, - dummy_vdaf::AggregationParam::default(), - (), - Interval::from_time(report.0.metadata().time()).unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - ), - ) + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LEN, + TimeInterval, + Poplar1, + >::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::from_time(report_init.report_share().metadata().time()).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + )) .await .unwrap(); - let (prep_state, prep_share) = report.1.helper_prep_state(0); - tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( - *task.id(), - aggregation_job_id, - *report.0.metadata().id(), - *report.0.metadata().time(), - 0, - ReportAggregationState::Waiting( - *prep_state, - PrepareMessageOrShare::Helper(*prep_share), + let (prep_state, _) = transcript.helper_prep_state(1); + tx.put_report_aggregation::>( + &ReportAggregation::new( + *task.id(), + aggregation_job_id, + *report_init.report_share().metadata().id(), + *report_init.report_share().metadata().time(), + 0, + None, + ReportAggregationState::Waiting(prep_state.clone(), None), ), - )) + ) .await .unwrap(); @@ -478,8 +643,10 @@ mod tests { let first_continue_request = AggregationJobContinueReq::new( AggregationJobRound::from(1), Vec::from([PrepareStep::new( - *report.0.metadata().id(), - PrepareStepResult::Continued(report.1.prepare_messages[0].get_encoded()), + *report_init.report_share().metadata().id(), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[1].get_encoded(), + }, )]), ); @@ -499,10 +666,10 @@ mod tests { } } - /// Set up a helper with an aggregation job in round 1 + /// Set up a helper with an aggregation job in round 1. #[allow(clippy::unit_arg)] async fn setup_aggregation_job_continue_round_recovery_test( - ) -> AggregationJobContinueTestCase { + ) -> AggregationJobContinueTestCase> { let mut test_case = setup_aggregation_job_continue_test().await; let first_continue_response = post_aggregation_job_and_decode( @@ -521,7 +688,12 @@ mod tests { .first_continue_request .prepare_steps() .iter() - .map(|step| PrepareStep::new(*step.report_id(), PrepareStepResult::Finished)) + .map(|step| PrepareStep::new( + *step.report_id(), + PrepareStepResult::Finished { + prep_msg: Vec::new() + } + )) .collect() ) ); @@ -577,23 +749,25 @@ mod tests { async fn aggregation_job_continue_round_recovery_mutate_continue_request() { let test_case = setup_aggregation_job_continue_round_recovery_test().await; - let unrelated_report = test_case.report_generator.next(); + let (unrelated_report_init, unrelated_transcript) = test_case + .report_generator + .next(&IdpfInput::from_bools(&[false])); let (before_aggregation_job, before_report_aggregations) = test_case .datastore .run_tx(|tx| { - let (task_id, unrelated_report, aggregation_job_id) = ( + let (task_id, unrelated_report_init, aggregation_job_id) = ( *test_case.task.id(), - unrelated_report.clone(), + unrelated_report_init.clone(), test_case.aggregation_job_id, ); Box::pin(async move { - tx.put_report_share(&task_id, &unrelated_report.0) + tx.put_report_share(&task_id, unrelated_report_init.report_share()) .await .unwrap(); let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -601,8 +775,8 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_sha3(1), &Role::Helper, &task_id, &aggregation_job_id, @@ -621,8 +795,10 @@ mod tests { let modified_request = AggregationJobContinueReq::new( test_case.first_continue_request.round(), Vec::from([PrepareStep::new( - *unrelated_report.0.metadata().id(), - PrepareStepResult::Continued(unrelated_report.1.prepare_messages[0].get_encoded()), + *unrelated_report_init.report_share().metadata().id(), + PrepareStepResult::Finished { + prep_msg: unrelated_transcript.prepare_messages[1].get_encoded(), + }, )]), ); @@ -643,7 +819,7 @@ mod tests { (*test_case.task.id(), test_case.aggregation_job_id); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) @@ -651,8 +827,8 @@ mod tests { .unwrap(); let report_aggregations = tx - .get_report_aggregations_for_aggregation_job::<0, dummy_vdaf::Vdaf>( - &dummy_vdaf::Vdaf::new(), + .get_report_aggregations_for_aggregation_job::>( + &Poplar1::new_sha3(1), &Role::Helper, &task_id, &aggregation_job_id, @@ -685,7 +861,7 @@ mod tests { // round mismatch error instead of tripping the check for a request to continue // to round 0. let aggregation_job = tx - .get_aggregation_job::<0, TimeInterval, dummy_vdaf::Vdaf>( + .get_aggregation_job::>( &task_id, &aggregation_job_id, ) diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index a61779595..b141d3f80 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -12,7 +12,7 @@ use janus_aggregator_core::{ task::{self, Task}, }; use janus_core::{ - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, time::{Clock, DurationExt as _, TimeExt as _}, }; use janus_messages::{ @@ -232,102 +232,102 @@ impl AggregationJobCreator { ) -> anyhow::Result { match (task.query_type(), task.vdaf()) { (task::QueryType::TimeInterval, VdafInstance::Prio3Count) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3CountVec { .. }) => { self.create_aggregation_jobs_for_time_interval_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, Prio3SumVecMultithreaded >(task).await } (task::QueryType::TimeInterval, VdafInstance::Prio3Sum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3SumVec { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } (task::QueryType::TimeInterval, VdafInstance::Prio3Histogram { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::TimeInterval, VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::TimeInterval, VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::TimeInterval, VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. }) => { - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task) .await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3Count) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) .await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3CountVec { .. }) => { let max_batch_size = *max_batch_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, Prio3SumVecMultithreaded >(task, max_batch_size).await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3Sum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) .await } (task::QueryType::FixedSize { max_batch_size }, VdafInstance::Prio3SumVec { .. }) => { let max_batch_size = *max_batch_size; self.create_aggregation_jobs_for_fixed_size_task_no_param::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, Prio3SumVec, >(task, max_batch_size).await } (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3Histogram { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::(task, max_batch_size) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) .await } #[cfg(feature = "fpvec_bounded_l2")] (task::QueryType::FixedSize{max_batch_size}, VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. }) => { let max_batch_size = *max_batch_size; - self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) + self.create_aggregation_jobs_for_fixed_size_task_no_param::>>(task, max_batch_size) .await } @@ -415,6 +415,7 @@ impl AggregationJobCreator { *report_id, *time, ord.try_into()?, + None, ReportAggregationState::Start, )); } @@ -571,6 +572,7 @@ impl AggregationJobCreator { report_id, client_timestamp, ord.try_into()?, + None, ReportAggregationState::Start, )) }) @@ -732,6 +734,7 @@ impl AggregationJobCreator { *report_id, *time, ord.try_into()?, + None, ReportAggregationState::Start, )); } @@ -772,7 +775,7 @@ mod tests { task::{test_util::TaskBuilder, QueryType as TaskQueryType}, }; use janus_core::{ - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, @@ -1466,10 +1469,9 @@ mod tests { >( tx: &Transaction<'_, C>, task_id: &TaskId, - ) -> HashMap, T)> - { + ) -> HashMap, T)> { let vdaf = Prio3::new_count(2).unwrap(); - read_aggregate_jobs_for_task_generic::( + read_aggregate_jobs_for_task_generic::( tx, task_id, &vdaf, ) .await diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 61bb16fc4..1b8173570 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -9,7 +9,7 @@ use janus_aggregator_core::{ self, models::{ AcquiredAggregationJob, AggregationJob, AggregationJobState, LeaderStoredReport, Lease, - PrepareMessageOrShare, ReportAggregation, ReportAggregationState, + ReportAggregation, ReportAggregationState, }, Datastore, }, @@ -20,7 +20,8 @@ use janus_core::{time::Clock, vdaf_dispatch}; use janus_messages::{ query_type::{FixedSize, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, - PartialBatchSelector, PrepareStep, PrepareStepResult, ReportShare, ReportShareError, Role, + PartialBatchSelector, PrepareStep, PrepareStepResult, ReportPrepInit, ReportShare, + ReportShareError, Role, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter, Unit}, @@ -86,13 +87,13 @@ impl AggregationJobDriver { ) -> Result<()> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await }) } } @@ -314,7 +315,7 @@ impl AggregationJobDriver { // Compute report shares to send to helper, and decrypt our input shares & initialize // preparation state. let mut report_aggregations_to_write = Vec::new(); - let mut report_shares = Vec::new(); + let mut report_inits = Vec::new(); let mut stepped_aggregations = Vec::new(); for (report_aggregation, report) in reports { // Check for repeated extensions. @@ -360,10 +361,13 @@ impl AggregationJobDriver { } }; - report_shares.push(ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + report_inits.push(ReportPrepInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + prep_share.get_encoded(), )); stepped_aggregations.push(SteppedAggregation { report_aggregation, @@ -377,7 +381,7 @@ impl AggregationJobDriver { let req = AggregationJobInitializeReq::::new( aggregation_job.aggregation_parameter().get_encoded(), PartialBatchSelector::new(aggregation_job.partial_batch_identifier().clone()), - report_shares, + report_inits, ); let resp_bytes = send_request_to_helper( @@ -398,7 +402,7 @@ impl AggregationJobDriver { lease, task, aggregation_job, - &stepped_aggregations, + stepped_aggregations, report_aggregations_to_write, resp.prepare_steps(), ) @@ -437,9 +441,25 @@ impl AggregationJobDriver { if let ReportAggregationState::Waiting(prep_state, prep_msg) = report_aggregation.state() { - let prep_msg = prep_msg - .get_leader_prepare_message() - .context("report aggregation missing prepare message")?; + let prep_msg = match prep_msg.as_ref() { + Some(prep_msg) => prep_msg, + None => { + // This error indicates programmer/system error (i.e. it cannot possibly be + // the fault of our co-aggregator). We still record this failure against a + // single report, rather than failing the entire request, to (safely) + // minimize impact if we ever encounter this bug. + info!(report_id = %report_aggregation.report_id(), "Report aggregation is missing prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "missing_prepare_message")], + ); + report_aggregations_to_write.push(report_aggregation.with_state( + ReportAggregationState::Failed(ReportShareError::VdafPrepError), + )); + continue; + } + }; // Step our own state. let leader_transition = match vdaf @@ -460,9 +480,19 @@ impl AggregationJobDriver { } }; + let prepare_step_result = match &leader_transition { + PrepareTransition::Continue(_, prep_share) => PrepareStepResult::Continued { + prep_msg: prep_msg.get_encoded(), + prep_share: prep_share.get_encoded(), + }, + PrepareTransition::Finish(_) => PrepareStepResult::Finished { + prep_msg: prep_msg.get_encoded(), + }, + }; + prepare_steps.push(PrepareStep::new( *report_aggregation.report_id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + prepare_step_result, )); stepped_aggregations.push(SteppedAggregation { report_aggregation, @@ -494,7 +524,7 @@ impl AggregationJobDriver { lease, task, aggregation_job, - &stepped_aggregations, + stepped_aggregations, report_aggregations_to_write, resp.prepare_steps(), ) @@ -514,7 +544,7 @@ impl AggregationJobDriver { lease: Arc>, task: Arc, leader_aggregation_job: AggregationJob, - stepped_aggregations: &[SteppedAggregation], + stepped_aggregations: Vec>, mut report_aggregations_to_write: Vec>, helper_prep_steps: &[PrepareStep], ) -> Result<()> @@ -539,11 +569,11 @@ impl AggregationJobDriver { leader_aggregation_job.aggregation_parameter().clone(), ); for (stepped_aggregation, helper_prep_step) in - stepped_aggregations.iter().zip(helper_prep_steps) + stepped_aggregations.into_iter().zip(helper_prep_steps) { let (report_aggregation, leader_transition) = ( - &stepped_aggregation.report_aggregation, - &stepped_aggregation.leader_transition, + stepped_aggregation.report_aggregation, + stepped_aggregation.leader_transition, ); if helper_prep_step.report_id() != report_aggregation.report_id() { return Err(anyhow!( @@ -551,81 +581,213 @@ impl AggregationJobDriver { )); } - let new_state = match helper_prep_step.result() { - PrepareStepResult::Continued(payload) => { - // If the leader continued too, combine the leader's prepare share with the + let new_state = (|| match helper_prep_step.result() { + PrepareStepResult::Continued { + prep_msg, + prep_share, + } => { + // If the Leader continued too, combine the Leader's prepare share with the // helper's to compute next round's prepare message. Prepare to store the // leader's new state & the prepare message. If the leader didn't continue, // transition to INVALID. - if let PrepareTransition::Continue(leader_prep_state, leader_prep_share) = - leader_transition + let leader_prep_state = if let PrepareTransition::Continue( + leader_prep_state, + _, + ) = leader_transition { - let leader_prep_state = leader_prep_state.clone(); - let helper_prep_share = - A::PrepareShare::get_decoded_with_param(&leader_prep_state, payload) - .context("couldn't decode helper's prepare message"); - let prep_msg = helper_prep_share.and_then(|helper_prep_share| { - vdaf.prepare_preprocess([leader_prep_share.clone(), helper_prep_share]) - .context( - "couldn't preprocess leader & helper prepare shares into \ - prepare message", - ) - }); - match prep_msg { - Ok(prep_msg) => ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), - Err(error) => { - info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't compute prepare message"); - self.aggregate_step_failure_counter.add( - &Context::current(), - 1, - &[KeyValue::new("type", "prepare_message_failure")], - ); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) - } - } + leader_prep_state } else { - warn!(report_id = %report_aggregation.report_id(), "Helper continued but leader did not"); + warn!(report_id = %report_aggregation.report_id(), "Helper continued but Leader did not"); self.aggregate_step_failure_counter.add( &Context::current(), 1, &[KeyValue::new("type", "continue_mismatch")], ); - ReportAggregationState::Invalid - } + return ReportAggregationState::Invalid; + }; + + let prep_msg = match A::PrepareMessage::get_decoded_with_param( + &leader_prep_state, + prep_msg, + ) { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't decode prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_message_decode_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::UnrecognizedMessage, + ); + } + }; + let helper_prep_share = match A::PrepareShare::get_decoded_with_param( + &leader_prep_state, + prep_share, + ) { + Ok(helper_prep_share) => helper_prep_share, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't decode Helper prepare share"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "helper_prep_share_decode_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::UnrecognizedMessage, + ); + } + }; + + let (leader_prep_state, leader_prep_share) = match vdaf + .prepare_step(leader_prep_state, prep_msg) + { + Ok(PrepareTransition::Continue(leader_prep_state, leader_prep_share)) => { + (leader_prep_state, leader_prep_share) + } + Ok(_) => { + warn!(report_id = %report_aggregation.report_id(), "Helper continued but Leader did not"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "continue_mismatch")], + ); + return ReportAggregationState::Invalid; + } + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Prepare step failed"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_step_failure")], + ); + return ReportAggregationState::Failed(ReportShareError::VdafPrepError); + } + }; + + let prep_msg = match vdaf + .prepare_preprocess([leader_prep_share, helper_prep_share]) + { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't compute prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_message_failure")], + ); + return ReportAggregationState::Failed(ReportShareError::VdafPrepError); + } + }; + + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) } - PrepareStepResult::Finished => { - // If the leader finished too, we are done; prepare to store the output share. - // If the leader didn't finish too, we transition to INVALID. - if let PrepareTransition::Finish(out_share) = leader_transition { - match accumulator.update( - leader_aggregation_job.partial_batch_identifier(), - report_aggregation.report_id(), - report_aggregation.time(), - out_share, - ) { - Ok(_) => ReportAggregationState::Finished(out_share.clone()), - Err(error) => { - warn!(report_id = %report_aggregation.report_id(), ?error, "Could not update batch aggregation"); + PrepareStepResult::Finished { prep_msg } => { + match leader_transition { + PrepareTransition::Continue(leader_prep_state, _) => { + // The Helper is finished. If the Leader is ready to continue & also + // finishes, we are done; prepare to store the output share. If the + // Leader doesn't finish too, we transition to INVALID. + + let prep_msg = match A::PrepareMessage::get_decoded_with_param( + &leader_prep_state, + prep_msg, + ) { + Ok(prep_msg) => prep_msg, + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Couldn't decode prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_message_decode_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::UnrecognizedMessage, + ); + } + }; + + let out_share = match vdaf.prepare_step(leader_prep_state, prep_msg) { + Ok(PrepareTransition::Finish(out_share)) => out_share, + Ok(_) => { + warn!(report_id = %report_aggregation.report_id(), "Helper finished but Leader did not"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "finish_mismatch")], + ); + return ReportAggregationState::Invalid; + } + Err(err) => { + info!(report_id = %report_aggregation.report_id(), ?err, "Prepare step failed"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "prepare_step_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + ); + } + }; + + if let Err(err) = accumulator.update( + leader_aggregation_job.partial_batch_identifier(), + report_aggregation.report_id(), + report_aggregation.time(), + &out_share, + ) { + warn!(report_id = %report_aggregation.report_id(), ?err, "Could not update batch aggregation"); self.aggregate_step_failure_counter.add( &Context::current(), 1, &[KeyValue::new("type", "accumulate_failure")], ); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + return ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + ); } + + ReportAggregationState::Finished(out_share) + } + + PrepareTransition::Finish(out_share) => { + // If the Leader is already done, check that the Helper properly + // finished too by transmitting a Finished message with an empty + // prep_msg field. If so, we are done. Otherwise, we transition to + // INVALID. + if !prep_msg.is_empty() { + warn!(report_id = %report_aggregation.report_id(), "Leader finished but Helper did not"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "finish_mismatch")], + ); + return ReportAggregationState::Invalid; + } + + if let Err(err) = accumulator.update( + leader_aggregation_job.partial_batch_identifier(), + report_aggregation.report_id(), + report_aggregation.time(), + &out_share, + ) { + warn!(report_id = %report_aggregation.report_id(), ?err, "Could not update batch aggregation"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "accumulate_failure")], + ); + return ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + ); + } + + ReportAggregationState::Finished(out_share) } - } else { - warn!(report_id = %report_aggregation.report_id(), "Helper finished but leader did not"); - self.aggregate_step_failure_counter.add( - &Context::current(), - 1, - &[KeyValue::new("type", "finish_mismatch")], - ); - ReportAggregationState::Invalid } } @@ -640,7 +802,7 @@ impl AggregationJobDriver { ); ReportAggregationState::Failed(*err) } - }; + })(); report_aggregations_to_write.push(report_aggregation.clone().with_state(new_state)); } @@ -650,18 +812,12 @@ impl AggregationJobDriver { let aggregation_job_is_finished = report_aggregations_to_write .iter() .all(|ra| !matches!(ra.state(), ReportAggregationState::Waiting(_, _))); - let (next_round, next_state) = if aggregation_job_is_finished { - ( - leader_aggregation_job.round(), - AggregationJobState::Finished, - ) + let next_state = if aggregation_job_is_finished { + AggregationJobState::Finished } else { - ( - // Advance self to next VDAF preparation round - leader_aggregation_job.round().increment(), - AggregationJobState::InProgress, - ) + AggregationJobState::InProgress }; + let next_round = leader_aggregation_job.round().increment(); let aggregation_job_to_write = leader_aggregation_job .with_round(next_round) @@ -713,9 +869,9 @@ impl AggregationJobDriver { ) -> Result<()> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LEN) => { self.cancel_aggregation_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, TimeInterval, VdafType, @@ -724,9 +880,9 @@ impl AggregationJobDriver { }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (_, VdafType, VERIFY_KEY_LEN) => { self.cancel_aggregation_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, FixedSize, VdafType, @@ -859,7 +1015,7 @@ mod tests { datastore::{ models::{ AggregationJob, AggregationJobState, BatchAggregation, LeaderStoredReport, - PrepareMessageOrShare, ReportAggregation, ReportAggregationState, + ReportAggregation, ReportAggregationState, }, test_util::ephemeral_datastore, }, @@ -871,7 +1027,7 @@ mod tests { self, test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, report_id::ReportIdChecksumExt, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, test_util::{install_test_trace_subscriber, run_vdaf, runtime::TestRuntimeManager}, time::{Clock, IntervalExt, MockClock, TimeExt}, Runtime, @@ -881,13 +1037,17 @@ mod tests { AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, Duration, Extension, ExtensionType, HpkeConfig, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, - ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, + ReportIdChecksum, ReportMetadata, ReportPrepInit, ReportShare, ReportShareError, Role, + TaskId, Time, }; use opentelemetry::global::meter; use prio::{ codec::Encode, + idpf::IdpfInput, vdaf::{ self, + poplar1::{Poplar1, Poplar1AggregationParam}, + prg::PrgSha3, prio3::{Prio3, Prio3Count}, Aggregator, }, @@ -924,8 +1084,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -937,7 +1096,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token().clone(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -954,31 +1113,28 @@ mod tests { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) - .await?; - tx.put_report_aggregation( - &ReportAggregation::::new( + tx.put_aggregation_job( + &AggregationJob::::new( *task.id(), aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - ReportAggregationState::Start, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), ), ) + .await?; + tx.put_report_aggregation(&ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) .await }) }) @@ -986,30 +1142,19 @@ mod tests { .unwrap(); // Setup: prepare mocked HTTP responses. - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); - let helper_responses = Vec::from([ - ( - "PUT", - AggregationJobInitializeReq::::MEDIA_TYPE, - AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( - *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), - )])) - .get_encoded(), - ), - ( - "POST", - AggregationJobContinueReq::MEDIA_TYPE, - AggregationJobResp::MEDIA_TYPE, - AggregationJobResp::new(Vec::from([PrepareStep::new( - *report.metadata().id(), - PrepareStepResult::Finished, - )])) - .get_encoded(), - ), - ]); - let mocked_aggregates = join_all(helper_responses.into_iter().map( + let helper_responses = Vec::from([( + "PUT", + AggregationJobInitializeReq::::MEDIA_TYPE, + AggregationJobResp::MEDIA_TYPE, + AggregationJobResp::new(Vec::from([PrepareStep::new( + *report.metadata().id(), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[0].get_encoded(), + }, + )])) + .get_encoded(), + )]); + let mocked_aggregates = join_all(helper_responses.iter().map( |(req_method, req_content_type, resp_content_type, resp_body)| { server .mock( @@ -1022,7 +1167,7 @@ mod tests { "DAP-Auth-Token", str::from_utf8(agg_auth_token.as_bytes()).unwrap(), ) - .match_header(CONTENT_TYPE.as_str(), req_content_type) + .match_header(CONTENT_TYPE.as_str(), *req_content_type) .with_status(200) .with_header(CONTENT_TYPE.as_str(), resp_content_type) .with_body(resp_body) @@ -1060,7 +1205,9 @@ mod tests { tracing::info!("awaiting stepper tasks"); // Wait for all of the aggregate job stepper tasks to complete. - runtime_manager.wait_for_completed_tasks("stepper", 2).await; + runtime_manager + .wait_for_completed_tasks("stepper", helper_responses.len()) + .await; // Stop the aggregate job driver task. task_handle.abort(); @@ -1069,23 +1216,22 @@ mod tests { mocked_aggregate.assert_async().await; } - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); @@ -1095,7 +1241,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1144,8 +1290,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1157,7 +1302,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1165,7 +1310,7 @@ mod tests { Vec::new(), transcript.input_shares.clone(), ); - let repeated_extension_report = generate_report::( + let repeated_extension_report = generate_report::( *task.id(), ReportMetadata::new(random(), time), helper_hpke_keypair.config(), @@ -1193,7 +1338,7 @@ mod tests { .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( @@ -1207,29 +1352,29 @@ mod tests { AggregationJobRound::from(0), )) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *repeated_extension_report.metadata().id(), - *repeated_extension_report.metadata().time(), - 1, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *repeated_extension_report.metadata().id(), + *repeated_extension_report.metadata().time(), + 1, + None, + ReportAggregationState::Start, + ), + ) .await?; Ok(tx @@ -1247,19 +1392,24 @@ mod tests { // (This is fragile in that it expects the leader request to be deterministically encoded. // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) + let (_, leader_prep_share) = transcript.leader_prep_state(0); let leader_request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_time_interval(), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + Vec::from([ReportPrepInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + leader_prep_share.get_encoded(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[0].get_encoded(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1319,37 +1469,32 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), - ); - let leader_prep_state = transcript.leader_prep_state(0).clone(); - let prep_msg = transcript.prepare_messages[0].clone(); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + None, + ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); let want_repeated_extension_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *repeated_extension_report.metadata().id(), *repeated_extension_report.metadata().time(), 1, + None, ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), ); @@ -1367,7 +1512,7 @@ mod tests { ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1436,8 +1581,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1449,7 +1593,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1467,32 +1611,33 @@ mod tests { tx.put_task(&task).await?; tx.put_client_report(vdaf.borrow(), &report).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - FixedSize, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new( + Time::from_seconds_since_epoch(0), + Duration::from_seconds(1), + ) .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await?; - tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - ReportAggregationState::Start, - )) + tx.put_report_aggregation( + &ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + ), + ) .await?; Ok(tx @@ -1510,19 +1655,24 @@ mod tests { // (This is fragile in that it expects the leader request to be deterministically encoded. // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) + let (_, leader_prep_share) = transcript.leader_prep_state(0); let leader_request = AggregationJobInitializeReq::new( ().get_encoded(), PartialBatchSelector::new_fixed_size(batch_id), - Vec::from([ReportShare::new( - report.metadata().clone(), - report.public_share().get_encoded(), - report.helper_encrypted_input_share().clone(), + Vec::from([ReportPrepInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded(), + report.helper_encrypted_input_share().clone(), + ), + leader_prep_share.get_encoded(), )]), ); - let (_, helper_vdaf_msg) = transcript.helper_prep_state(0); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: transcript.prepare_messages[0].get_encoded(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1582,27 +1732,23 @@ mod tests { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - transcript.leader_prep_state(0).clone(), - PrepareMessageOrShare::Leader(transcript.prepare_messages[0].clone()), - ), + None, + ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); let (got_aggregation_job, got_report_aggregation) = ds @@ -1611,7 +1757,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1646,11 +1792,11 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::TimeInterval, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) @@ -1660,20 +1806,23 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1683,17 +1832,21 @@ mod tests { ); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); + let (leader_prep_state, _) = transcript.leader_prep_state(1); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate( + &aggregation_param, + [transcript.output_share(Role::Leader).clone()], + ) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; + let prep_msg = &transcript.prepare_messages[1]; let lease = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, aggregation_param, report, leader_prep_state, prep_msg) = ( vdaf.clone(), task.clone(), + aggregation_param.clone(), report.clone(), leader_prep_state.clone(), prep_msg.clone(), @@ -1703,13 +1856,13 @@ mod tests { tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -1718,18 +1871,16 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), )) .await?; @@ -1752,12 +1903,16 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg.get_encoded(), + }, )]), ); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Finished, + PrepareStepResult::Finished { + prep_msg: Vec::new(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -1815,37 +1970,39 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, - AggregationJobRound::from(1), + AggregationJobRound::from(2), + ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), - ); let batch_interval_start = report .metadata() .time() .to_batch_interval_start(task.time_precision()) .unwrap(); let want_batch_aggregations = Vec::from([BatchAggregation::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, - Prio3Count, + Poplar1, >::new( *task.id(), Interval::new(batch_interval_start, *task.time_precision()).unwrap(), - (), + aggregation_param.clone(), 0, leader_aggregate_share, 1, @@ -1855,11 +2012,15 @@ mod tests { let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds .run_tx(|tx| { - let (vdaf, task, report_metadata) = - (Arc::clone(&vdaf), task.clone(), report.metadata().clone()); + let (vdaf, task, report_metadata, aggregation_param) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -1877,8 +2038,8 @@ mod tests { .unwrap(); let batch_aggregations = TimeInterval::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, >( tx, @@ -1892,7 +2053,7 @@ mod tests { *task.time_precision(), ) .unwrap(), - &(), + &aggregation_param, ) .await .unwrap(); @@ -1909,7 +2070,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -1933,11 +2094,11 @@ mod tests { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone())); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let vdaf = Arc::new(Poplar1::new_sha3(1)); let task = TaskBuilder::new( QueryType::FixedSize { max_batch_size: 10 }, - VdafInstance::Prio3Count, + VdafInstance::Poplar1 { bits: 1 }, Role::Leader, ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) @@ -1949,20 +2110,23 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(), ); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); + let aggregation_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + ])) + .unwrap(); let transcript = run_vdaf( vdaf.as_ref(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &0, + &IdpfInput::from_bools(&[true]), ); let agg_auth_token = task.primary_aggregator_auth_token(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::>( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -1972,18 +2136,22 @@ mod tests { ); let batch_id = random(); let aggregation_job_id = random(); - let leader_prep_state = transcript.leader_prep_state(0); + let (leader_prep_state, _) = transcript.leader_prep_state(1); let leader_aggregate_share = vdaf - .aggregate(&(), [transcript.output_share(Role::Leader).clone()]) + .aggregate( + &aggregation_param, + [transcript.output_share(Role::Leader).clone()], + ) .unwrap(); - let prep_msg = &transcript.prepare_messages[0]; + let prep_msg = &transcript.prepare_messages[1]; let lease = ds .run_tx(|tx| { - let (vdaf, task, report, leader_prep_state, prep_msg) = ( + let (vdaf, task, report, aggregation_param, leader_prep_state, prep_msg) = ( vdaf.clone(), task.clone(), report.clone(), + aggregation_param.clone(), leader_prep_state.clone(), prep_msg.clone(), ); @@ -1992,13 +2160,13 @@ mod tests { tx.put_client_report(vdaf.borrow(), &report).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, - Prio3Count, + Poplar1, >::new( *task.id(), aggregation_job_id, - (), + aggregation_param, batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), @@ -2007,18 +2175,16 @@ mod tests { )) .await?; tx.put_report_aggregation(&ReportAggregation::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, >::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), )) .await?; @@ -2041,12 +2207,16 @@ mod tests { AggregationJobRound::from(1), Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Continued(prep_msg.get_encoded()), + PrepareStepResult::Finished { + prep_msg: prep_msg.get_encoded(), + }, )]), ); let helper_response = AggregationJobResp::new(Vec::from([PrepareStep::new( *report.metadata().id(), - PrepareStepResult::Finished, + PrepareStepResult::Finished { + prep_msg: Vec::new(), + }, )])); let mocked_aggregate_failure = server .mock( @@ -2104,33 +2274,35 @@ mod tests { mocked_aggregate_success.assert_async().await; let want_aggregation_job = - AggregationJob::::new( + AggregationJob::>::new( *task.id(), aggregation_job_id, - (), + aggregation_param.clone(), batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), AggregationJobState::Finished, - AggregationJobRound::from(1), + AggregationJobRound::from(2), ); let leader_output_share = transcript.output_share(Role::Leader); - let want_report_aggregation = ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - ReportAggregationState::Finished(leader_output_share.clone()), - ); + let want_report_aggregation = + ReportAggregation::>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished(leader_output_share.clone()), + ); let want_batch_aggregations = Vec::from([BatchAggregation::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, FixedSize, - Prio3Count, + Poplar1, >::new( *task.id(), batch_id, - (), + aggregation_param.clone(), 0, leader_aggregate_share, 1, @@ -2140,11 +2312,15 @@ mod tests { let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds .run_tx(|tx| { - let (vdaf, task, report_metadata) = - (Arc::clone(&vdaf), task.clone(), report.metadata().clone()); + let (vdaf, task, report_metadata, aggregation_param) = ( + Arc::clone(&vdaf), + task.clone(), + report.metadata().clone(), + aggregation_param.clone(), + ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::>( task.id(), &aggregation_job_id, ) @@ -2162,10 +2338,10 @@ mod tests { .unwrap(); let batch_aggregations = FixedSize::get_batch_aggregations_for_collect_identifier::< - PRIO3_VERIFY_KEY_LENGTH, - Prio3Count, + VERIFY_KEY_LEN, + Poplar1, _, - >(tx, &task, &vdaf, &batch_id, &()) + >(tx, &task, &vdaf, &batch_id, &aggregation_param) .await?; Ok((aggregation_job, report_aggregation, batch_aggregations)) }) @@ -2180,7 +2356,7 @@ mod tests { BatchAggregation::new( *agg.task_id(), *agg.batch_identifier(), - (), + aggregation_param.clone(), 0, agg.aggregate_share().clone(), agg.report_count(), @@ -2215,8 +2391,7 @@ mod tests { .to_batch_interval_start(task.time_precision()) .unwrap(); let report_metadata = ReportMetadata::new(random(), time); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -2227,7 +2402,7 @@ mod tests { ); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2237,23 +2412,22 @@ mod tests { ); let aggregation_job_id = random(); - let aggregation_job = - AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - ); - let report_aggregation = ReportAggregation::::new( + let aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ); + let report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Start, ); @@ -2304,7 +2478,7 @@ mod tests { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2397,8 +2571,7 @@ mod tests { .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); - let verify_key: VerifyKey = - task.primary_vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); let helper_hpke_keypair = generate_test_hpke_config_and_private_key(); @@ -2409,7 +2582,7 @@ mod tests { .unwrap(); let report_metadata = ReportMetadata::new(random(), time); let transcript = run_vdaf(&vdaf, verify_key.as_bytes(), &(), report_metadata.id(), &0); - let report = generate_report::( + let report = generate_report::( *task.id(), report_metadata, helper_hpke_keypair.config(), @@ -2430,34 +2603,31 @@ mod tests { // run through initial VDAF preparation before sending a request to the helper. tx.put_client_report(&vdaf, &report).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) - .await?; - - tx.put_report_aggregation( - &ReportAggregation::::new( + tx.put_aggregation_job( + &AggregationJob::::new( *task.id(), aggregation_job_id, - *report.metadata().id(), - *report.metadata().time(), - 0, - ReportAggregationState::Start, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), ), ) .await?; + tx.put_report_aggregation(&ReportAggregation::::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Start, + )) + .await?; + Ok(()) }) }) @@ -2556,7 +2726,7 @@ mod tests { .run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -2568,7 +2738,7 @@ mod tests { .unwrap(); assert_eq!( aggregation_job_after, - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index cb9c6da37..df435cdbe 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -71,9 +71,9 @@ impl CollectionJobDriver { ) -> Result<(), Error> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { self.step_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, TimeInterval, VdafType @@ -82,9 +82,9 @@ impl CollectionJobDriver { }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { self.step_collection_job_generic::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, C, FixedSize, VdafType @@ -261,8 +261,8 @@ impl CollectionJobDriver { ) -> Result<(), Error> { match lease.leased().query_type() { task::QueryType::TimeInterval => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.abandon_collection_job_generic::( + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.abandon_collection_job_generic::( datastore, Arc::new(vdaf), lease, @@ -271,8 +271,8 @@ impl CollectionJobDriver { }) } task::QueryType::FixedSize { .. } => { - vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.abandon_collection_job_generic::( + vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LEN) => { + self.abandon_collection_job_generic::( datastore, Arc::new(vdaf), lease, @@ -566,7 +566,8 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Finished(OutputShare()), + None, + ReportAggregationState::Finished(OutputShare(0)), )) .await?; @@ -691,7 +692,8 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Finished(OutputShare()), + None, + ReportAggregationState::Finished(OutputShare(0)), )) .await?; diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 75ef2f32e..abeaa32e1 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -200,7 +200,8 @@ async fn setup_fixed_size_current_batch_collection_job_test_case() -> ( *report.metadata().id(), time, ord, - ReportAggregationState::Finished(dummy_vdaf::OutputShare()), + None, + ReportAggregationState::Finished(dummy_vdaf::OutputShare(0)), )) .await .unwrap(); diff --git a/aggregator/src/aggregator/garbage_collector.rs b/aggregator/src/aggregator/garbage_collector.rs index db8808f35..b90e5fb26 100644 --- a/aggregator/src/aggregator/garbage_collector.rs +++ b/aggregator/src/aggregator/garbage_collector.rs @@ -159,6 +159,7 @@ mod tests { *report.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -320,6 +321,7 @@ mod tests { *report_share.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -477,6 +479,7 @@ mod tests { *report.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -642,6 +645,7 @@ mod tests { *report_share.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index bdca3886f..0fe845368 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -3,8 +3,8 @@ use self::models::{ AcquiredAggregationJob, AcquiredCollectionJob, AggregateShareJob, AggregationJob, AggregatorRole, BatchAggregation, CollectionJob, CollectionJobState, CollectionJobStateCode, - LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, PrepareMessageOrShare, - ReportAggregation, ReportAggregationState, ReportAggregationStateCode, SqlInterval, + LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, ReportAggregation, + ReportAggregationState, ReportAggregationStateCode, SqlInterval, }; #[cfg(feature = "test-util")] use crate::VdafHasAggregationParameter; @@ -24,7 +24,8 @@ use janus_core::{ use janus_messages::{ query_type::{QueryType, TimeInterval}, AggregationJobId, BatchId, CollectionJobId, Duration, Extension, HpkeCiphertext, HpkeConfig, - Interval, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, + Interval, PrepareStep, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, + Time, }; use opentelemetry::{ metrics::{Counter, Histogram}, @@ -1719,8 +1720,9 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT client_reports.report_id, client_reports.client_timestamp, report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.out_share, - report_aggregations.error_code, aggregation_jobs.aggregation_param + report_aggregations.prep_msg, report_aggregations.last_prep_step, + report_aggregations.out_share, report_aggregations.error_code, + aggregation_jobs.aggregation_param FROM report_aggregations JOIN client_reports ON client_reports.id = report_aggregations.client_report_id JOIN aggregation_jobs @@ -1772,8 +1774,9 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT client_reports.report_id, client_reports.client_timestamp, report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.out_share, - report_aggregations.error_code, aggregation_jobs.aggregation_param + report_aggregations.prep_msg, report_aggregations.last_prep_step, + report_aggregations.out_share, report_aggregations.error_code, + aggregation_jobs.aggregation_param FROM report_aggregations JOIN client_reports ON client_reports.id = report_aggregations.client_report_id JOIN aggregation_jobs @@ -1826,8 +1829,9 @@ impl Transaction<'_, C> { aggregation_jobs.aggregation_job_id, client_reports.report_id, client_reports.client_timestamp, report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.out_share, - report_aggregations.error_code, aggregation_jobs.aggregation_param + report_aggregations.prep_msg, report_aggregations.last_prep_step, + report_aggregations.out_share, report_aggregations.error_code, + aggregation_jobs.aggregation_param FROM report_aggregations JOIN client_reports ON client_reports.id = report_aggregations.client_report_id JOIN aggregation_jobs @@ -1867,6 +1871,7 @@ impl Transaction<'_, C> { let state: ReportAggregationStateCode = row.get("state"); let prep_state_bytes: Option> = row.get("prep_state"); let prep_msg_bytes: Option> = row.get("prep_msg"); + let last_prep_step_bytes: Option> = row.get("last_prep_step"); let out_share_bytes: Option> = row.get("out_share"); let error_code: Option = row.get("error_code"); let aggregation_param_bytes = row.get("aggregation_param"); @@ -1883,6 +1888,10 @@ impl Transaction<'_, C> { None => None, }; + let last_prep_step = last_prep_step_bytes + .map(|bytes| PrepareStep::get_decoded(&bytes)) + .transpose()?; + let agg_state = match state { ReportAggregationStateCode::Start => ReportAggregationState::Start, ReportAggregationStateCode::Waiting => { @@ -1898,26 +1907,14 @@ impl Transaction<'_, C> { ) })?, )?; - let prep_msg_bytes = prep_msg_bytes.ok_or_else(|| { - Error::DbState( - "report aggregation in state WAITING but prep_msg is NULL".to_string(), - ) - })?; - let prep_msg = match role { - Role::Leader => PrepareMessageOrShare::Leader( - A::PrepareMessage::get_decoded_with_param(&prep_state, &prep_msg_bytes)?, - ), - Role::Helper => PrepareMessageOrShare::Helper( - A::PrepareShare::get_decoded_with_param(&prep_state, &prep_msg_bytes)?, - ), - _ => return Err(Error::DbState(format!("unexpected role {role}"))), - }; - + let prep_msg = prep_msg_bytes + .map(|bytes| A::PrepareMessage::get_decoded_with_param(&prep_state, &bytes)) + .transpose()?; ReportAggregationState::Waiting(prep_state, prep_msg) } ReportAggregationStateCode::Finished => { let aggregation_param = A::AggregationParam::get_decoded(aggregation_param_bytes)?; - ReportAggregationState::Finished(A::OutputShare::get_decoded_with_param( + let output_share = A::OutputShare::get_decoded_with_param( &(vdaf, &aggregation_param), &out_share_bytes.ok_or_else(|| { Error::DbState( @@ -1925,7 +1922,8 @@ impl Transaction<'_, C> { .to_string(), ) })?, - )?) + )?; + ReportAggregationState::Finished(output_share) } ReportAggregationStateCode::Failed => { ReportAggregationState::Failed(error_code.ok_or_else(|| { @@ -1943,6 +1941,7 @@ impl Transaction<'_, C> { *report_id, time, ord, + last_prep_step, agg_state, )) } @@ -1960,17 +1959,20 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); + let encoded_last_prep_step = report_aggregation + .last_prep_step() + .map(PrepareStep::get_encoded); let stmt = self .prepare_cached( "INSERT INTO report_aggregations (aggregation_job_id, client_report_id, ord, state, prep_state, prep_msg, out_share, - error_code) + error_code, last_prep_step) VALUES ((SELECT id FROM aggregation_jobs WHERE aggregation_job_id = $1), (SELECT id FROM client_reports WHERE task_id = (SELECT id FROM tasks WHERE task_id = $2) AND report_id = $3), - $4, $5, $6, $7, $8, $9)", + $4, $5, $6, $7, $8, $9, $10)", ) .await?; self.execute( @@ -1986,6 +1988,7 @@ impl Transaction<'_, C> { /* prep_msg */ &encoded_state_values.prep_msg, /* out_share */ &encoded_state_values.output_share, /* error_code */ &encoded_state_values.report_share_err, + /* last_prep_step */ &encoded_last_prep_step, ], ) .await?; @@ -2004,16 +2007,19 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); + let encoded_last_prep_step = report_aggregation + .last_prep_step() + .map(PrepareStep::get_encoded); let stmt = self .prepare_cached( "UPDATE report_aggregations SET ord = $1, state = $2, prep_state = $3, - prep_msg = $4, out_share = $5, error_code = $6 + prep_msg = $4, out_share = $5, error_code = $6, last_prep_step = $7 WHERE aggregation_job_id = (SELECT id FROM aggregation_jobs WHERE - aggregation_job_id = $7) + aggregation_job_id = $8) AND client_report_id = (SELECT id FROM client_reports - WHERE task_id = (SELECT id FROM tasks WHERE task_id = $8) - AND report_id = $9)", + WHERE task_id = (SELECT id FROM tasks WHERE task_id = $9) + AND report_id = $10)", ) .await?; check_single_row_mutation( @@ -2026,6 +2032,7 @@ impl Transaction<'_, C> { /* prep_msg */ &encoded_state_values.prep_msg, /* out_share */ &encoded_state_values.output_share, /* error_code */ &encoded_state_values.report_share_err, + /* last_prep_step */ &encoded_last_prep_step, /* aggregation_job_id */ &report_aggregation.aggregation_job_id().as_ref(), /* task_id */ &report_aggregation.task_id().as_ref(), @@ -3843,8 +3850,8 @@ pub mod models { use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregationJobId, AggregationJobRound, BatchId, CollectionJobId, Duration, Extension, - HpkeCiphertext, Interval, ReportId, ReportIdChecksum, ReportMetadata, ReportShareError, - Role, TaskId, Time, + HpkeCiphertext, Interval, PrepareStep, ReportId, ReportIdChecksum, ReportMetadata, + ReportShareError, Role, TaskId, Time, }; use postgres_protocol::types::{ range_from_sql, range_to_sql, timestamp_from_sql, timestamp_to_sql, Range, RangeBound, @@ -4384,6 +4391,7 @@ pub mod models { time: Time, ord: u64, state: ReportAggregationState, + last_prep_step: Option, } impl> ReportAggregation { @@ -4394,6 +4402,7 @@ pub mod models { report_id: ReportId, time: Time, ord: u64, + last_prep_step: Option, state: ReportAggregationState, ) -> Self { Self { @@ -4402,6 +4411,7 @@ pub mod models { report_id, time, ord, + last_prep_step, state, } } @@ -4436,6 +4446,17 @@ pub mod models { self.ord } + pub fn last_prep_step(&self) -> Option<&PrepareStep> { + self.last_prep_step.as_ref() + } + + pub fn with_last_prep_step(self, last_prep_step: Option) -> Self { + Self { + last_prep_step, + ..self + } + } + /// Returns the state of the report aggregation. pub fn state(&self) -> &ReportAggregationState { &self.state @@ -4462,6 +4483,7 @@ pub mod models { && self.report_id == other.report_id && self.time == other.time && self.ord == other.ord + && self.last_prep_step == other.last_prep_step && self.state == other.state } } @@ -4476,76 +4498,6 @@ pub mod models { { } - /// Represents either a preprocessed VDAF preparation message (for the leader) or a VDAF - /// preparation message share (for the helper). - #[derive(Clone, Derivative)] - #[derivative(Debug)] - pub enum PrepareMessageOrShare> { - /// The helper stores a prepare message share - Helper(#[derivative(Debug = "ignore")] A::PrepareShare), - /// The leader stores a combined prepare message - Leader(#[derivative(Debug = "ignore")] A::PrepareMessage), - } - - impl> - PrepareMessageOrShare - { - /// Get the leader's preprocessed prepare message, or an error if this is a helper's prepare - /// share. - pub fn get_leader_prepare_message(&self) -> Result<&A::PrepareMessage, Error> - where - A::PrepareMessage: Encode, - { - if let Self::Leader(prep_msg) = self { - Ok(prep_msg) - } else { - Err(Error::InvalidParameter( - "does not contain a prepare message", - )) - } - } - - /// Get the helper's prepare share, or an error if this is a leader's preprocessed prepare - /// message. - pub fn get_helper_prepare_share(&self) -> Result<&A::PrepareShare, Error> - where - A::PrepareShare: Encode, - { - if let Self::Helper(prep_share) = self { - Ok(prep_share) - } else { - Err(Error::InvalidParameter("does not contain a prepare share")) - } - } - } - - impl PartialEq for PrepareMessageOrShare - where - A: vdaf::Aggregator, - A::PrepareShare: PartialEq, - A::PrepareMessage: PartialEq, - { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Helper(self_prep_share), Self::Helper(other_prep_share)) => { - self_prep_share.eq(other_prep_share) - } - (Self::Leader(self_prep_msg), Self::Leader(other_prep_msg)) => { - self_prep_msg.eq(other_prep_msg) - } - _ => false, - } - } - } - - impl Eq for PrepareMessageOrShare - where - A: vdaf::Aggregator, - A::PrepareShare: Eq, - A::PrepareMessage: Eq, - { - } - /// ReportAggregationState represents the state of a single report aggregation. It corresponds /// to the REPORT_AGGREGATION_STATE enum in the schema, along with the state-specific data. #[derive(Clone, Derivative)] @@ -4554,7 +4506,7 @@ pub mod models { Start, Waiting( #[derivative(Debug = "ignore")] A::PrepareState, - #[derivative(Debug = "ignore")] PrepareMessageOrShare, + #[derivative(Debug = "ignore")] Option, ), Finished(#[derivative(Debug = "ignore")] A::OutputShare), Failed(ReportShareError), @@ -4567,9 +4519,9 @@ pub mod models { pub fn state_code(&self) -> ReportAggregationStateCode { match self { ReportAggregationState::Start => ReportAggregationStateCode::Start, - ReportAggregationState::Waiting(_, _) => ReportAggregationStateCode::Waiting, - ReportAggregationState::Finished(_) => ReportAggregationStateCode::Finished, - ReportAggregationState::Failed(_) => ReportAggregationStateCode::Failed, + ReportAggregationState::Waiting { .. } => ReportAggregationStateCode::Waiting, + ReportAggregationState::Finished { .. } => ReportAggregationStateCode::Finished, + ReportAggregationState::Failed { .. } => ReportAggregationStateCode::Failed, ReportAggregationState::Invalid => ReportAggregationStateCode::Invalid, } } @@ -4581,37 +4533,33 @@ pub mod models { where A::PrepareState: Encode, { - let (prep_state, prep_msg, output_share, report_share_err) = match self { - ReportAggregationState::Start => (None, None, None, None), + match self { + ReportAggregationState::Start => EncodedReportAggregationStateValues::default(), ReportAggregationState::Waiting(prep_state, prep_msg) => { - let encoded_msg = match prep_msg { - PrepareMessageOrShare::Leader(prep_msg) => prep_msg.get_encoded(), - PrepareMessageOrShare::Helper(prep_share) => prep_share.get_encoded(), - }; - ( - Some(prep_state.get_encoded()), - Some(encoded_msg), - None, - None, - ) + EncodedReportAggregationStateValues { + prep_state: Some(prep_state.get_encoded()), + prep_msg: prep_msg.as_ref().map(Encode::get_encoded), + ..Default::default() + } } ReportAggregationState::Finished(output_share) => { - (None, None, Some(output_share.get_encoded()), None) + EncodedReportAggregationStateValues { + output_share: Some(output_share.get_encoded()), + ..Default::default() + } } ReportAggregationState::Failed(report_share_err) => { - (None, None, None, Some(*report_share_err as i16)) + EncodedReportAggregationStateValues { + report_share_err: Some(*report_share_err as i16), + ..Default::default() + } } - ReportAggregationState::Invalid => (None, None, None, None), - }; - EncodedReportAggregationStateValues { - prep_state, - prep_msg, - output_share, - report_share_err, + ReportAggregationState::Invalid => EncodedReportAggregationStateValues::default(), } } } + #[derive(Default)] pub(super) struct EncodedReportAggregationStateValues { pub(super) prep_state: Option>, pub(super) prep_msg: Option>, @@ -5367,8 +5315,8 @@ mod tests { models::{ AcquiredAggregationJob, AcquiredCollectionJob, AggregateShareJob, AggregationJob, AggregationJobState, BatchAggregation, CollectionJob, CollectionJobState, - LeaderStoredReport, Lease, OutstandingBatch, PrepareMessageOrShare, - ReportAggregation, ReportAggregationState, SqlInterval, + LeaderStoredReport, Lease, OutstandingBatch, ReportAggregation, + ReportAggregationState, SqlInterval, }, test_util::{ephemeral_datastore, generate_aead_key}, Crypter, Datastore, Error, Transaction, @@ -5382,7 +5330,7 @@ mod tests { use futures::future::try_join_all; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, - task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{VdafInstance, VERIFY_KEY_LEN}, test_util::{ dummy_vdaf::{self, AggregateShare, AggregationParam}, install_test_trace_subscriber, run_vdaf, @@ -5393,8 +5341,8 @@ mod tests { query_type::{FixedSize, QueryType, TimeInterval}, AggregateShareAad, AggregationJobId, AggregationJobRound, BatchId, BatchSelector, CollectionJobId, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, - Interval, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, - TaskId, Time, + Interval, PrepareStep, PrepareStepResult, ReportId, ReportIdChecksum, ReportMetadata, + ReportShare, ReportShareError, Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, @@ -5963,6 +5911,7 @@ mod tests { aggregated_report_id, aggregated_report_time, 0, + None, ReportAggregationState::Start, )) .await @@ -6233,6 +6182,7 @@ mod tests { *report_0.metadata().id(), *report_0.metadata().time(), 0, + None, ReportAggregationState::Start, ); let aggregation_job_0_report_aggregation_1 = @@ -6242,6 +6192,7 @@ mod tests { *report_1.metadata().id(), *report_1.metadata().time(), 1, + None, ReportAggregationState::Start, ); @@ -6262,6 +6213,7 @@ mod tests { *report_0.metadata().id(), *report_0.metadata().time(), 0, + None, ReportAggregationState::Start, ); let aggregation_job_1_report_aggregation_1 = @@ -6271,6 +6223,7 @@ mod tests { *report_1.metadata().id(), *report_1.metadata().time(), 1, + None, ReportAggregationState::Start, ); @@ -6614,7 +6567,7 @@ mod tests { tx.put_task(&task).await?; for aggregation_job_id in aggregation_job_ids { tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( @@ -6631,20 +6584,18 @@ mod tests { } // Write an aggregation job that is finished. We don't want to retrieve this one. - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::Finished, - AggregationJobRound::from(1), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Finished, + AggregationJobRound::from(1), + ), + ) .await?; // Write an aggregation job for a task that we are taking on the helper role for. @@ -6656,20 +6607,18 @@ mod tests { ) .build(); tx.put_task(&helper_task).await?; - tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *helper_task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobRound::from(0), - )) + tx.put_aggregation_job( + &AggregationJob::::new( + *helper_task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobRound::from(0), + ), + ) .await }) }) @@ -6882,7 +6831,7 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( &random(), &random(), ) @@ -6896,7 +6845,7 @@ mod tests { let rslt = ds .run_tx(|tx| { Box::pin(async move { - tx.update_aggregation_job::( + tx.update_aggregation_job::( &AggregationJob::new( random(), random(), @@ -7027,66 +6976,49 @@ mod tests { let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); + let verify_key: [u8; VERIFY_KEY_LEN] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); - let (helper_prep_state, helper_prep_share) = vdaf_transcript.helper_prep_state(0); - let leader_prep_state = vdaf_transcript.leader_prep_state(0); + let (leader_prep_state, _) = vdaf_transcript.leader_prep_state(0); - for (ord, (role, state)) in [ - ( - Role::Leader, - ReportAggregationState::::Start, - ), - ( - Role::Helper, - ReportAggregationState::Waiting( - helper_prep_state.clone(), - PrepareMessageOrShare::Helper(helper_prep_share.clone()), - ), - ), - ( - Role::Leader, - ReportAggregationState::Waiting( - leader_prep_state.clone(), - PrepareMessageOrShare::Leader(vdaf_transcript.prepare_messages[0].clone()), - ), - ), - ( - Role::Leader, - ReportAggregationState::Finished( - vdaf_transcript.output_share(Role::Leader).clone(), - ), - ), - ( - Role::Leader, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), + for (ord, state) in [ + ReportAggregationState::::Start, + ReportAggregationState::Waiting( + leader_prep_state.clone(), + Some(vdaf_transcript.prepare_messages[0].clone()), ), - (Role::Leader, ReportAggregationState::Invalid), + ReportAggregationState::Waiting(leader_prep_state.clone(), None), + ReportAggregationState::Finished(vdaf_transcript.output_share(Role::Leader).clone()), + ReportAggregationState::Finished(vdaf_transcript.output_share(Role::Helper).clone()), + ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Invalid, ] .into_iter() .enumerate() { - let task = TaskBuilder::new( - task::QueryType::TimeInterval, - VdafInstance::Prio3Count, - role, - ) - .build(); + let task_id = random(); let aggregation_job_id = random(); let time = Time::from_seconds_since_epoch(12345); - let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); - let report_aggregation = ds + let want_report_aggregation = ds .run_tx(|tx| { - let (task, state) = (task.clone(), state.clone()); + let state = state.clone(); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_task( + &TaskBuilder::new( + task::QueryType::TimeInterval, + VdafInstance::Prio3Count, + Role::Leader, + ) + .with_id(task_id) + .build(), + ) + .await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( - *task.id(), + task_id, aggregation_job_id, (), (), @@ -7100,7 +7032,7 @@ mod tests { )) .await?; tx.put_report_share( - task.id(), + &task_id, &ReportShare::new( ReportMetadata::new(report_id, time), Vec::from("public_share"), @@ -7114,11 +7046,18 @@ mod tests { .await?; let report_aggregation = ReportAggregation::new( - *task.id(), + task_id, aggregation_job_id, report_id, time, - ord.try_into().unwrap(), + 0, + Some(PrepareStep::new( + report_id, + PrepareStepResult::Continued { + prep_msg: format!("prep_msg_{ord}").into(), + prep_share: format!("prep_share_{ord}").into(), + }, + )), state, ); tx.put_report_aggregation(&report_aggregation).await?; @@ -7130,12 +7069,12 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &role, - task.id(), + &Role::Leader, + &task_id, &aggregation_job_id, &report_id, ) @@ -7145,45 +7084,34 @@ mod tests { .await .unwrap() .unwrap(); - assert_eq!(report_aggregation, got_report_aggregation); - - if let ReportAggregationState::Waiting(_, message) = got_report_aggregation.state() { - match role { - Role::Leader => { - assert!(message.get_leader_prepare_message().is_ok()); - assert!(message.get_helper_prepare_share().is_err()); - } - Role::Helper => { - assert!(message.get_helper_prepare_share().is_ok()); - assert!(message.get_leader_prepare_message().is_err()); - } - _ => panic!("unexpected role"), - } - } - - let new_report_aggregation = ReportAggregation::new( - *report_aggregation.task_id(), - *report_aggregation.aggregation_job_id(), - *report_aggregation.report_id(), - *report_aggregation.time(), - report_aggregation.ord() + 10, - report_aggregation.state().clone(), + assert_eq!(want_report_aggregation, got_report_aggregation); + + let want_report_aggregation = ReportAggregation::new( + *want_report_aggregation.task_id(), + *want_report_aggregation.aggregation_job_id(), + *want_report_aggregation.report_id(), + *want_report_aggregation.time(), + want_report_aggregation.ord() + 10, + want_report_aggregation.last_prep_step().cloned(), + want_report_aggregation.state().clone(), ); ds.run_tx(|tx| { - let new_report_aggregation = new_report_aggregation.clone(); - Box::pin(async move { tx.update_report_aggregation(&new_report_aggregation).await }) + let want_report_aggregation = want_report_aggregation.clone(); + Box::pin( + async move { tx.update_report_aggregation(&want_report_aggregation).await }, + ) }) .await .unwrap(); let got_report_aggregation = ds .run_tx(|tx| { - let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + let vdaf = Arc::clone(&vdaf); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &role, - task.id(), + &Role::Leader, + &task_id, &aggregation_job_id, &report_id, ) @@ -7192,7 +7120,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(Some(new_report_aggregation), got_report_aggregation); + assert_eq!(Some(want_report_aggregation), got_report_aggregation); } } @@ -7247,6 +7175,7 @@ mod tests { report_id, Time::from_seconds_since_epoch(12345), 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation).await?; @@ -7355,6 +7284,7 @@ mod tests { ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(12345), 0, + None, ReportAggregationState::Invalid, )) .await @@ -7372,7 +7302,7 @@ mod tests { let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); + let verify_key: [u8; VERIFY_KEY_LEN] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); let task = TaskBuilder::new( @@ -7384,18 +7314,18 @@ mod tests { let aggregation_job_id = random(); let time = Time::from_seconds_since_epoch(12345); - let report_aggregations = ds + let want_report_aggregations = ds .run_tx(|tx| { let (task, prep_msg, prep_state, output_share) = ( task.clone(), vdaf_transcript.prepare_messages[0].clone(), - vdaf_transcript.leader_prep_state(0).clone(), + vdaf_transcript.leader_prep_state(0).0.clone(), vdaf_transcript.output_share(Role::Leader).clone(), ); Box::pin(async move { tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< - PRIO3_VERIFY_KEY_LENGTH, + VERIFY_KEY_LEN, TimeInterval, Prio3Count, >::new( @@ -7410,13 +7340,12 @@ mod tests { )) .await?; - let mut report_aggregations = Vec::new(); + let mut want_report_aggregations = Vec::new(); for (ord, state) in [ - ReportAggregationState::::Start, - ReportAggregationState::Waiting( - prep_state.clone(), - PrepareMessageOrShare::Leader(prep_msg), - ), + ReportAggregationState::::Start, + ReportAggregationState::Waiting(prep_state.clone(), Some(prep_msg)), + ReportAggregationState::Waiting(prep_state.clone(), None), + ReportAggregationState::Finished(output_share.clone()), ReportAggregationState::Finished(output_share), ReportAggregationState::Failed(ReportShareError::VdafPrepError), ReportAggregationState::Invalid, @@ -7445,12 +7374,18 @@ mod tests { report_id, time, ord.try_into().unwrap(), + Some(PrepareStep::new( + report_id, + PrepareStepResult::Finished { + prep_msg: format!("prep_msg_{ord}").into(), + }, + )), state.clone(), ); tx.put_report_aggregation(&report_aggregation).await?; - report_aggregations.push(report_aggregation); + want_report_aggregations.push(report_aggregation); } - Ok(report_aggregations) + Ok(want_report_aggregations) }) }) .await @@ -7471,7 +7406,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(report_aggregations, got_report_aggregations); + assert_eq!(want_report_aggregations, got_report_aggregations); } #[tokio::test] @@ -7927,6 +7862,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, // Doesn't matter what state the report aggregation is in )]); @@ -8057,6 +7993,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, // Doesn't matter what state the report aggregation is in )]); @@ -8293,6 +8230,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, // Shouldn't matter what state the report aggregation is in )]); @@ -8357,6 +8295,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, )]); @@ -8434,6 +8373,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8442,6 +8382,7 @@ mod tests { *reports[1].metadata().id(), *reports[1].metadata().time(), 0, + None, ReportAggregationState::Start, ), ]); @@ -8516,6 +8457,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8524,6 +8466,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ]); @@ -8679,6 +8622,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8687,6 +8631,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8695,6 +8640,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ]); @@ -9253,6 +9199,7 @@ mod tests { random(), clock.now(), 0, + None, ReportAggregationState::Start, // Counted among max_size. ); let report_aggregation_0_1 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9261,9 +9208,10 @@ mod tests { random(), clock.now(), 1, + None, ReportAggregationState::Waiting( dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Leader(()), + Some(()), ), // Counted among max_size. ); let report_aggregation_0_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9272,6 +9220,7 @@ mod tests { random(), clock.now(), 2, + None, ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. ); let report_aggregation_0_3 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9280,6 +9229,7 @@ mod tests { random(), clock.now(), 3, + None, ReportAggregationState::Invalid, // Not counted among min_size or max_size. ); @@ -9299,7 +9249,8 @@ mod tests { random(), clock.now(), 0, - ReportAggregationState::Finished(dummy_vdaf::OutputShare()), // Counted among min_size and max_size. + None, + ReportAggregationState::Finished(dummy_vdaf::OutputShare(0)), // Counted among min_size and max_size. ); let report_aggregation_1_1 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), @@ -9307,7 +9258,8 @@ mod tests { random(), clock.now(), 1, - ReportAggregationState::Finished(dummy_vdaf::OutputShare()), // Counted among min_size and max_size. + None, + ReportAggregationState::Finished(dummy_vdaf::OutputShare(0)), // Counted among min_size and max_size. ); let report_aggregation_1_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( *task.id(), @@ -9315,6 +9267,7 @@ mod tests { random(), clock.now(), 2, + None, ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. ); let report_aggregation_1_3 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9323,6 +9276,7 @@ mod tests { random(), clock.now(), 3, + None, ReportAggregationState::Invalid, // Not counted among min_size or max_size. ); @@ -9570,6 +9524,7 @@ mod tests { *attached_report.metadata().id(), *attached_report.metadata().time(), 0, + None, ReportAggregationState::<0, dummy_vdaf::Vdaf>::Start, ); @@ -9698,6 +9653,7 @@ mod tests { *report_id, *client_timestamp, ord.try_into().unwrap(), + None, ReportAggregationState::<0, dummy_vdaf::Vdaf>::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -10408,6 +10364,7 @@ mod tests { *report.metadata().id(), *client_timestamp, 0, + None, ReportAggregationState::<0, dummy_vdaf::Vdaf>::Start, ); tx.put_report_aggregation(&report_aggregation) diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index a45ef3608..23168368c 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -586,7 +586,7 @@ pub mod test_util { }; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair}, - task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, VERIFY_KEY_LEN}, time::DurationExt, }; use janus_messages::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId, Time}; @@ -602,7 +602,7 @@ pub mod test_util { // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, + _ => VERIFY_KEY_LEN, } } @@ -801,7 +801,7 @@ mod tests { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VERIFY_KEY_LEN}, test_util::roundtrip_encoding, time::DurationExt, }; @@ -839,7 +839,7 @@ mod tests { QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -861,7 +861,7 @@ mod tests { QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -883,7 +883,7 @@ mod tests { QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -905,7 +905,7 @@ mod tests { QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Helper, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, @@ -929,7 +929,7 @@ mod tests { QueryType::TimeInterval, VdafInstance::Prio3Count, Role::Leader, - Vec::from([SecretBytes::new([0; PRIO3_VERIFY_KEY_LENGTH].into())]), + Vec::from([SecretBytes::new([0; VERIFY_KEY_LEN].into())]), 0, Time::from_seconds_since_epoch(u64::MAX), None, diff --git a/core/src/task.rs b/core/src/task.rs index fe02d350e..52f18da4f 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -7,8 +7,8 @@ use url::Url; /// HTTP header where auth tokens are provided in messages between participants. pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token"; -/// The length of the verify key parameter for Prio3 VDAF instantiations. -pub const PRIO3_VERIFY_KEY_LENGTH: usize = 16; +/// The length of the verify key parameter for Prio3 & Poplar1 VDAF instantiations. +pub const VERIFY_KEY_LEN: usize = 16; /// Identifiers for supported VDAFs, corresponding to definitions in /// [draft-irtf-cfrg-vdaf-03][1] and implementations in [`prio::vdaf::prio3`]. @@ -62,9 +62,8 @@ impl VdafInstance { | VdafInstance::FakeFailsPrepInit | VdafInstance::FakeFailsPrepStep => 0, - // All "real" VDAFs use a verify key of length 16 currently. (Poplar1 may not, but it's - // not yet done being specified, so choosing 16 bytes is fine for testing.) - _ => PRIO3_VERIFY_KEY_LENGTH, + // All "real" VDAFs use a verify key of length 16 currently. + _ => VERIFY_KEY_LEN, } } } @@ -73,35 +72,41 @@ impl VdafInstance { #[macro_export] macro_rules! vdaf_dispatch_impl_base { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match base $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match base $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count => { type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3CountVec { length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Sum { bits } => { type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3SumVec { bits, length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; + $body + } + + ::janus_core::task::VdafInstance::Poplar1 { bits } => { + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -110,12 +115,12 @@ macro_rules! vdaf_dispatch_impl_base { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match base $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match base $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_count(2)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -123,14 +128,14 @@ macro_rules! vdaf_dispatch_impl_base { // Prio3CountVec is implemented as a 1-bit sum vec let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, 1, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Sum { bits } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *bits)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -138,14 +143,21 @@ macro_rules! vdaf_dispatch_impl_base { let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum_vec_multithreaded(2, *bits, *length)?; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVecMultithreaded; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } ::janus_core::task::VdafInstance::Prio3Histogram { buckets } => { let $vdaf = ::prio::vdaf::prio3::Prio3::new_histogram(2, buckets)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; + $body + } + + ::janus_core::task::VdafInstance::Poplar1 { bits } => { + let $vdaf = ::prio::vdaf::poplar1::Poplar1::new_sha3(*bits); + type $Vdaf = ::prio::vdaf::poplar1::Poplar1<::prio::vdaf::prg::PrgSha3, 16>; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -159,13 +171,13 @@ macro_rules! vdaf_dispatch_impl_base { #[macro_export] macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI16<::fixed::types::extra::U15>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -173,7 +185,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI32<::fixed::types::extra::U31>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -181,7 +193,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI64<::fixed::types::extra::U63>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -190,7 +202,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => { let $vdaf = @@ -200,7 +212,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI16<::fixed::types::extra::U15>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -212,7 +224,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI32<::fixed::types::extra::U31>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -224,7 +236,7 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded< ::fixed::FixedI64<::fixed::types::extra::U63>, >; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LEN: usize = ::janus_core::task::VERIFY_KEY_LEN; $body } @@ -238,23 +250,23 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 { #[macro_export] macro_rules! vdaf_dispatch_impl_test_util { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match test_util $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match test_util $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Fake => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepInit => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepStep => { type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -263,12 +275,12 @@ macro_rules! vdaf_dispatch_impl_test_util { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match test_util $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match test_util $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Fake => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new(); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -278,26 +290,30 @@ macro_rules! vdaf_dispatch_impl_test_util { ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( "FakeFailsPrepInit failed at prep_init".to_string(), )) - } + }, ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } ::janus_core::task::VdafInstance::FakeFailsPrepStep => { let $vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf::new().with_prep_step_fn( - || -> Result< - ::prio::vdaf::PrepareTransition<::janus_core::test_util::dummy_vdaf::Vdaf, 0, 16>, + |_| -> Result< + ::prio::vdaf::PrepareTransition< + ::janus_core::test_util::dummy_vdaf::Vdaf, + 0, + 16, + >, ::prio::vdaf::VdafError, > { ::std::result::Result::Err(::prio::vdaf::VdafError::Uncategorized( "FakeFailsPrepStep failed at prep_step".to_string(), )) - } + }, ); type $Vdaf = ::janus_core::test_util::dummy_vdaf::Vdaf; - const $VERIFY_KEY_LENGTH: usize = 0; + const $VERIFY_KEY_LEN: usize = 0; $body } @@ -311,26 +327,27 @@ macro_rules! vdaf_dispatch_impl_test_util { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -338,26 +355,27 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -370,20 +388,21 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -391,20 +410,21 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { .. } | ::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { .. } => { - ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_fpvec_bounded_l2!(impl match fpvec_bounded_l2 $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -417,20 +437,21 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -438,20 +459,21 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } ::janus_core::task::VdafInstance::Fake | ::janus_core::task::VdafInstance::FakeFailsPrepInit | ::janus_core::task::VdafInstance::FakeFailsPrepStep => { - ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ::janus_core::vdaf_dispatch_impl_test_util!(impl match test_util $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -464,14 +486,15 @@ macro_rules! vdaf_dispatch_impl { #[macro_export] macro_rules! vdaf_dispatch_impl { // Provide the dispatched type only, don't construct a VDAF instance. - (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -479,14 +502,15 @@ macro_rules! vdaf_dispatch_impl { }; // Construct a VDAF instance, and provide that to the block as well. - (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + (impl match all $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { match $vdaf_instance { ::janus_core::task::VdafInstance::Prio3Count | ::janus_core::task::VdafInstance::Prio3CountVec { .. } | ::janus_core::task::VdafInstance::Prio3Sum { .. } | ::janus_core::task::VdafInstance::Prio3SumVec { .. } - | ::janus_core::task::VdafInstance::Prio3Histogram { .. } => { - ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + | ::janus_core::task::VdafInstance::Prio3Histogram { .. } + | ::janus_core::task::VdafInstance::Poplar1 { .. } => { + ::janus_core::vdaf_dispatch_impl_base!(impl match base $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) } _ => panic!("VDAF {:?} is not yet supported", $vdaf_instance), @@ -513,21 +537,21 @@ macro_rules! vdaf_dispatch_impl { /// # } /// # fn test() -> Result<(), prio::vdaf::VdafError> { /// # let vdaf = janus_core::task::VdafInstance::Prio3Count; -/// vdaf_dispatch!(&vdaf, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { -/// handle_request_generic::(&vdaf) +/// vdaf_dispatch!(&vdaf, (vdaf, VdafType, VERIFY_KEY_LEN) => { +/// handle_request_generic::(&vdaf) /// }) /// # } /// ``` #[macro_export] macro_rules! vdaf_dispatch { // Provide the dispatched type only, don't construct a VDAF instance. - ($vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ($vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { + ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, (_, $Vdaf, $VERIFY_KEY_LEN) => $body) }; // Construct a VDAF instance, and provide that to the block as well. - ($vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH) => $body) + ($vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LEN:ident) => $body:tt) => { + ::janus_core::vdaf_dispatch_impl!(impl match all $vdaf_instance, ($vdaf, $Vdaf, $VERIFY_KEY_LEN) => $body) }; } diff --git a/core/src/test_util/dummy_vdaf.rs b/core/src/test_util/dummy_vdaf.rs index 7d1583112..6d87ea60e 100644 --- a/core/src/test_util/dummy_vdaf.rs +++ b/core/src/test_util/dummy_vdaf.rs @@ -4,20 +4,24 @@ use prio::{ codec::{CodecError, Decode, Encode}, vdaf::{self, Aggregatable, PrepareTransition, VdafError}, }; +use rand::random; use std::fmt::Debug; use std::io::Cursor; use std::sync::Arc; type ArcPrepInitFn = Arc Result<(), VdafError> + 'static + Send + Sync>; -type ArcPrepStepFn = - Arc Result, VdafError> + 'static + Send + Sync>; +type ArcPrepStepFn = Arc< + dyn Fn(&PrepareState) -> Result, VdafError> + + 'static + + Send + + Sync, +>; #[derive(Clone)] pub struct Vdaf { prep_init_fn: ArcPrepInitFn, prep_step_fn: ArcPrepStepFn, - input_share: InputShare, } impl Debug for Vdaf { @@ -31,15 +35,16 @@ impl Debug for Vdaf { impl Vdaf { /// The length of the verify key parameter for fake VDAF instantiations. - pub const VERIFY_KEY_LENGTH: usize = 0; + pub const VERIFY_KEY_LEN: usize = 0; pub fn new() -> Self { Self { prep_init_fn: Arc::new(|_| -> Result<(), VdafError> { Ok(()) }), - prep_step_fn: Arc::new(|| -> Result, VdafError> { - Ok(PrepareTransition::Finish(OutputShare())) - }), - input_share: InputShare::default(), + prep_step_fn: Arc::new( + |state| -> Result, VdafError> { + Ok(PrepareTransition::Finish(OutputShare(state.0))) + }, + ), } } @@ -54,7 +59,9 @@ impl Vdaf { self } - pub fn with_prep_step_fn Result, VdafError>>( + pub fn with_prep_step_fn< + F: Fn(&PrepareState) -> Result, VdafError>, + >( mut self, f: F, ) -> Self @@ -64,11 +71,6 @@ impl Vdaf { self.prep_step_fn = Arc::new(f); self } - - pub fn with_input_share(mut self, input_share: InputShare) -> Self { - self.input_share = input_share; - self - } } impl Default for Vdaf { @@ -80,8 +82,8 @@ impl Default for Vdaf { impl vdaf::Vdaf for Vdaf { const ID: u32 = 0xFFFF0000; - type Measurement = (); - type AggregateResult = (); + type Measurement = u8; + type AggregateResult = u8; type AggregationParam = AggregationParam; type PublicShare = (); type InputShare = InputShare; @@ -120,10 +122,10 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_step( &self, - _: Self::PrepareState, + state: Self::PrepareState, _: Self::PrepareMessage, ) -> Result, VdafError> { - (self.prep_step_fn)() + (self.prep_step_fn)(&state) } fn aggregate>( @@ -142,10 +144,18 @@ impl vdaf::Aggregator<0, 16> for Vdaf { impl vdaf::Client<16> for Vdaf { fn shard( &self, - _measurement: &Self::Measurement, + measurement: &Self::Measurement, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { - Ok(((), Vec::from([self.input_share, self.input_share]))) + let first_input_share = random(); + let (second_input_share, _) = measurement.overflowing_sub(first_input_share); + Ok(( + (), + Vec::from([ + InputShare(first_input_share), + InputShare(second_input_share), + ]), + )) } } @@ -173,7 +183,7 @@ pub struct AggregationParam(pub u8); impl Encode for AggregationParam { fn encode(&self, bytes: &mut Vec) { - self.0.encode(bytes); + self.0.encode(bytes) } fn encoded_len(&self) -> Option { @@ -188,19 +198,21 @@ impl Decode for AggregationParam { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct OutputShare(); +pub struct OutputShare(pub u8); impl Decode for OutputShare { - fn decode(_: &mut Cursor<&[u8]>) -> Result { - Ok(Self()) + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + Ok(Self(u8::decode(bytes)?)) } } impl Encode for OutputShare { - fn encode(&self, _: &mut Vec) {} + fn encode(&self, bytes: &mut Vec) { + self.0.encode(bytes); + } fn encoded_len(&self) -> Option { - Some(0) + self.0.encoded_len() } } @@ -234,15 +246,15 @@ impl Aggregatable for AggregateShare { Ok(()) } - fn accumulate(&mut self, _: &Self::OutputShare) -> Result<(), VdafError> { - self.0 += 1; + fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> { + self.0 += u64::from(out_share.0); Ok(()) } } impl From for AggregateShare { - fn from(_: OutputShare) -> Self { - Self(1) + fn from(out_share: OutputShare) -> Self { + Self(u64::from(out_share.0)) } } diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 60d2e15db..e150b4c9f 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -32,15 +32,15 @@ pub struct VdafTranscript> VdafTranscript { - /// Get the leader's preparation state at the requested round. - pub fn leader_prep_state(&self, round: usize) -> &V::PrepareState { + /// Get the Leader's preparation state and prepare share at the requested round. + pub fn leader_prep_state(&self, round: usize) -> (&V::PrepareState, &V::PrepareShare) { assert_matches!( &self.prepare_transitions[Role::Leader.index().unwrap()][round], - PrepareTransition::::Continue(prep_state, _) => prep_state + PrepareTransition::::Continue(prep_state, prep_share) => (prep_state, prep_share) ) } - /// Get the helper's preparation state and prepare share at the requested round. + /// Get the Helper's preparation state and prepare share at the requested round. pub fn helper_prep_state(&self, round: usize) -> (&V::PrepareState, &V::PrepareShare) { assert_matches!( &self.prepare_transitions[Role::Helper.index().unwrap()][round], diff --git a/db/schema.sql b/db/schema.sql index 3590dd1c5..65f880ab4 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -140,11 +140,10 @@ CREATE TABLE report_aggregations( ord BIGINT NOT NULL, -- a value used to specify the ordering of client reports in the aggregation job state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation prep_state BYTEA, -- the current preparation state (opaque VDAF message, only if in state WAITING) - prep_msg BYTEA, -- for the leader, the next preparation message to be sent to the helper (opaque VDAF message) - -- for the helper, the next preparation share to be sent to the leader (opaque VDAF message) - -- only non-NULL if in state WAITING + prep_msg BYTEA, -- the next preparation message to be sent to the Helper (opaque VDAF message, populated for Leader only) out_share BYTEA, -- the output share (opaque VDAF message, only if in state FINISHED) error_code SMALLINT, -- error code corresponding to a DAP ReportShareError value; null if in a state other than FAILED + last_prep_step BYTEA, -- the last PreparationStep message sent to the Leader, to assist in replay (opaque VADF message, populated for Helper only) CONSTRAINT report_aggregations_unique_ord UNIQUE(aggregation_job_id, ord), CONSTRAINT fk_aggregation_job_id FOREIGN KEY(aggregation_job_id) REFERENCES aggregation_jobs(id), diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index b4654b703..70a2bc422 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -2,7 +2,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use futures::future::join_all; use janus_core::{ - task::PRIO3_VERIFY_KEY_LENGTH, + task::VERIFY_KEY_LEN, test_util::{install_test_trace_subscriber, testcontainers::container_client}, time::{Clock, RealClock, TimeExt}, }; @@ -109,7 +109,7 @@ async fn run( let task_id: TaskId = random(); let aggregator_auth_token = URL_SAFE_NO_PAD.encode(random::<[u8; 16]>()); let collector_auth_token = URL_SAFE_NO_PAD.encode(random::<[u8; 16]>()); - let vdaf_verify_key = rand::random::<[u8; PRIO3_VERIFY_KEY_LENGTH]>(); + let vdaf_verify_key = rand::random::<[u8; VERIFY_KEY_LEN]>(); let task_id_encoded = URL_SAFE_NO_PAD.encode(task_id.get_encoded()); let vdaf_verify_key_encoded = URL_SAFE_NO_PAD.encode(vdaf_verify_key); diff --git a/messages/Cargo.toml b/messages/Cargo.toml index 52866cc37..da46349d0 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -20,7 +20,7 @@ hex = "0.4" num_enum = "0.5.11" # We can't pull prio in from the workspace because that would enable default features, and we do not # want prio/crypto-dependencies -prio = { version = "0.12.0", default-features = false } +prio = { workspace = true, default-features = false } # XXX: revert to 0.12.0? rand = "0.8" serde = { version = "1.0.158", features = ["derive"] } thiserror = "1.0" diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 231ce0f6e..1d5ec6cf6 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -427,7 +427,6 @@ impl Role { /// If this [`Role`] is one of the aggregators, returns the index at which /// that aggregator's message or data can be found in various lists, or /// `None` if the role is not an aggregator. - // XXX: can this be removed once all messages are updated to have explicit leader/helper fields? pub fn index(&self) -> Option { match self { // draft-gpew-priv-ppm §4.2: the leader's endpoint MUST be the first @@ -2028,6 +2027,57 @@ impl Decode for ReportShare { } } +/// DAP protocol message representing information required to initialize preparation of a report for +/// aggregation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ReportPrepInit { + report_share: ReportShare, + leader_prep_share: Vec, +} + +impl ReportPrepInit { + /// Constructs a new report preparation initialization message from its components. + pub fn new(report_share: ReportShare, leader_prep_share: Vec) -> Self { + Self { + report_share, + leader_prep_share, + } + } + + /// Gets the report share associated with this report prep init. + pub fn report_share(&self) -> &ReportShare { + &self.report_share + } + + /// Gets the leader preparation share associated with this report prep init. + pub fn leader_prep_share(&self) -> &[u8] { + &self.leader_prep_share + } +} + +impl Encode for ReportPrepInit { + fn encode(&self, bytes: &mut Vec) { + self.report_share.encode(bytes); + encode_u32_items(bytes, &(), &self.leader_prep_share); + } + + fn encoded_len(&self) -> Option { + Some(self.report_share.encoded_len()? + 4 + self.leader_prep_share.len()) + } +} + +impl Decode for ReportPrepInit { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let report_share = ReportShare::decode(bytes)?; + let leader_prep_share = decode_u32_items(&(), bytes)?; + + Ok(Self { + report_share, + leader_prep_share, + }) + } +} + /// DAP protocol message representing the result of a preparation step in a VDAF evaluation. #[derive(Clone, Debug, PartialEq, Eq)] pub struct PrepareStep { @@ -2077,8 +2127,16 @@ impl Decode for PrepareStep { #[derive(Clone, Derivative, PartialEq, Eq)] #[derivative(Debug)] pub enum PrepareStepResult { - Continued(#[derivative(Debug = "ignore")] Vec), // content is a serialized preparation message - Finished, + Continued { + #[derivative(Debug = "ignore")] + prep_msg: Vec, + #[derivative(Debug = "ignore")] + prep_share: Vec, + }, + Finished { + #[derivative(Debug = "ignore")] + prep_msg: Vec, + }, Failed(ReportShareError), } @@ -2087,11 +2145,18 @@ impl Encode for PrepareStepResult { // The encoding includes an implicit discriminator byte, called PrepareStepResult in the // DAP spec. match self { - Self::Continued(vdaf_msg) => { + Self::Continued { + prep_msg, + prep_share, + } => { 0u8.encode(bytes); - encode_u32_items(bytes, &(), vdaf_msg); + encode_u32_items(bytes, &(), prep_msg); + encode_u32_items(bytes, &(), prep_share); + } + Self::Finished { prep_msg } => { + 1u8.encode(bytes); + encode_u32_items(bytes, &(), prep_msg); } - Self::Finished => 1u8.encode(bytes), Self::Failed(error) => { 2u8.encode(bytes); error.encode(bytes); @@ -2101,8 +2166,11 @@ impl Encode for PrepareStepResult { fn encoded_len(&self) -> Option { match self { - PrepareStepResult::Continued(vdaf_msg) => Some(1 + 4 + vdaf_msg.len()), - PrepareStepResult::Finished => Some(1), + PrepareStepResult::Continued { + prep_msg, + prep_share, + } => Some(1 + 4 + prep_msg.len() + 4 + prep_share.len()), + PrepareStepResult::Finished { prep_msg } => Some(1 + 4 + prep_msg.len()), PrepareStepResult::Failed(error) => Some(1 + error.encoded_len()?), } } @@ -2112,8 +2180,18 @@ impl Decode for PrepareStepResult { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let val = u8::decode(bytes)?; Ok(match val { - 0 => Self::Continued(decode_u32_items(&(), bytes)?), - 1 => Self::Finished, + 0 => { + let prep_msg = decode_u32_items(&(), bytes)?; + let prep_share = decode_u32_items(&(), bytes)?; + Self::Continued { + prep_msg, + prep_share, + } + } + 1 => { + let prep_msg = decode_u32_items(&(), bytes)?; + Self::Finished { prep_msg } + } 2 => Self::Failed(ReportShareError::decode(bytes)?), _ => return Err(CodecError::UnexpectedValue), }) @@ -2227,7 +2305,7 @@ pub struct AggregationJobInitializeReq { #[derivative(Debug = "ignore")] aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + report_inits: Vec, } impl AggregationJobInitializeReq { @@ -2238,12 +2316,12 @@ impl AggregationJobInitializeReq { pub fn new( aggregation_parameter: Vec, partial_batch_selector: PartialBatchSelector, - report_shares: Vec, + report_inits: Vec, ) -> Self { Self { aggregation_parameter, partial_batch_selector, - report_shares, + report_inits, } } @@ -2257,9 +2335,10 @@ impl AggregationJobInitializeReq { &self.partial_batch_selector } - /// Gets the report shares associated with this aggregate initialization request. - pub fn report_shares(&self) -> &[ReportShare] { - &self.report_shares + /// Gets the report preparation initialization messages associated with this aggregate + /// initialization request. + pub fn report_inits(&self) -> &[ReportPrepInit] { + &self.report_inits } } @@ -2267,15 +2346,15 @@ impl Encode for AggregationJobInitializeReq { fn encode(&self, bytes: &mut Vec) { encode_u32_items(bytes, &(), &self.aggregation_parameter); self.partial_batch_selector.encode(bytes); - encode_u32_items(bytes, &(), &self.report_shares); + encode_u32_items(bytes, &(), &self.report_inits); } fn encoded_len(&self) -> Option { let mut length = 4 + self.aggregation_parameter.len(); length += self.partial_batch_selector.encoded_len()?; length += 4; - for report_share in self.report_shares.iter() { - length += report_share.encoded_len()?; + for report_init in &self.report_inits { + length += report_init.encoded_len()?; } Some(length) } @@ -2285,12 +2364,12 @@ impl Decode for AggregationJobInitializeReq { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let aggregation_parameter = decode_u32_items(&(), bytes)?; let partial_batch_selector = PartialBatchSelector::decode(bytes)?; - let report_shares = decode_u32_items(&(), bytes)?; + let report_inits = decode_u32_items(&(), bytes)?; Ok(Self { aggregation_parameter, partial_batch_selector, - report_shares, + report_inits, }) } } @@ -2659,8 +2738,8 @@ mod tests { Extension, ExtensionType, FixedSize, FixedSizeQuery, HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Query, Report, - ReportId, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, - Time, TimeInterval, + ReportId, ReportIdChecksum, ReportMetadata, ReportPrepInit, ReportShare, ReportShareError, + Role, TaskId, Time, TimeInterval, }; use assert_matches::assert_matches; use prio::codec::{CodecError, Decode, Encode}; @@ -3763,6 +3842,200 @@ mod tests { ]) } + #[test] + fn roundtrip_report_share() { + roundtrip_encoding(&[ + ( + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + concat!( + concat!( + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + ), + ), + ( + ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + concat!( + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + ), + ]) + } + + #[test] + fn roundtrip_report_prep_init() { + roundtrip_encoding(&[ + ( + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + leader_prep_share: Vec::from("012345"), + }, + concat!( + concat!( + // report_share + concat!( + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time + ), + concat!( + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), + ), + ), + concat!( + // leader_prep_share + "00000006", // length + "303132333435", // opaque data + ) + ), + ), + ( + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + leader_prep_share: Vec::new(), + }, + concat!( + concat!( + // report_share + concat!( + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time + ), + concat!( + // public_share + "00000004", // length + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), + ), + ), + concat!( + // leader_prep_share + "00000000", // length + "" // opaque data + ) + ), + ), + ]) + } + #[test] fn roundtrip_prepare_step() { roundtrip_encoding(&[ @@ -3771,16 +4044,24 @@ mod tests { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continued { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("543210"), + }, }, concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // vdaf_msg + // prep_msg "00000006", // length "303132333435", // opaque data ), + concat!( + // prep_share + "00000006", // length + "353433323130", // opaque data + ), ), ), ( @@ -3788,11 +4069,18 @@ mod tests { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + result: PrepareStepResult::Finished { + prep_msg: Vec::from("012345"), + }, }, concat!( "100F0E0D0C0B0A090807060504030201", // report_id "01", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), ), ), ( @@ -3828,30 +4116,40 @@ mod tests { AggregationJobInitializeReq { aggregation_parameter: Vec::from("012345"), partial_batch_selector: PartialBatchSelector::new_time_interval(), - report_shares: Vec::from([ - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - public_share: Vec::new(), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + report_inits: Vec::from([ + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + leader_prep_share: Vec::from("012345"), }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + leader_prep_share: Vec::new(), }, ]), }, @@ -3866,58 +4164,74 @@ mod tests { "01", // query_type ), concat!( - // report_shares - "0000005E", // length + // report_inits + "0000006C", // length concat!( concat!( - // metadata - "0102030405060708090A0B0C0D0E0F10", // report_id - "000000000000D431", // time - ), - concat!( - // public_share - "00000000", // length - "", // opaque data - ), - concat!( - // encrypted_input_share - "2A", // config_id + // report_share concat!( - // encapsulated_context - "0006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time ), concat!( - // payload - "00000006", // length - "353433323130", // opaque data + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000006", // length + "303132333435", // opaque data + ) ), concat!( concat!( - // metadata - "100F0E0D0C0B0A090807060504030201", // report_id - "0000000000011F46", // time - ), - concat!( - "00000004", // payload - "30313233", // opaque data - ), - concat!( - // encrypted_input_share - "0D", // config_id concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time ), concat!( - // payload + // public_share "00000004", // length - "61626664", // opaque data + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000000", // length + "" // opaque data + ) ), ), ), @@ -3930,30 +4244,40 @@ mod tests { partial_batch_selector: PartialBatchSelector::new_fixed_size(BatchId::from( [2u8; 32], )), - report_shares: Vec::from([ - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(54321), - ), - public_share: Vec::new(), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), + report_inits: Vec::from([ + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + Time::from_seconds_since_epoch(54321), + ), + public_share: Vec::new(), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(42), + Vec::from("012345"), + Vec::from("543210"), + ), + }, + leader_prep_share: Vec::from("012345"), }, - ReportShare { - metadata: ReportMetadata::new( - ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), - Time::from_seconds_since_epoch(73542), - ), - public_share: Vec::from("0123"), - encrypted_input_share: HpkeCiphertext::new( - HpkeConfigId::from(13), - Vec::from("abce"), - Vec::from("abfd"), - ), + ReportPrepInit { + report_share: ReportShare { + metadata: ReportMetadata::new( + ReportId::from([ + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + ]), + Time::from_seconds_since_epoch(73542), + ), + public_share: Vec::from("0123"), + encrypted_input_share: HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("abce"), + Vec::from("abfd"), + ), + }, + leader_prep_share: Vec::new(), }, ]), }, @@ -3969,59 +4293,74 @@ mod tests { "0202020202020202020202020202020202020202020202020202020202020202", // batch_id ), concat!( - // report_shares - "0000005E", // length + // report_inits + "0000006C", // length concat!( concat!( - // metadata - "0102030405060708090A0B0C0D0E0F10", // report_id - "000000000000D431", // time - ), - concat!( - // public_share - "00000000", // length - "", // opaque data - ), - concat!( - // encrypted_input_share - "2A", // config_id + // report_share concat!( - // encapsulated_context - "0006", // length - "303132333435", // opaque data + // metadata + "0102030405060708090A0B0C0D0E0F10", // report_id + "000000000000D431", // time ), concat!( - // payload - "00000006", // length - "353433323130", // opaque data + // public_share + "00000000", // length + "", // opaque data + ), + concat!( + // encrypted_input_share + "2A", // config_id + concat!( + // encapsulated_context + "0006", // length + "303132333435", // opaque data + ), + concat!( + // payload + "00000006", // length + "353433323130", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000006", // length + "303132333435", // opaque data + ) ), concat!( concat!( - // metadata - "100F0E0D0C0B0A090807060504030201", // report_id - "0000000000011F46", // time - ), - concat!( - // public_share - "00000004", // length - "30313233", // opaque data - ), - concat!( - // encrypted_input_share - "0D", // config_id concat!( - // encapsulated_context - "0004", // length - "61626365", // opaque data + // metadata + "100F0E0D0C0B0A090807060504030201", // report_id + "0000000000011F46", // time ), concat!( - // payload + // public_share "00000004", // length - "61626664", // opaque data + "30313233", // opaque data + ), + concat!( + // encrypted_input_share + "0D", // config_id + concat!( + // encapsulated_context + "0004", // length + "61626365", // opaque data + ), + concat!( + // payload + "00000004", // length + "61626664", // opaque data + ), ), ), + concat!( + // leader_prep_share + "00000000", // length + "" // opaque data + ) ), ), ), @@ -4038,13 +4377,18 @@ mod tests { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continued { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("543210"), + }, }, PrepareStep { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + result: PrepareStepResult::Finished { + prep_msg: Vec::from("012345"), + }, }, ]), }, @@ -4052,19 +4396,29 @@ mod tests { "A5A5", // round concat!( // prepare_steps - "0000002C", // length + "00000040", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // payload + // prep_msg "00000006", // length "303132333435", // opaque data ), + concat!( + // prep_share + "00000006", // length + "353433323130", // opaque data + ), ), concat!( "100F0E0D0C0B0A090807060504030201", // report_id "01", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), ) ), ), @@ -4080,31 +4434,46 @@ mod tests { report_id: ReportId::from([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ]), - result: PrepareStepResult::Continued(Vec::from("012345")), + result: PrepareStepResult::Continued { + prep_msg: Vec::from("012345"), + prep_share: Vec::from("543210"), + }, }, PrepareStep { report_id: ReportId::from([ 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, ]), - result: PrepareStepResult::Finished, + result: PrepareStepResult::Finished { + prep_msg: Vec::from("012345"), + }, }, ]), }, concat!(concat!( // prepare_steps - "0000002C", // length + "00000040", // length concat!( "0102030405060708090A0B0C0D0E0F10", // report_id "00", // prepare_step_result concat!( - // payload + // prep_msg "00000006", // length "303132333435", // opaque data ), + concat!( + // prep_share + "00000006", // length + "353433323130", // opaque data + ), ), concat!( "100F0E0D0C0B0A090807060504030201", // report_id "01", // prepare_step_result + concat!( + // prep_msg + "00000006", // length + "303132333435", // opaque data + ), ) ),), )]) From cc98fcdf3d4aa56a8c0460215dfe108a0760a399 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Mon, 10 Apr 2023 15:59:43 -0700 Subject: [PATCH 6/8] cargo fmt --- aggregator/src/aggregator.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index e11de890c..2618497da 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -62,7 +62,7 @@ use prio::{ poplar1::Poplar1, prg::PrgSha3, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVecMultithreaded}, - VdafError, PrepareTransition + PrepareTransition, VdafError, }, }; use regex::Regex; @@ -1379,12 +1379,15 @@ impl VdafOps { if incoming_report_share_data .iter() .zip(existing_report_aggregations) - .any(|(incoming_report_share_data, existing_report_aggregation)| { - !existing_report_aggregation - .report_metadata() - .eq(incoming_report_share_data.report_share.metadata()) - || !existing_report_aggregation.eq(&incoming_report_share_data.report_aggregation) - }) + .any( + |(incoming_report_share_data, existing_report_aggregation)| { + !existing_report_aggregation + .report_metadata() + .eq(incoming_report_share_data.report_share.metadata()) + || !existing_report_aggregation + .eq(&incoming_report_share_data.report_aggregation) + }, + ) { return Ok(false); } @@ -1751,7 +1754,7 @@ impl VdafOps { .with_last_prep_step(Some(PrepareStep::new( *report_share_data.report_share.metadata().id(), PrepareStepResult::Failed(ReportShareError::ReportReplayed)) - )); + )); }, err => return Err(err), } From 592be4c5d86e66f149b2151084324aed30604a75 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Tue, 11 Apr 2023 09:09:22 -0700 Subject: [PATCH 7/8] Remove unnecessary/redundant logic check. This came out of a code-editing error when refactoring this section of code. --- aggregator/src/aggregator.rs | 59 ++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 2618497da..acc32dab2 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -1738,40 +1738,33 @@ impl VdafOps { for report_share_data in &mut report_share_data { - if !replayed_request { - // Write client report & report aggregation. - if let Err(err) = tx.put_report_share( - task.id(), - &report_share_data.report_share - ).await { - match err { - datastore::Error::MutationTargetAlreadyExists => { - report_share_data.report_aggregation = - report_share_data.report_aggregation - .clone() - .with_state(ReportAggregationState::Failed( - ReportShareError::ReportReplayed)) - .with_last_prep_step(Some(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed)) - )); - }, - err => return Err(err), - } + // Write client report & report aggregation. + if let Err(err) = tx.put_report_share(task.id(), &report_share_data.report_share).await { + match err { + datastore::Error::MutationTargetAlreadyExists => { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportReplayed)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::ReportReplayed)) + )); + }, + err => return Err(err), } - tx.put_report_aggregation(&report_share_data.report_aggregation).await?; - - if let ReportAggregationState::::Finished(output_share) = - report_share_data.report_aggregation.state() - { - accumulator.update( - aggregation_job.partial_batch_identifier(), - report_share_data.report_share.metadata().id(), - report_share_data.report_share.metadata().time(), - output_share, - )?; - } - + } + tx.put_report_aggregation(&report_share_data.report_aggregation).await?; + + if let ReportAggregationState::Finished(output_share) = report_share_data.report_aggregation.state() + { + accumulator.update( + aggregation_job.partial_batch_identifier(), + report_share_data.report_share.metadata().id(), + report_share_data.report_share.metadata().time(), + output_share, + )?; } } From c9c9cf6450c128adab4fa56f4fa8454e75a28b03 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Tue, 11 Apr 2023 09:31:12 -0700 Subject: [PATCH 8/8] Change prio version reference. It's necessary to reference an unreleased version of libprio-rs until https://github.com/divviup/libprio-rs/commit/54a46230615d28c7e131d0595cc558e1619b8071 is released. I need to change the reference since that PR has now been merged, and the underlying branch deleted. --- Cargo.lock | 2 +- Cargo.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 27a0ea268..aa44eadd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3090,7 +3090,7 @@ dependencies = [ [[package]] name = "prio" version = "0.12.0" -source = "git+https://github.com/divviup/libprio-rs.git?branch=bran/encode-poplar1-state#e0357075d880d209f1021b1e9282261533678ba7" +source = "git+https://github.com/divviup/libprio-rs.git?rev=54a46230615d28c7e131d0595cc558e1619b8071#54a46230615d28c7e131d0595cc558e1619b8071" dependencies = [ "aes", "base64 0.21.0", diff --git a/Cargo.toml b/Cargo.toml index 640ecc3e3..5a0f3f04a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,8 +39,8 @@ janus_interop_binaries = { version = "0.4", path = "interop_binaries" } janus_messages = { version = "0.4", path = "messages" } k8s-openapi = { version = "0.16.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube` kube = { version = "0.75.0", default-features = false, features = ["client"] } -# prio = { version = "0.12.0", features = ["multithreaded"] } # XXX -prio = { git = "https://github.com/divviup/libprio-rs.git", branch = "bran/encode-poplar1-state", features = ["multithreaded", "experimental"] } +# prio = { version = "0.12.0", features = ["multithreaded"] } # XXX: go back to a released version of prio once https://github.com/divviup/libprio-rs/commit/54a46230615d28c7e131d0595cc558e1619b8071 is released. +prio = { git = "https://github.com/divviup/libprio-rs.git", rev = "54a46230615d28c7e131d0595cc558e1619b8071", features = ["multithreaded", "experimental"] } trillium = "0.2.8" trillium-api = { version = "0.2.0-rc.2", default-features = false } trillium-head = "0.2.0"