diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 1c8fb2dc7..36d962451 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -22,8 +22,8 @@ use janus_aggregator_core::{ self, models::{ AggregateShareJob, AggregationJob, AggregationJobState, BatchAggregation, - CollectionJob, CollectionJobState, LeaderStoredReport, PrepareMessageOrShare, - ReportAggregation, ReportAggregationState, + CollectionJob, CollectionJobState, LeaderStoredReport, ReportAggregation, + ReportAggregationState, }, Datastore, Transaction, }, @@ -112,6 +112,7 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { // Initialize counters with desired status labels. This causes Prometheus to see the first // non-zero value we record. for failure_type in [ + "missing_prepare_message", "missing_leader_input_share", "missing_helper_input_share", "prepare_init_failure", @@ -1106,26 +1107,16 @@ where { report_share: ReportShare, report_aggregation: ReportAggregation, - prep_result: PrepareStepResult, - existing_report_aggregation: bool, - conflicting_aggregate_share: bool, } impl> ReportShareData where A: vdaf::Aggregator, { - fn new( - report_share: ReportShare, - report_aggregation: ReportAggregation, - prep_result: PrepareStepResult, - ) -> Self { + fn new(report_share: ReportShare, report_aggregation: ReportAggregation) -> Self { Self { report_share, report_aggregation, - prep_result, - existing_report_aggregation: false, - conflicting_aggregate_share: false, } } } @@ -1178,16 +1169,6 @@ impl VdafOps { ) .await?; - // Filter out any report shares in the incoming message that wouldn't get written out: we - // don't expect to see those in the datastore. - let incoming_report_share_data: Vec<_> = incoming_report_share_data - .iter() - .filter(|report_share_data| { - !report_share_data.existing_report_aggregation - && !report_share_data.conflicting_aggregate_share - }) - .collect(); - if existing_report_aggregations.len() != incoming_report_share_data.len() { return Ok(false); } @@ -1199,12 +1180,15 @@ impl VdafOps { if incoming_report_share_data .iter() .zip(existing_report_aggregations) - .any(|(incoming_report_share, existing_report_share)| { - !existing_report_share - .report_metadata() - .eq(incoming_report_share.report_share.metadata()) - || !existing_report_share.eq(&incoming_report_share.report_aggregation) - }) + .any( + |(incoming_report_share_data, existing_report_aggregation)| { + !existing_report_aggregation + .report_metadata() + .eq(incoming_report_share_data.report_share.metadata()) + || !existing_report_aggregation + .eq(&incoming_report_share_data.report_aggregation) + }, + ) { return Ok(false); } @@ -1370,12 +1354,12 @@ impl VdafOps { *report_share.metadata().id(), *report_share.metadata().time(), ord.try_into()?, - ReportAggregationState::::Waiting( - prep_state, - PrepareMessageOrShare::Helper(prep_share), - ), + Some(PrepareStep::new( + *report_share.metadata().id(), + PrepareStepResult::Continued(encoded_prep_share), + )), + ReportAggregationState::::Waiting(prep_state, None), ), - PrepareStepResult::Continued(encoded_prep_share), ) } @@ -1387,9 +1371,12 @@ impl VdafOps { *report_share.metadata().id(), *report_share.metadata().time(), ord.try_into()?, + Some(PrepareStep::new( + *report_share.metadata().id(), + PrepareStepResult::Failed(err), + )), ReportAggregationState::::Failed(err), ), - PrepareStepResult::Failed(err), ), }); } @@ -1428,7 +1415,7 @@ impl VdafOps { AggregationJobRound::from(0), )); - let prep_steps = datastore + Ok(datastore .run_tx_with_name("aggregate_init", |tx| { let (vdaf, task, req, aggregation_job, mut report_share_data) = ( vdaf.clone(), @@ -1439,14 +1426,14 @@ impl VdafOps { ); Box::pin(async move { - for mut share_data in report_share_data.iter_mut() { + for mut report_share_data in &mut report_share_data { // Verify that we haven't seen this report ID and aggregation parameter // before in another aggregation job, and that the report isn't for a batch // interval that has already started collection. let (report_aggregation_exists, conflicting_aggregate_share_jobs) = try_join!( tx.check_other_report_aggregation_exists::( task.id(), - share_data.report_share.metadata().id(), + report_share_data.report_share.metadata().id(), aggregation_job.aggregation_parameter(), aggregation_job.id(), ), @@ -1455,12 +1442,31 @@ impl VdafOps { &vdaf, task.id(), req.batch_selector().batch_identifier(), - share_data.report_share.metadata() + report_share_data.report_share.metadata() ), )?; - share_data.existing_report_aggregation = report_aggregation_exists; - share_data.conflicting_aggregate_share = !conflicting_aggregate_share_jobs.is_empty(); + if report_aggregation_exists { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportReplayed)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::ReportReplayed)) + )); + } else if !conflicting_aggregate_share_jobs.is_empty() { + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::BatchCollected)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::BatchCollected)) + )); + } } // Write aggregation job. @@ -1493,77 +1499,52 @@ impl VdafOps { // Construct a response and write any new report shares and report aggregations // as we go. - let mut accumulator = Accumulator::::new( - Arc::clone(&task), - batch_aggregation_shard_count, - aggregation_job.aggregation_parameter().clone(), - ); - - let mut prep_steps = Vec::new(); - for report_share_data in report_share_data - { - if report_share_data.existing_report_aggregation { - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed), - )); - continue; - } - if report_share_data.conflicting_aggregate_share { - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), - )); - continue; - } + if !replayed_request { + let mut accumulator = Accumulator::::new( + Arc::clone(&task), + batch_aggregation_shard_count, + aggregation_job.aggregation_parameter().clone(), + ); - if !replayed_request { + for report_share_data in &mut report_share_data + { // Write client report & report aggregation. - if let Err(error) = tx.put_report_share( - task.id(), - &report_share_data.report_share - ).await { - match error { + if let Err(err) = tx.put_report_share(task.id(), &report_share_data.report_share).await { + match err { datastore::Error::MutationTargetAlreadyExists => { - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - PrepareStepResult::Failed(ReportShareError::ReportReplayed), - )); - continue; - } - e => return Err(e), + report_share_data.report_aggregation = + report_share_data.report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportReplayed)) + .with_last_prep_step(Some(PrepareStep::new( + *report_share_data.report_share.metadata().id(), + PrepareStepResult::Failed(ReportShareError::ReportReplayed)) + )); + }, + err => return Err(err), } } tx.put_report_aggregation(&report_share_data.report_aggregation).await?; - } - if let ReportAggregationState::::Finished(output_share) = - report_share_data.report_aggregation.state() - { - accumulator.update( - aggregation_job.partial_batch_identifier(), - report_share_data.report_share.metadata().id(), - report_share_data.report_share.metadata().time(), - output_share, - )?; + if let ReportAggregationState::Finished(output_share) = report_share_data.report_aggregation.state() + { + accumulator.update( + aggregation_job.partial_batch_identifier(), + report_share_data.report_share.metadata().id(), + report_share_data.report_share.metadata().time(), + output_share, + )?; + } } - prep_steps.push(PrepareStep::new( - *report_share_data.report_share.metadata().id(), - report_share_data.prep_result.clone(), - )); - } - - if !replayed_request { accumulator.flush_to_datastore(tx, &vdaf).await?; } - Ok(prep_steps) + + Ok(Self::aggregation_job_resp_for(report_share_data.into_iter().map(|data| data.report_aggregation))) }) }) - .await?; - - // Construct response and return. - Ok(AggregationJobResp::new(prep_steps)) + .await?) } async fn handle_aggregate_continue_generic< @@ -1662,9 +1643,7 @@ impl VdafOps { } } } - return Self::replay_aggregation_job_round::( - report_aggregations, - ); + return Ok(Self::aggregation_job_resp_for(report_aggregations)); } else if helper_aggregation_job.round().increment() != leader_aggregation_job.round() { @@ -1685,14 +1664,14 @@ impl VdafOps { // compute the next round of prepare messages and state. Self::step_aggregation_job( tx, - &task, - &vdaf, + task, + vdaf, batch_aggregation_shard_count, helper_aggregation_job, report_aggregations, - &leader_aggregation_job, + leader_aggregation_job, request_hash, - &aggregate_step_failure_counter, + aggregate_step_failure_counter, ) .await }) diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index ae6b5e87b..72a433c92 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -1,13 +1,11 @@ //! Implements portions of aggregation job continuation for the helper. use crate::aggregator::{accumulator::Accumulator, Error, VdafOps}; +use futures::future::try_join_all; use janus_aggregator_core::{ datastore::{ self, - models::{ - AggregationJob, AggregationJobState, PrepareMessageOrShare, ReportAggregation, - ReportAggregationState, - }, + models::{AggregationJob, AggregationJobState, ReportAggregation, ReportAggregationState}, Transaction, }, query_type::AccumulableQueryType, @@ -23,6 +21,7 @@ use prio::{ vdaf::{self, PrepareTransition}, }; use std::{io::Cursor, sync::Arc}; +use tokio::try_join; use tracing::info; impl VdafOps { @@ -31,14 +30,14 @@ impl VdafOps { /// `leader_aggregation_job`. pub(super) async fn step_aggregation_job( tx: &Transaction<'_, C>, - task: &Arc, - vdaf: &Arc, + task: Arc, + vdaf: Arc, batch_aggregation_shard_count: u64, helper_aggregation_job: AggregationJob, - report_aggregations: Vec>, - leader_aggregation_job: &Arc, + mut report_aggregations: Vec>, + leader_aggregation_job: Arc, request_hash: [u8; 32], - aggregate_step_failure_counter: &Counter, + aggregate_step_failure_counter: Counter, ) -> Result where C: Clock, @@ -47,11 +46,9 @@ impl VdafOps { for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, { // Handle each transition in the request. - let mut report_aggregations = report_aggregations.into_iter(); - let (mut saw_continue, mut saw_finish) = (false, false); - let mut response_prep_steps = Vec::new(); + let mut report_aggregations_iter = report_aggregations.iter_mut(); let mut accumulator = Accumulator::::new( - Arc::clone(task), + Arc::clone(&task), batch_aggregation_shard_count, helper_aggregation_job.aggregation_parameter().clone(), ); @@ -60,7 +57,7 @@ impl VdafOps { // Match preparation step received from leader to stored report aggregation, and extract // the stored preparation step. let report_aggregation = loop { - let report_agg = report_aggregations.next().ok_or_else(|| { + let report_agg = report_aggregations_iter.next().ok_or_else(|| { datastore::Error::User( Error::UnrecognizedMessage( Some(*task.id()), @@ -73,10 +70,12 @@ impl VdafOps { // This report was omitted by the leader because of a prior failure. Note that // the report was dropped (if it's not already in an error state) and continue. if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { - tx.update_report_aggregation(&report_agg.with_state( - ReportAggregationState::Failed(ReportShareError::ReportDropped), - )) - .await?; + *report_agg = report_agg + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportDropped, + )) + .with_last_prep_step(None); } continue; } @@ -86,20 +85,21 @@ impl VdafOps { // Make sure this report isn't in an interval that has already started collection. let conflicting_aggregate_share_jobs = tx .get_aggregate_share_jobs_including_time::( - vdaf, + &vdaf, task.id(), report_aggregation.time(), ) .await?; if !conflicting_aggregate_share_jobs.is_empty() { - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::BatchCollected), - )); - tx.update_report_aggregation(&report_aggregation.with_state( - ReportAggregationState::Failed(ReportShareError::BatchCollected), - )) - .await?; + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::BatchCollected, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::BatchCollected), + ))); continue; } @@ -133,33 +133,32 @@ impl VdafOps { } }; - // Compute the next transition, prepare to respond & update DB. - let next_state = match vdaf.prepare_step(prep_state.clone(), prep_msg) { + // Compute the next transition. + match vdaf.prepare_step(prep_state.clone(), prep_msg) { Ok(PrepareTransition::Continue(prep_state, prep_share)) => { - saw_continue = true; - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Continued(prep_share.get_encoded()), - )); - ReportAggregationState::Waiting( - prep_state, - PrepareMessageOrShare::Helper(prep_share), - ) + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Waiting(prep_state, None)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Continued(prep_share.get_encoded()), + ))); } Ok(PrepareTransition::Finish(output_share)) => { - saw_finish = true; accumulator.update( helper_aggregation_job.partial_batch_identifier(), prep_step.report_id(), report_aggregation.time(), &output_share, )?; - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Finished, - )); - ReportAggregationState::Finished(output_share) + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Finished(output_share)) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Finished, + ))); } Err(error) => { @@ -174,29 +173,44 @@ impl VdafOps { 1, &[KeyValue::new("type", "prepare_step_failure")], ); - response_prep_steps.push(PrepareStep::new( - *prep_step.report_id(), - PrepareStepResult::Failed(ReportShareError::VdafPrepError), - )); - ReportAggregationState::Failed(ReportShareError::VdafPrepError) + *report_aggregation = report_aggregation + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::VdafPrepError, + )) + .with_last_prep_step(Some(PrepareStep::new( + *prep_step.report_id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError), + ))) } }; - - tx.update_report_aggregation(&report_aggregation.with_state(next_state)) - .await?; } - for report_agg in report_aggregations { + for report_agg in report_aggregations_iter { // This report was omitted by the leader because of a prior failure. Note that the // report was dropped (if it's not already in an error state) and continue. if matches!(report_agg.state(), ReportAggregationState::Waiting(_, _)) { - tx.update_report_aggregation(&report_agg.with_state( - ReportAggregationState::Failed(ReportShareError::ReportDropped), - )) - .await?; + *report_agg = report_agg + .clone() + .with_state(ReportAggregationState::Failed( + ReportShareError::ReportDropped, + )) + .with_last_prep_step(None); } } + let saw_continue = report_aggregations.iter().any(|report_agg| { + matches!( + report_agg.last_prep_step().map(PrepareStep::result), + Some(PrepareStepResult::Continued(_)) + ) + }); + let saw_finish = report_aggregations.iter().any(|report_agg| { + matches!( + report_agg.last_prep_step().map(PrepareStep::result), + Some(PrepareStepResult::Finished) + ) + }); let helper_aggregation_job = helper_aggregation_job // Advance the job to the leader's round .with_round(leader_aggregation_job.round()) @@ -215,53 +229,33 @@ impl VdafOps { } }) .with_last_continue_request_hash(request_hash); - tx.update_aggregation_job(&helper_aggregation_job).await?; - accumulator.flush_to_datastore(tx, vdaf).await?; + try_join!( + tx.update_aggregation_job(&helper_aggregation_job), + try_join_all( + report_aggregations + .iter() + .map(|ra| tx.update_report_aggregation(ra)) + ), + accumulator.flush_to_datastore(tx, &vdaf), + )?; - Ok(AggregationJobResp::new(response_prep_steps)) + Ok(Self::aggregation_job_resp_for(report_aggregations)) } - /// Fetch previously-computed prepare message shares and replay them back to the leader. - pub(super) fn replay_aggregation_job_round( - report_aggregations: Vec>, - ) -> Result - where - C: Clock, - Q: AccumulableQueryType, - A: vdaf::Aggregator + 'static + Send + Sync, - for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, - { - let response_prep_steps = report_aggregations - .iter() - .map(|report_aggregation| { - let prepare_step_state = match report_aggregation.state() { - ReportAggregationState::Waiting(_, prep_msg) => PrepareStepResult::Continued( - prep_msg.get_helper_prepare_share()?.get_encoded(), - ), - ReportAggregationState::Finished(_) => PrepareStepResult::Finished, - ReportAggregationState::Failed(report_share_error) => { - PrepareStepResult::Failed(*report_share_error) - } - state => { - return Err(datastore::Error::User( - Error::Internal(format!( - "report aggregation {} unexpectedly in state {state:?}", - report_aggregation.report_id() - )) - .into(), - )); - } - }; - - Ok(PrepareStep::new( - *report_aggregation.report_id(), - prepare_step_state, - )) - }) - .collect::>()?; - - Ok(AggregationJobResp::new(response_prep_steps)) + /// Constructs an AggregationJobResp from a given set of Helper report aggregations. + pub(super) fn aggregation_job_resp_for< + const SEED_SIZE: usize, + A: vdaf::Aggregator, + >( + report_aggregations: impl IntoIterator>, + ) -> AggregationJobResp { + AggregationJobResp::new( + report_aggregations + .into_iter() + .filter_map(|ra| ra.last_prep_step().cloned()) + .collect(), + ) } } @@ -380,8 +374,7 @@ mod tests { use janus_aggregator_core::{ datastore::{ models::{ - AggregationJob, AggregationJobState, PrepareMessageOrShare, ReportAggregation, - ReportAggregationState, + AggregationJob, AggregationJobState, ReportAggregation, ReportAggregationState, }, test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, @@ -456,17 +449,15 @@ mod tests { .await .unwrap(); - let (prep_state, prep_share) = report.1.helper_prep_state(0); + let (prep_state, _) = report.1.helper_prep_state(0); tx.put_report_aggregation::<0, dummy_vdaf::Vdaf>(&ReportAggregation::new( *task.id(), aggregation_job_id, *report.0.metadata().id(), *report.0.metadata().time(), 0, - ReportAggregationState::Waiting( - *prep_state, - PrepareMessageOrShare::Helper(*prep_share), - ), + None, + ReportAggregationState::Waiting(*prep_state, None), )) .await .unwrap(); diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index a61779595..1c4913a59 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -415,6 +415,7 @@ impl AggregationJobCreator { *report_id, *time, ord.try_into()?, + None, ReportAggregationState::Start, )); } @@ -571,6 +572,7 @@ impl AggregationJobCreator { report_id, client_timestamp, ord.try_into()?, + None, ReportAggregationState::Start, )) }) @@ -732,6 +734,7 @@ impl AggregationJobCreator { *report_id, *time, ord.try_into()?, + None, ReportAggregationState::Start, )); } diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 66169cc39..e0be308a5 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -9,7 +9,7 @@ use janus_aggregator_core::{ self, models::{ AcquiredAggregationJob, AggregationJob, AggregationJobState, LeaderStoredReport, Lease, - PrepareMessageOrShare, ReportAggregation, ReportAggregationState, + ReportAggregation, ReportAggregationState, }, Datastore, }, @@ -437,9 +437,25 @@ impl AggregationJobDriver { if let ReportAggregationState::Waiting(prep_state, prep_msg) = report_aggregation.state() { - let prep_msg = prep_msg - .get_leader_prepare_message() - .context("report aggregation missing prepare message")?; + let prep_msg = match prep_msg.as_ref() { + Some(prep_msg) => prep_msg, + None => { + // This error indicates programmer/system error (i.e. it cannot possibly be + // the fault of our co-aggregator). We still record this failure against a + // single report, rather than failing the entire request, to minimize impact + // if we ever encounter this bug. + info!(report_id = %report_aggregation.report_id(), "Report aggregation is missing prepare message"); + self.aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "missing_prepare_message")], + ); + report_aggregations_to_write.push(report_aggregation.with_state( + ReportAggregationState::Failed(ReportShareError::VdafPrepError), + )); + continue; + } + }; // Step our own state. let leader_transition = match vdaf @@ -572,10 +588,9 @@ impl AggregationJobDriver { ) }); match prep_msg { - Ok(prep_msg) => ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + Ok(prep_msg) => { + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)) + } Err(error) => { info!(report_id = %report_aggregation.report_id(), ?error, "Couldn't compute prepare message"); self.aggregate_step_failure_counter.add( @@ -859,7 +874,7 @@ mod tests { datastore::{ models::{ AggregationJob, AggregationJobState, BatchAggregation, LeaderStoredReport, - PrepareMessageOrShare, ReportAggregation, ReportAggregationState, + ReportAggregation, ReportAggregationState, }, test_util::ephemeral_datastore, }, @@ -980,6 +995,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Start, ), ) @@ -1090,6 +1106,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); @@ -1223,6 +1240,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Start, )) .await?; @@ -1235,6 +1253,7 @@ mod tests { *repeated_extension_report.metadata().id(), *repeated_extension_report.metadata().time(), 1, + None, ReportAggregationState::Start, )) .await?; @@ -1345,10 +1364,8 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), ); let want_repeated_extension_report_aggregation = ReportAggregation::::new( @@ -1357,6 +1374,7 @@ mod tests { *repeated_extension_report.metadata().id(), *repeated_extension_report.metadata().time(), 1, + None, ReportAggregationState::Failed(ReportShareError::UnrecognizedMessage), ); @@ -1501,6 +1519,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Start, )) .await?; @@ -1609,9 +1628,10 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Waiting( transcript.leader_prep_state(0).clone(), - PrepareMessageOrShare::Leader(transcript.prepare_messages[0].clone()), + Some(transcript.prepare_messages[0].clone()), ), ); @@ -1739,10 +1759,8 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), )) .await?; @@ -1844,6 +1862,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Finished(transcript.output_share(Role::Leader).clone()), ); let batch_interval_start = report @@ -2031,10 +2050,8 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, - ReportAggregationState::Waiting( - leader_prep_state, - PrepareMessageOrShare::Leader(prep_msg), - ), + None, + ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), )) .await?; @@ -2137,6 +2154,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Finished(leader_output_share.clone()), ); let want_batch_aggregations = Vec::from([BatchAggregation::< @@ -2270,6 +2288,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Start, ); @@ -2472,6 +2491,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Start, ), ) diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index bf2cd6708..21ac91332 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -570,6 +570,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Finished(OutputShare()), )) .await?; @@ -698,6 +699,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), 0, + None, ReportAggregationState::Finished(OutputShare()), )) .await?; diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 44c4fb2e4..242a8bcc2 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -196,6 +196,7 @@ async fn setup_fixed_size_current_batch_collection_job_test_case( *report.metadata().id(), time, ord, + None, ReportAggregationState::Finished(dummy_vdaf::OutputShare()), )) .await diff --git a/aggregator/src/aggregator/garbage_collector.rs b/aggregator/src/aggregator/garbage_collector.rs index db8808f35..b90e5fb26 100644 --- a/aggregator/src/aggregator/garbage_collector.rs +++ b/aggregator/src/aggregator/garbage_collector.rs @@ -159,6 +159,7 @@ mod tests { *report.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -320,6 +321,7 @@ mod tests { *report_share.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -477,6 +479,7 @@ mod tests { *report.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -642,6 +645,7 @@ mod tests { *report_share.metadata().id(), client_timestamp, 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation) diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index b2e114912..89028f4e0 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -549,8 +549,7 @@ mod tests { datastore::{ models::{ AggregateShareJob, AggregationJob, AggregationJobState, BatchAggregation, - CollectionJobState, PrepareMessageOrShare, ReportAggregation, - ReportAggregationState, + CollectionJobState, ReportAggregation, ReportAggregationState, }, test_util::ephemeral_datastore, }, @@ -1563,6 +1562,7 @@ mod tests { *report_share_4.metadata().id(), *report_share_4.metadata().time(), 0, + None, ReportAggregationState::Start, )) .await @@ -1593,6 +1593,7 @@ mod tests { *report_share_8.metadata().id(), *report_share_8.metadata().time(), 0, + None, ReportAggregationState::Start, )) .await @@ -2070,7 +2071,7 @@ mod tests { report_metadata_0.id(), &0, ); - let (prep_state_0, prep_share_0) = transcript_0.helper_prep_state(0); + let (prep_state_0, _) = transcript_0.helper_prep_state(0); let prep_msg_0 = transcript_0.prepare_messages[0].clone(); let report_share_0 = generate_helper_report_share::( *task.id(), @@ -2097,7 +2098,7 @@ mod tests { &0, ); - let (prep_state_1, prep_share_1) = transcript_1.helper_prep_state(0); + let (prep_state_1, _) = transcript_1.helper_prep_state(0); let report_share_1 = generate_helper_report_share::( *task.id(), report_metadata_1.clone(), @@ -2125,7 +2126,7 @@ mod tests { report_metadata_2.id(), &0, ); - let (prep_state_2, prep_share_2) = transcript_2.helper_prep_state(0); + let (prep_state_2, _) = transcript_2.helper_prep_state(0); let prep_msg_2 = transcript_2.prepare_messages[0].clone(); let report_share_2 = generate_helper_report_share::( *task.id(), @@ -2144,11 +2145,6 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_share_0, prep_share_1, prep_share_2) = ( - prep_share_0.clone(), - prep_share_1.clone(), - prep_share_2.clone(), - ); let (prep_state_0, prep_state_1, prep_state_2) = ( prep_state_0.clone(), prep_state_1.clone(), @@ -2190,10 +2186,8 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, - ReportAggregationState::Waiting( - prep_state_0, - PrepareMessageOrShare::Helper(prep_share_0), - ), + None, + ReportAggregationState::Waiting(prep_state_0, None), ), ) .await?; @@ -2204,10 +2198,8 @@ mod tests { *report_metadata_1.id(), *report_metadata_1.time(), 1, - ReportAggregationState::Waiting( - prep_state_1, - PrepareMessageOrShare::Helper(prep_share_1), - ), + None, + ReportAggregationState::Waiting(prep_state_1, None), ), ) .await?; @@ -2218,10 +2210,8 @@ mod tests { *report_metadata_2.id(), *report_metadata_2.time(), 2, - ReportAggregationState::Waiting( - prep_state_2, - PrepareMessageOrShare::Helper(prep_share_2), - ), + None, + ReportAggregationState::Waiting(prep_state_2, None), ), ) .await?; @@ -2329,6 +2319,10 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, + Some(PrepareStep::new( + *report_metadata_0.id(), + PrepareStepResult::Finished + )), ReportAggregationState::Finished( transcript_0.output_share(Role::Helper).clone() ), @@ -2339,6 +2333,7 @@ mod tests { *report_metadata_1.id(), *report_metadata_1.time(), 1, + None, ReportAggregationState::Failed(ReportShareError::ReportDropped), ), ReportAggregation::new( @@ -2347,6 +2342,10 @@ mod tests { *report_metadata_2.id(), *report_metadata_2.time(), 2, + Some(PrepareStep::new( + *report_metadata_2.id(), + PrepareStepResult::Failed(ReportShareError::BatchCollected) + )), ReportAggregationState::Failed(ReportShareError::BatchCollected), ) ]) @@ -2395,7 +2394,7 @@ mod tests { report_metadata_0.id(), &0, ); - let (prep_state_0, prep_share_0) = transcript_0.helper_prep_state(0); + let (prep_state_0, _) = transcript_0.helper_prep_state(0); let out_share_0 = transcript_0.output_share(Role::Helper); let prep_msg_0 = transcript_0.prepare_messages[0].clone(); let report_share_0 = generate_helper_report_share::( @@ -2423,7 +2422,7 @@ mod tests { report_metadata_1.id(), &0, ); - let (prep_state_1, prep_share_1) = transcript_1.helper_prep_state(0); + let (prep_state_1, _) = transcript_1.helper_prep_state(0); let out_share_1 = transcript_1.output_share(Role::Helper); let prep_msg_1 = transcript_1.prepare_messages[0].clone(); let report_share_1 = generate_helper_report_share::( @@ -2450,7 +2449,7 @@ mod tests { report_metadata_2.id(), &0, ); - let (prep_state_2, prep_share_2) = transcript_2.helper_prep_state(0); + let (prep_state_2, _) = transcript_2.helper_prep_state(0); let out_share_2 = transcript_2.output_share(Role::Helper); let prep_msg_2 = transcript_2.prepare_messages[0].clone(); let report_share_2 = generate_helper_report_share::( @@ -2470,11 +2469,6 @@ mod tests { report_share_1.clone(), report_share_2.clone(), ); - let (prep_share_0, prep_share_1, prep_share_2) = ( - prep_share_0.clone(), - prep_share_1.clone(), - prep_share_2.clone(), - ); let (prep_state_0, prep_state_1, prep_state_2) = ( prep_state_0.clone(), prep_state_1.clone(), @@ -2518,10 +2512,8 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, - ReportAggregationState::Waiting( - prep_state_0, - PrepareMessageOrShare::Helper(prep_share_0), - ), + None, + ReportAggregationState::Waiting(prep_state_0, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -2533,10 +2525,8 @@ mod tests { *report_metadata_1.id(), *report_metadata_1.time(), 1, - ReportAggregationState::Waiting( - prep_state_1, - PrepareMessageOrShare::Helper(prep_share_1), - ), + None, + ReportAggregationState::Waiting(prep_state_1, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -2548,10 +2538,8 @@ mod tests { *report_metadata_2.id(), *report_metadata_2.time(), 2, - ReportAggregationState::Waiting( - prep_state_2, - PrepareMessageOrShare::Helper(prep_share_2), - ), + None, + ReportAggregationState::Waiting(prep_state_2, None), )) .await?; @@ -2698,7 +2686,7 @@ mod tests { report_metadata_3.id(), &0, ); - let (prep_state_3, prep_share_3) = transcript_3.helper_prep_state(0); + let (prep_state_3, _) = transcript_3.helper_prep_state(0); let out_share_3 = transcript_3.output_share(Role::Helper); let prep_msg_3 = transcript_3.prepare_messages[0].clone(); let report_share_3 = generate_helper_report_share::( @@ -2725,7 +2713,7 @@ mod tests { report_metadata_4.id(), &0, ); - let (prep_state_4, prep_share_4) = transcript_4.helper_prep_state(0); + let (prep_state_4, _) = transcript_4.helper_prep_state(0); let out_share_4 = transcript_4.output_share(Role::Helper); let prep_msg_4 = transcript_4.prepare_messages[0].clone(); let report_share_4 = generate_helper_report_share::( @@ -2752,7 +2740,7 @@ mod tests { report_metadata_5.id(), &0, ); - let (prep_state_5, prep_share_5) = transcript_5.helper_prep_state(0); + let (prep_state_5, _) = transcript_5.helper_prep_state(0); let out_share_5 = transcript_5.output_share(Role::Helper); let prep_msg_5 = transcript_5.prepare_messages[0].clone(); let report_share_5 = generate_helper_report_share::( @@ -2772,11 +2760,6 @@ mod tests { report_share_4.clone(), report_share_5.clone(), ); - let (prep_share_3, prep_share_4, prep_share_5) = ( - prep_share_3.clone(), - prep_share_4.clone(), - prep_share_5.clone(), - ); let (prep_state_3, prep_state_4, prep_state_5) = ( prep_state_3.clone(), prep_state_4.clone(), @@ -2818,10 +2801,8 @@ mod tests { *report_metadata_3.id(), *report_metadata_3.time(), 3, - ReportAggregationState::Waiting( - prep_state_3, - PrepareMessageOrShare::Helper(prep_share_3), - ), + None, + ReportAggregationState::Waiting(prep_state_3, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -2833,10 +2814,8 @@ mod tests { *report_metadata_4.id(), *report_metadata_4.time(), 4, - ReportAggregationState::Waiting( - prep_state_4, - PrepareMessageOrShare::Helper(prep_share_4), - ), + None, + ReportAggregationState::Waiting(prep_state_4, None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -2848,10 +2827,8 @@ mod tests { *report_metadata_5.id(), *report_metadata_5.time(), 5, - ReportAggregationState::Waiting( - prep_state_5, - PrepareMessageOrShare::Helper(prep_share_5), - ), + None, + ReportAggregationState::Waiting(prep_state_5, None), )) .await?; @@ -3063,10 +3040,8 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -3163,10 +3138,8 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -3244,6 +3217,10 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, + Some(PrepareStep::new( + *report_metadata.id(), + PrepareStepResult::Failed(ReportShareError::VdafPrepError) + )), ReportAggregationState::Failed(ReportShareError::VdafPrepError), ) ); @@ -3310,10 +3287,8 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -3433,10 +3408,8 @@ mod tests { *report_metadata_0.id(), *report_metadata_0.time(), 0, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await?; tx.put_report_aggregation(&ReportAggregation::< @@ -3448,10 +3421,8 @@ mod tests { *report_metadata_1.id(), *report_metadata_1.time(), 1, - ReportAggregationState::Waiting( - dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Helper(()), - ), + None, + ReportAggregationState::Waiting(dummy_vdaf::PrepareState::default(), None), )) .await }) @@ -3551,6 +3522,7 @@ mod tests { *report_metadata.id(), *report_metadata.time(), 0, + None, ReportAggregationState::Invalid, )) .await diff --git a/aggregator_api/src/lib.rs b/aggregator_api/src/lib.rs index a9a113dcb..737922810 100644 --- a/aggregator_api/src/lib.rs +++ b/aggregator_api/src/lib.rs @@ -762,6 +762,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), ord.try_into().unwrap(), + None, ReportAggregationState::Start, ), ) diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 37732cd5c..03c237d98 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -3,8 +3,8 @@ use self::models::{ AcquiredAggregationJob, AcquiredCollectionJob, AggregateShareJob, AggregationJob, AggregatorRole, BatchAggregation, CollectionJob, CollectionJobState, CollectionJobStateCode, - LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, PrepareMessageOrShare, - ReportAggregation, ReportAggregationState, ReportAggregationStateCode, SqlInterval, + LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, ReportAggregation, + ReportAggregationState, ReportAggregationStateCode, SqlInterval, }; #[cfg(feature = "test-util")] use crate::VdafHasAggregationParameter; @@ -24,7 +24,8 @@ use janus_core::{ use janus_messages::{ query_type::{QueryType, TimeInterval}, AggregationJobId, BatchId, CollectionJobId, Duration, Extension, HpkeCiphertext, HpkeConfig, - Interval, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, Time, + Interval, PrepareStep, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, TaskId, + Time, }; use opentelemetry::{ metrics::{Counter, Histogram}, @@ -1734,10 +1735,12 @@ impl Transaction<'_, C> { { let stmt = self .prepare_cached( - "SELECT client_reports.report_id, client_reports.client_timestamp, - report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.out_share, - report_aggregations.error_code, aggregation_jobs.aggregation_param + "SELECT + client_reports.report_id, client_reports.client_timestamp, + report_aggregations.ord, report_aggregations.state, + report_aggregations.prep_state, report_aggregations.prep_msg, + report_aggregations.out_share, report_aggregations.error_code, + report_aggregations.last_prep_step, aggregation_jobs.aggregation_param FROM report_aggregations JOIN client_reports ON client_reports.id = report_aggregations.client_report_id JOIN aggregation_jobs @@ -1787,10 +1790,12 @@ impl Transaction<'_, C> { { let stmt = self .prepare_cached( - "SELECT client_reports.report_id, client_reports.client_timestamp, - report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, - report_aggregations.prep_msg, report_aggregations.out_share, - report_aggregations.error_code, aggregation_jobs.aggregation_param + "SELECT + client_reports.report_id, client_reports.client_timestamp, + report_aggregations.ord, report_aggregations.state, + report_aggregations.prep_state, report_aggregations.prep_msg, + report_aggregations.out_share, report_aggregations.error_code, + report_aggregations.last_prep_step, aggregation_jobs.aggregation_param FROM report_aggregations JOIN client_reports ON client_reports.id = report_aggregations.client_report_id JOIN aggregation_jobs @@ -1844,7 +1849,8 @@ impl Transaction<'_, C> { client_reports.client_timestamp, report_aggregations.ord, report_aggregations.state, report_aggregations.prep_state, report_aggregations.prep_msg, report_aggregations.out_share, - report_aggregations.error_code, aggregation_jobs.aggregation_param + report_aggregations.error_code, report_aggregations.last_prep_step, + aggregation_jobs.aggregation_param FROM report_aggregations JOIN client_reports ON client_reports.id = report_aggregations.client_report_id JOIN aggregation_jobs @@ -1886,6 +1892,7 @@ impl Transaction<'_, C> { let prep_msg_bytes: Option> = row.get("prep_msg"); let out_share_bytes: Option> = row.get("out_share"); let error_code: Option = row.get("error_code"); + let last_prep_step_bytes: Option> = row.get("last_prep_step"); let aggregation_param_bytes = row.get("aggregation_param"); let error_code = match error_code { @@ -1900,8 +1907,13 @@ impl Transaction<'_, C> { None => None, }; + let last_prep_step = last_prep_step_bytes + .map(|bytes| PrepareStep::get_decoded(&bytes)) + .transpose()?; + let agg_state = match state { ReportAggregationStateCode::Start => ReportAggregationState::Start, + ReportAggregationStateCode::Waiting => { let agg_index = role.index().ok_or_else(|| { Error::User(anyhow!("unexpected role: {}", role.as_str()).into()) @@ -1915,23 +1927,13 @@ impl Transaction<'_, C> { ) })?, )?; - let prep_msg_bytes = prep_msg_bytes.ok_or_else(|| { - Error::DbState( - "report aggregation in state WAITING but prep_msg is NULL".to_string(), - ) - })?; - let prep_msg = match role { - Role::Leader => PrepareMessageOrShare::Leader( - A::PrepareMessage::get_decoded_with_param(&prep_state, &prep_msg_bytes)?, - ), - Role::Helper => PrepareMessageOrShare::Helper( - A::PrepareShare::get_decoded_with_param(&prep_state, &prep_msg_bytes)?, - ), - _ => return Err(Error::DbState(format!("unexpected role {role}"))), - }; + let prep_msg = prep_msg_bytes + .map(|bytes| A::PrepareMessage::get_decoded_with_param(&prep_state, &bytes)) + .transpose()?; ReportAggregationState::Waiting(prep_state, prep_msg) } + ReportAggregationStateCode::Finished => { let aggregation_param = A::AggregationParam::get_decoded(aggregation_param_bytes)?; ReportAggregationState::Finished(A::OutputShare::get_decoded_with_param( @@ -1944,6 +1946,7 @@ impl Transaction<'_, C> { })?, )?) } + ReportAggregationStateCode::Failed => { ReportAggregationState::Failed(error_code.ok_or_else(|| { Error::DbState( @@ -1951,6 +1954,7 @@ impl Transaction<'_, C> { ) })?) } + ReportAggregationStateCode::Invalid => ReportAggregationState::Invalid, }; @@ -1960,6 +1964,7 @@ impl Transaction<'_, C> { *report_id, time, ord, + last_prep_step, agg_state, )) } @@ -1977,17 +1982,20 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); + let encoded_last_prep_step = report_aggregation + .last_prep_step() + .map(PrepareStep::get_encoded); let stmt = self .prepare_cached( "INSERT INTO report_aggregations - (aggregation_job_id, client_report_id, ord, state, prep_state, prep_msg, out_share, - error_code) + (aggregation_job_id, client_report_id, ord, state, prep_state, prep_msg, + out_share, error_code, last_prep_step) VALUES ((SELECT id FROM aggregation_jobs WHERE aggregation_job_id = $1), (SELECT id FROM client_reports WHERE task_id = (SELECT id FROM tasks WHERE task_id = $2) AND report_id = $3), - $4, $5, $6, $7, $8, $9)", + $4, $5, $6, $7, $8, $9, $10)", ) .await?; self.execute( @@ -2003,6 +2011,7 @@ impl Transaction<'_, C> { /* prep_msg */ &encoded_state_values.prep_msg, /* out_share */ &encoded_state_values.output_share, /* error_code */ &encoded_state_values.report_share_err, + /* last_prep_step */ &encoded_last_prep_step, ], ) .await?; @@ -2021,16 +2030,20 @@ impl Transaction<'_, C> { A::PrepareState: Encode, { let encoded_state_values = report_aggregation.state().encoded_values_from_state(); + let encoded_last_prep_step = report_aggregation + .last_prep_step() + .map(PrepareStep::get_encoded); let stmt = self .prepare_cached( - "UPDATE report_aggregations SET ord = $1, state = $2, prep_state = $3, - prep_msg = $4, out_share = $5, error_code = $6 + "UPDATE report_aggregations SET + ord = $1, state = $2, prep_state = $3, prep_msg = $4, out_share = $5, + error_code = $6, last_prep_step = $7 WHERE aggregation_job_id = (SELECT id FROM aggregation_jobs WHERE - aggregation_job_id = $7) + aggregation_job_id = $8) AND client_report_id = (SELECT id FROM client_reports - WHERE task_id = (SELECT id FROM tasks WHERE task_id = $8) - AND report_id = $9)", + WHERE task_id = (SELECT id FROM tasks WHERE task_id = $9) + AND report_id = $10)", ) .await?; check_single_row_mutation( @@ -2043,6 +2056,7 @@ impl Transaction<'_, C> { /* prep_msg */ &encoded_state_values.prep_msg, /* out_share */ &encoded_state_values.output_share, /* error_code */ &encoded_state_values.report_share_err, + /* last_prep_step */ &encoded_last_prep_step, /* aggregation_job_id */ &report_aggregation.aggregation_job_id().as_ref(), /* task_id */ &report_aggregation.task_id().as_ref(), @@ -3860,8 +3874,8 @@ pub mod models { use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, AggregationJobId, AggregationJobRound, BatchId, CollectionJobId, Duration, Extension, - HpkeCiphertext, Interval, ReportId, ReportIdChecksum, ReportMetadata, ReportShareError, - Role, TaskId, Time, + HpkeCiphertext, Interval, PrepareStep, ReportId, ReportIdChecksum, ReportMetadata, + ReportShareError, Role, TaskId, Time, }; use postgres_protocol::types::{ range_from_sql, range_to_sql, timestamp_from_sql, timestamp_to_sql, Range, RangeBound, @@ -4400,6 +4414,7 @@ pub mod models { report_id: ReportId, time: Time, ord: u64, + last_prep_step: Option, state: ReportAggregationState, } @@ -4411,6 +4426,7 @@ pub mod models { report_id: ReportId, time: Time, ord: u64, + last_prep_step: Option, state: ReportAggregationState, ) -> Self { Self { @@ -4419,6 +4435,7 @@ pub mod models { report_id, time, ord, + last_prep_step, state, } } @@ -4453,12 +4470,26 @@ pub mod models { self.ord } + /// Returns the last preparation step returned by the Helper, if any. + pub fn last_prep_step(&self) -> Option<&PrepareStep> { + self.last_prep_step.as_ref() + } + + /// Returns a new [`ReportAggregation`] corresponding to this report aggregation updated to + /// have the given state. + pub fn with_last_prep_step(self, last_prep_step: Option) -> Self { + Self { + last_prep_step, + ..self + } + } + /// Returns the state of the report aggregation. pub fn state(&self) -> &ReportAggregationState { &self.state } - /// Returns a new [`ReportAggregation`] corresponding to this aggregation job updated to + /// Returns a new [`ReportAggregation`] corresponding to this report aggregation updated to /// have the given state. pub fn with_state(self, state: ReportAggregationState) -> Self { Self { state, ..self } @@ -4479,6 +4510,7 @@ pub mod models { && self.report_id == other.report_id && self.time == other.time && self.ord == other.ord + && self.last_prep_step == other.last_prep_step && self.state == other.state } } @@ -4493,76 +4525,6 @@ pub mod models { { } - /// Represents either a preprocessed VDAF preparation message (for the leader) or a VDAF - /// preparation message share (for the helper). - #[derive(Clone, Derivative)] - #[derivative(Debug)] - pub enum PrepareMessageOrShare> { - /// The helper stores a prepare message share - Helper(#[derivative(Debug = "ignore")] A::PrepareShare), - /// The leader stores a combined prepare message - Leader(#[derivative(Debug = "ignore")] A::PrepareMessage), - } - - impl> - PrepareMessageOrShare - { - /// Get the leader's preprocessed prepare message, or an error if this is a helper's prepare - /// share. - pub fn get_leader_prepare_message(&self) -> Result<&A::PrepareMessage, Error> - where - A::PrepareMessage: Encode, - { - if let Self::Leader(prep_msg) = self { - Ok(prep_msg) - } else { - Err(Error::InvalidParameter( - "does not contain a prepare message", - )) - } - } - - /// Get the helper's prepare share, or an error if this is a leader's preprocessed prepare - /// message. - pub fn get_helper_prepare_share(&self) -> Result<&A::PrepareShare, Error> - where - A::PrepareShare: Encode, - { - if let Self::Helper(prep_share) = self { - Ok(prep_share) - } else { - Err(Error::InvalidParameter("does not contain a prepare share")) - } - } - } - - impl PartialEq for PrepareMessageOrShare - where - A: vdaf::Aggregator, - A::PrepareShare: PartialEq, - A::PrepareMessage: PartialEq, - { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Helper(self_prep_share), Self::Helper(other_prep_share)) => { - self_prep_share.eq(other_prep_share) - } - (Self::Leader(self_prep_msg), Self::Leader(other_prep_msg)) => { - self_prep_msg.eq(other_prep_msg) - } - _ => false, - } - } - } - - impl Eq for PrepareMessageOrShare - where - A: vdaf::Aggregator, - A::PrepareShare: Eq, - A::PrepareMessage: Eq, - { - } - /// ReportAggregationState represents the state of a single report aggregation. It corresponds /// to the REPORT_AGGREGATION_STATE enum in the schema, along with the state-specific data. #[derive(Clone, Derivative)] @@ -4571,7 +4533,7 @@ pub mod models { Start, Waiting( #[derivative(Debug = "ignore")] A::PrepareState, - #[derivative(Debug = "ignore")] PrepareMessageOrShare, + #[derivative(Debug = "ignore")] Option, ), Finished(#[derivative(Debug = "ignore")] A::OutputShare), Failed(ReportShareError), @@ -4598,37 +4560,33 @@ pub mod models { where A::PrepareState: Encode, { - let (prep_state, prep_msg, output_share, report_share_err) = match self { - ReportAggregationState::Start => (None, None, None, None), + match self { + ReportAggregationState::Start => EncodedReportAggregationStateValues::default(), ReportAggregationState::Waiting(prep_state, prep_msg) => { - let encoded_msg = match prep_msg { - PrepareMessageOrShare::Leader(prep_msg) => prep_msg.get_encoded(), - PrepareMessageOrShare::Helper(prep_share) => prep_share.get_encoded(), - }; - ( - Some(prep_state.get_encoded()), - Some(encoded_msg), - None, - None, - ) + EncodedReportAggregationStateValues { + prep_state: Some(prep_state.get_encoded()), + prep_msg: prep_msg.as_ref().map(Encode::get_encoded), + ..Default::default() + } } ReportAggregationState::Finished(output_share) => { - (None, None, Some(output_share.get_encoded()), None) + EncodedReportAggregationStateValues { + output_share: Some(output_share.get_encoded()), + ..Default::default() + } } ReportAggregationState::Failed(report_share_err) => { - (None, None, None, Some(*report_share_err as i16)) + EncodedReportAggregationStateValues { + report_share_err: Some(*report_share_err as i16), + ..Default::default() + } } - ReportAggregationState::Invalid => (None, None, None, None), - }; - EncodedReportAggregationStateValues { - prep_state, - prep_msg, - output_share, - report_share_err, + ReportAggregationState::Invalid => EncodedReportAggregationStateValues::default(), } } } + #[derive(Default)] pub(super) struct EncodedReportAggregationStateValues { pub(super) prep_state: Option>, pub(super) prep_msg: Option>, @@ -5384,8 +5342,8 @@ mod tests { models::{ AcquiredAggregationJob, AcquiredCollectionJob, AggregateShareJob, AggregationJob, AggregationJobState, BatchAggregation, CollectionJob, CollectionJobState, - LeaderStoredReport, Lease, OutstandingBatch, PrepareMessageOrShare, - ReportAggregation, ReportAggregationState, SqlInterval, + LeaderStoredReport, Lease, OutstandingBatch, ReportAggregation, + ReportAggregationState, SqlInterval, }, test_util::{ephemeral_datastore, generate_aead_key}, Crypter, Datastore, Error, Transaction, @@ -5410,8 +5368,8 @@ mod tests { query_type::{FixedSize, QueryType, TimeInterval}, AggregateShareAad, AggregationJobId, AggregationJobRound, BatchId, BatchSelector, CollectionJobId, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, - Interval, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, - TaskId, Time, + Interval, PrepareStep, PrepareStepResult, ReportId, ReportIdChecksum, ReportMetadata, + ReportShare, ReportShareError, Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, @@ -5610,6 +5568,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), ord.try_into().unwrap(), + None, ReportAggregationState::Start, ) }) @@ -5625,6 +5584,7 @@ mod tests { *report.metadata().id(), *report.metadata().time(), ord.try_into().unwrap(), + None, ReportAggregationState::Start, ) }) @@ -6153,6 +6113,7 @@ mod tests { aggregated_report_id, aggregated_report_time, 0, + None, ReportAggregationState::Start, )) .await @@ -6423,6 +6384,7 @@ mod tests { *report_0.metadata().id(), *report_0.metadata().time(), 0, + None, ReportAggregationState::Start, ); let aggregation_job_0_report_aggregation_1 = @@ -6432,6 +6394,7 @@ mod tests { *report_1.metadata().id(), *report_1.metadata().time(), 1, + None, ReportAggregationState::Start, ); @@ -6452,6 +6415,7 @@ mod tests { *report_0.metadata().id(), *report_0.metadata().time(), 0, + None, ReportAggregationState::Start, ); let aggregation_job_1_report_aggregation_1 = @@ -6461,6 +6425,7 @@ mod tests { *report_1.metadata().id(), *report_1.metadata().time(), 1, + None, ReportAggregationState::Start, ); @@ -7219,39 +7184,18 @@ mod tests { let vdaf = Arc::new(Prio3::new_count(2).unwrap()); let verify_key: [u8; PRIO3_VERIFY_KEY_LENGTH] = random(); let vdaf_transcript = run_vdaf(vdaf.as_ref(), &verify_key, &(), &report_id, &0); - let (helper_prep_state, helper_prep_share) = vdaf_transcript.helper_prep_state(0); let leader_prep_state = vdaf_transcript.leader_prep_state(0); - for (ord, (role, state)) in [ - ( - Role::Leader, - ReportAggregationState::::Start, - ), - ( - Role::Helper, - ReportAggregationState::Waiting( - helper_prep_state.clone(), - PrepareMessageOrShare::Helper(helper_prep_share.clone()), - ), - ), - ( - Role::Leader, - ReportAggregationState::Waiting( - leader_prep_state.clone(), - PrepareMessageOrShare::Leader(vdaf_transcript.prepare_messages[0].clone()), - ), - ), - ( - Role::Leader, - ReportAggregationState::Finished( - vdaf_transcript.output_share(Role::Leader).clone(), - ), + for (ord, state) in [ + ReportAggregationState::::Start, + ReportAggregationState::Waiting( + leader_prep_state.clone(), + Some(vdaf_transcript.prepare_messages[0].clone()), ), - ( - Role::Leader, - ReportAggregationState::Failed(ReportShareError::VdafPrepError), - ), - (Role::Leader, ReportAggregationState::Invalid), + ReportAggregationState::Waiting(leader_prep_state.clone(), None), + ReportAggregationState::Finished(vdaf_transcript.output_share(Role::Leader).clone()), + ReportAggregationState::Failed(ReportShareError::VdafPrepError), + ReportAggregationState::Invalid, ] .into_iter() .enumerate() @@ -7259,14 +7203,14 @@ mod tests { let task = TaskBuilder::new( task::QueryType::TimeInterval, VdafInstance::Prio3Count, - role, + Role::Leader, ) .build(); let aggregation_job_id = random(); let time = Time::from_seconds_since_epoch(12345); let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); - let report_aggregation = ds + let want_report_aggregation = ds .run_tx(|tx| { let (task, state) = (task.clone(), state.clone()); Box::pin(async move { @@ -7309,6 +7253,10 @@ mod tests { report_id, time, ord.try_into().unwrap(), + Some(PrepareStep::new( + report_id, + PrepareStepResult::Continued(format!("prep_msg_{ord}").into()), + )), state, ); tx.put_report_aggregation(&report_aggregation).await?; @@ -7324,7 +7272,7 @@ mod tests { Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &role, + &Role::Leader, task.id(), &aggregation_job_id, &report_id, @@ -7335,33 +7283,24 @@ mod tests { .await .unwrap() .unwrap(); - assert_eq!(report_aggregation, got_report_aggregation); - if let ReportAggregationState::Waiting(_, message) = got_report_aggregation.state() { - match role { - Role::Leader => { - assert!(message.get_leader_prepare_message().is_ok()); - assert!(message.get_helper_prepare_share().is_err()); - } - Role::Helper => { - assert!(message.get_helper_prepare_share().is_ok()); - assert!(message.get_leader_prepare_message().is_err()); - } - _ => panic!("unexpected role"), - } - } + assert_eq!(want_report_aggregation, got_report_aggregation); - let new_report_aggregation = ReportAggregation::new( - *report_aggregation.task_id(), - *report_aggregation.aggregation_job_id(), - *report_aggregation.report_id(), - *report_aggregation.time(), - report_aggregation.ord() + 10, - report_aggregation.state().clone(), + let want_report_aggregation = ReportAggregation::new( + *want_report_aggregation.task_id(), + *want_report_aggregation.aggregation_job_id(), + *want_report_aggregation.report_id(), + *want_report_aggregation.time(), + want_report_aggregation.ord() + 10, + want_report_aggregation.last_prep_step().cloned(), + want_report_aggregation.state().clone(), ); + ds.run_tx(|tx| { - let new_report_aggregation = new_report_aggregation.clone(); - Box::pin(async move { tx.update_report_aggregation(&new_report_aggregation).await }) + let want_report_aggregation = want_report_aggregation.clone(); + Box::pin( + async move { tx.update_report_aggregation(&want_report_aggregation).await }, + ) }) .await .unwrap(); @@ -7372,7 +7311,7 @@ mod tests { Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), - &role, + &Role::Leader, task.id(), &aggregation_job_id, &report_id, @@ -7382,7 +7321,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(Some(new_report_aggregation), got_report_aggregation); + assert_eq!(Some(want_report_aggregation), got_report_aggregation); } } @@ -7437,6 +7376,7 @@ mod tests { report_id, Time::from_seconds_since_epoch(12345), 0, + None, ReportAggregationState::Start, ); tx.put_report_aggregation(&report_aggregation).await?; @@ -7545,6 +7485,7 @@ mod tests { ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(12345), 0, + None, ReportAggregationState::Invalid, )) .await @@ -7574,7 +7515,7 @@ mod tests { let aggregation_job_id = random(); let time = Time::from_seconds_since_epoch(12345); - let report_aggregations = ds + let want_report_aggregations = ds .run_tx(|tx| { let (task, prep_msg, prep_state, output_share) = ( task.clone(), @@ -7600,13 +7541,10 @@ mod tests { )) .await?; - let mut report_aggregations = Vec::new(); + let mut want_report_aggregations = Vec::new(); for (ord, state) in [ ReportAggregationState::::Start, - ReportAggregationState::Waiting( - prep_state.clone(), - PrepareMessageOrShare::Leader(prep_msg), - ), + ReportAggregationState::Waiting(prep_state.clone(), Some(prep_msg)), ReportAggregationState::Finished(output_share), ReportAggregationState::Failed(ReportShareError::VdafPrepError), ReportAggregationState::Invalid, @@ -7635,12 +7573,13 @@ mod tests { report_id, time, ord.try_into().unwrap(), + Some(PrepareStep::new(report_id, PrepareStepResult::Finished)), state.clone(), ); tx.put_report_aggregation(&report_aggregation).await?; - report_aggregations.push(report_aggregation); + want_report_aggregations.push(report_aggregation); } - Ok(report_aggregations) + Ok(want_report_aggregations) }) }) .await @@ -7661,7 +7600,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(report_aggregations, got_report_aggregations); + assert_eq!(want_report_aggregations, got_report_aggregations); } #[tokio::test] @@ -8117,6 +8056,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, // Doesn't matter what state the report aggregation is in )]); @@ -8247,6 +8187,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, // Doesn't matter what state the report aggregation is in )]); @@ -8483,6 +8424,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, // Shouldn't matter what state the report aggregation is in )]); @@ -8547,6 +8489,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, )]); @@ -8624,6 +8567,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8632,6 +8576,7 @@ mod tests { *reports[1].metadata().id(), *reports[1].metadata().time(), 0, + None, ReportAggregationState::Start, ), ]); @@ -8706,6 +8651,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8714,6 +8660,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ]); @@ -8869,6 +8816,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8877,6 +8825,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -8885,6 +8834,7 @@ mod tests { *reports[0].metadata().id(), *reports[0].metadata().time(), 0, + None, ReportAggregationState::Start, ), ]); @@ -9443,6 +9393,7 @@ mod tests { random(), clock.now(), 0, + None, ReportAggregationState::Start, // Counted among max_size. ); let report_aggregation_0_1 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9451,9 +9402,10 @@ mod tests { random(), clock.now(), 1, + None, ReportAggregationState::Waiting( dummy_vdaf::PrepareState::default(), - PrepareMessageOrShare::Leader(()), + Some(()), ), // Counted among max_size. ); let report_aggregation_0_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9462,6 +9414,7 @@ mod tests { random(), clock.now(), 2, + None, ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. ); let report_aggregation_0_3 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9470,6 +9423,7 @@ mod tests { random(), clock.now(), 3, + None, ReportAggregationState::Invalid, // Not counted among min_size or max_size. ); @@ -9489,6 +9443,7 @@ mod tests { random(), clock.now(), 0, + None, ReportAggregationState::Finished(dummy_vdaf::OutputShare()), // Counted among min_size and max_size. ); let report_aggregation_1_1 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9497,6 +9452,7 @@ mod tests { random(), clock.now(), 1, + None, ReportAggregationState::Finished(dummy_vdaf::OutputShare()), // Counted among min_size and max_size. ); let report_aggregation_1_2 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9505,6 +9461,7 @@ mod tests { random(), clock.now(), 2, + None, ReportAggregationState::Failed(ReportShareError::VdafPrepError), // Not counted among min_size or max_size. ); let report_aggregation_1_3 = ReportAggregation::<0, dummy_vdaf::Vdaf>::new( @@ -9513,6 +9470,7 @@ mod tests { random(), clock.now(), 3, + None, ReportAggregationState::Invalid, // Not counted among min_size or max_size. ); @@ -9760,6 +9718,7 @@ mod tests { *attached_report.metadata().id(), *attached_report.metadata().time(), 0, + None, ReportAggregationState::<0, dummy_vdaf::Vdaf>::Start, ); @@ -9888,6 +9847,7 @@ mod tests { *report_id, *client_timestamp, ord.try_into().unwrap(), + None, ReportAggregationState::<0, dummy_vdaf::Vdaf>::Start, ); tx.put_report_aggregation(&report_aggregation) @@ -10598,6 +10558,7 @@ mod tests { *report.metadata().id(), *client_timestamp, 0, + None, ReportAggregationState::<0, dummy_vdaf::Vdaf>::Start, ); tx.put_report_aggregation(&report_aggregation) diff --git a/db/20230405185602_initial-schema.up.sql b/db/20230405185602_initial-schema.up.sql index a47f70a4e..0b2ec2c2b 100644 --- a/db/20230405185602_initial-schema.up.sql +++ b/db/20230405185602_initial-schema.up.sql @@ -140,9 +140,7 @@ CREATE TABLE report_aggregations( ord BIGINT NOT NULL, -- a value used to specify the ordering of client reports in the aggregation job state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation prep_state BYTEA, -- the current preparation state (opaque VDAF message, only if in state WAITING) - prep_msg BYTEA, -- for the leader, the next preparation message to be sent to the helper (opaque VDAF message) - -- for the helper, the next preparation share to be sent to the leader (opaque VDAF message) - -- only non-NULL if in state WAITING + prep_msg BYTEA, -- the next preparation message to be sent to the helper (opaque VDAF message, populated for Leader only) out_share BYTEA, -- the output share (opaque VDAF message, only if in state FINISHED) error_code SMALLINT, -- error code corresponding to a DAP ReportShareError value; null if in a state other than FAILED diff --git a/db/20230417204528_last-prep-step.down.sql b/db/20230417204528_last-prep-step.down.sql new file mode 100644 index 000000000..025b3efa0 --- /dev/null +++ b/db/20230417204528_last-prep-step.down.sql @@ -0,0 +1 @@ +ALTER TABLE report_aggregations DROP COLUMN last_prep_step; \ No newline at end of file diff --git a/db/20230417204528_last-prep-step.up.sql b/db/20230417204528_last-prep-step.up.sql new file mode 100644 index 000000000..a3d2a0bd5 --- /dev/null +++ b/db/20230417204528_last-prep-step.up.sql @@ -0,0 +1 @@ +ALTER TABLE report_aggregations ADD COLUMN last_prep_step BYTEA; -- the last PreparationStep message sent to the Leader, to assist in replay (opaque VDAF message, populated for Helper only) \ No newline at end of file