From 0e08bf77d56cd9c772df194235582bd8dbbb9f6e Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 17 Sep 2024 13:43:51 +0200 Subject: [PATCH] Implement CancelInvocationEntry --- crates/service-protocol/src/codec.rs | 34 ++++++-- crates/types/src/journal/entries.rs | 4 + crates/worker/src/invoker_integration.rs | 26 +++++- .../worker/src/partition/state_machine/mod.rs | 84 ++++++++++++++++++- .../state_machine/tests/kill_cancel.rs | 74 ++++++++++++++++ 5 files changed, 214 insertions(+), 8 deletions(-) diff --git a/crates/service-protocol/src/codec.rs b/crates/service-protocol/src/codec.rs index 08cb90253d..4e3c5b7e71 100644 --- a/crates/service-protocol/src/codec.rs +++ b/crates/service-protocol/src/codec.rs @@ -176,13 +176,15 @@ mod test_util { AwakeableEnrichmentResult, CallEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, }; use restate_types::journal::{ - AwakeableEntry, CompletableEntry, CompleteAwakeableEntry, EntryResult, GetStateKeysEntry, - GetStateKeysResult, InputEntry, OutputEntry, + AwakeableEntry, CancelInvocationEntry, CancelInvocationTarget, CompletableEntry, + CompleteAwakeableEntry, EntryResult, GetStateKeysEntry, GetStateKeysResult, InputEntry, + OutputEntry, }; use restate_types::service_protocol::{ - awakeable_entry_message, call_entry_message, complete_awakeable_entry_message, - get_state_entry_message, get_state_keys_entry_message, output_entry_message, - AwakeableEntryMessage, CallEntryMessage, ClearAllStateEntryMessage, ClearStateEntryMessage, + awakeable_entry_message, call_entry_message, cancel_invocation_entry_message, + complete_awakeable_entry_message, get_state_entry_message, get_state_keys_entry_message, + output_entry_message, AwakeableEntryMessage, CallEntryMessage, + CancelInvocationEntryMessage, ClearAllStateEntryMessage, ClearStateEntryMessage, CompleteAwakeableEntryMessage, Failure, GetStateEntryMessage, GetStateKeysEntryMessage, InputEntryMessage, OneWayCallEntryMessage, OutputEntryMessage, SetStateEntryMessage, }; @@ -347,6 +349,10 @@ mod test_util { }, Self::serialize_awakeable_entry(entry), ), + Entry::CancelInvocation(entry) => EnrichedRawEntry::new( + EnrichedEntryHeader::CancelInvocation {}, + Self::serialize_cancel_invocation_entry(entry), + ), _ => unimplemented!(), } } @@ -439,6 +445,24 @@ mod test_util { .encode_to_vec() .into() } + + fn serialize_cancel_invocation_entry( + CancelInvocationEntry { target }: CancelInvocationEntry, + ) -> Bytes { + CancelInvocationEntryMessage { + target: Some(match target { + CancelInvocationTarget::InvocationId(id) => { + cancel_invocation_entry_message::Target::InvocationId(id.to_string()) + } + CancelInvocationTarget::CallEntryIndex(idx) => { + cancel_invocation_entry_message::Target::CallEntryIndex(idx) + } + }), + ..Default::default() + } + .encode_to_vec() + .into() + } } } diff --git a/crates/types/src/journal/entries.rs b/crates/types/src/journal/entries.rs index d49de67594..de85d67339 100644 --- a/crates/types/src/journal/entries.rs +++ b/crates/types/src/journal/entries.rs @@ -107,6 +107,10 @@ impl Entry { pub fn awakeable(result: Option) -> Self { Entry::Awakeable(AwakeableEntry { result }) } + + pub fn cancel_invocation(target: CancelInvocationTarget) -> Entry { + Entry::CancelInvocation(CancelInvocationEntry { target }) + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/crates/worker/src/invoker_integration.rs b/crates/worker/src/invoker_integration.rs index 697786fca2..83d9274d9f 100644 --- a/crates/worker/src/invoker_integration.rs +++ b/crates/worker/src/invoker_integration.rs @@ -26,7 +26,10 @@ use restate_types::journal::enriched::{ AwakeableEnrichmentResult, CallEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, }; use restate_types::journal::raw::{PlainEntryHeader, PlainRawEntry, RawEntry, RawEntryCodec}; -use restate_types::journal::{CompleteAwakeableEntry, Entry, InvokeEntry, OneWayCallEntry}; +use restate_types::journal::{ + CancelInvocationEntry, CancelInvocationTarget, CompleteAwakeableEntry, Entry, InvokeEntry, + OneWayCallEntry, +}; use restate_types::journal::{EntryType, InvokeRequest}; use restate_types::live::Live; use restate_types::schema::invocation_target::InvocationTargetResolver; @@ -250,7 +253,26 @@ where } } PlainEntryHeader::Run { .. } => EnrichedEntryHeader::Run {}, - PlainEntryHeader::CancelInvocation { .. } => EnrichedEntryHeader::CancelInvocation {}, + PlainEntryHeader::CancelInvocation { .. } => { + // Validate the invocation id is valid + let entry = + Codec::deserialize(EntryType::CancelInvocation, serialized_entry.clone()) + .map_err(InvocationError::internal)?; + let_assert!(Entry::CancelInvocation(CancelInvocationEntry { target }) = entry); + if let CancelInvocationTarget::InvocationId(id) = target { + if let Err(e) = id.parse::() { + return Err(InvocationError::new( + codes::BAD_REQUEST, + format!( + "The given invocation id '{}' to cancel is invalid: {}", + id, e + ), + )); + } + } + + EnrichedEntryHeader::CancelInvocation {} + } PlainEntryHeader::GetCallInvocationId { is_completed } => { EnrichedEntryHeader::GetCallInvocationId { is_completed } } diff --git a/crates/worker/src/partition/state_machine/mod.rs b/crates/worker/src/partition/state_machine/mod.rs index 3f14881796..c557f2e16d 100644 --- a/crates/worker/src/partition/state_machine/mod.rs +++ b/crates/worker/src/partition/state_machine/mod.rs @@ -2371,7 +2371,12 @@ impl StateMachine { // We just store it } EntryHeader::CancelInvocation => { - todo!() + let_assert!( + Entry::CancelInvocation(entry) = + journal_entry.deserialize_entry_ref::()? + ); + self.apply_cancel_invocation_journal_entry_action(ctx, &invocation_id, entry) + .await?; } EntryHeader::GetCallInvocationId { .. } => { todo!() @@ -2394,6 +2399,83 @@ impl StateMachine { Ok(()) } + async fn apply_cancel_invocation_journal_entry_action< + State: OutboxTable + FsmTable + ReadOnlyJournalTable, + >( + &mut self, + ctx: &mut StateMachineApplyContext<'_, State>, + invocation_id: &InvocationId, + entry: CancelInvocationEntry, + ) -> Result<(), Error> { + let target_invocation_id = match entry.target { + CancelInvocationTarget::InvocationId(id) => { + if let Ok(id) = id.parse::() { + Some(id) + } else { + warn!( + "Error when trying to parse the invocation id '{}' of CancelInvocation. \ + This should have been previously checked by the invoker.", + id + ); + None + } + } + CancelInvocationTarget::CallEntryIndex(call_entry_index) => { + // Look for the given entry index, then resolve the invocation id. + match ctx + .storage + .get_journal_entry(invocation_id, call_entry_index) + .await? + { + Some(JournalEntry::Entry(e)) => { + match e.header() { + EnrichedEntryHeader::Call { + enrichment_result: Some(CallEnrichmentResult { invocation_id, .. }), + .. + } + | EnrichedEntryHeader::OneWayCall { + enrichment_result: CallEnrichmentResult { invocation_id, .. }, + .. + } => Some(*invocation_id), + // This is the corner case when there is no enrichment result due to + // the invocation being already completed from the SDK. Nothing to do here. + EnrichedEntryHeader::Call { + enrichment_result: None, + .. + } => None, + _ => { + warn!( + "The given journal entry index '{}' is not a Call/OneWayCall entry. This is potentially an SDK bug.", + call_entry_index + ); + None + } + } + } + _ => { + warn!( + "The given journal entry index '{}' does not exist. This is potentially an SDK bug.", + call_entry_index + ); + None + } + } + } + }; + + if let Some(target_invocation_id) = target_invocation_id { + self.handle_outgoing_message( + ctx, + OutboxMessage::InvocationTermination(InvocationTermination { + invocation_id: target_invocation_id, + flavor: TerminationFlavor::Cancel, + }), + ) + .await?; + } + Ok(()) + } + async fn handle_completion( ctx: &mut StateMachineApplyContext<'_, State>, invocation_id: InvocationId, diff --git a/crates/worker/src/partition/state_machine/tests/kill_cancel.rs b/crates/worker/src/partition/state_machine/tests/kill_cancel.rs index 8ac633060c..d987b6c98b 100644 --- a/crates/worker/src/partition/state_machine/tests/kill_cancel.rs +++ b/crates/worker/src/partition/state_machine/tests/kill_cancel.rs @@ -620,3 +620,77 @@ fn terminate_invocation_outbox_message_matcher( ) }) } + +#[test(tokio::test)] +async fn cancel_invocation_entry_referring_to_previous_entry() { + let mut test_env = TestEnv::create().await; + + let invocation_target = InvocationTarget::mock_service(); + let invocation_id = InvocationId::mock_random(); + + let callee_1 = InvocationId::mock_random(); + let callee_2 = InvocationId::mock_random(); + + let _ = test_env + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + // Add call and one way call journal entry + let mut tx = test_env.storage.transaction(); + tx.put_journal_entry(&invocation_id, 1, &background_invoke_entry(callee_1)) + .await; + tx.put_journal_entry(&invocation_id, 2, &uncompleted_invoke_entry(callee_2)) + .await; + let mut invocation_status = tx.get_invocation_status(&invocation_id).await.unwrap(); + invocation_status.get_journal_metadata_mut().unwrap().length = 3; + tx.put_invocation_status(&invocation_id, &invocation_status) + .await; + tx.commit().await.unwrap(); + + // Now create cancel invocation entry + let actions = test_env + .apply_multiple(vec![ + Command::InvokerEffect(InvokerEffect { + invocation_id, + kind: InvokerEffectKind::JournalEntry { + entry_index: 3, + entry: ProtobufRawEntryCodec::serialize_enriched(Entry::cancel_invocation( + CancelInvocationTarget::InvocationId(callee_1.to_string().into()), + )), + }, + }), + Command::InvokerEffect(InvokerEffect { + invocation_id, + kind: InvokerEffectKind::JournalEntry { + entry_index: 4, + entry: ProtobufRawEntryCodec::serialize_enriched(Entry::cancel_invocation( + CancelInvocationTarget::CallEntryIndex(2), + )), + }, + }), + ]) + .await; + + assert_that!( + actions, + all!( + contains(terminate_invocation_outbox_message_matcher( + callee_1, + TerminationFlavor::Cancel + )), + contains(terminate_invocation_outbox_message_matcher( + callee_2, + TerminationFlavor::Cancel + )), + ) + ); + assert_that!( + test_env.storage.get_invocation_status(&invocation_id).await, + ok(pat!(InvocationStatus::Invoked { .. })) + ); + test_env.shutdown().await; +}