diff --git a/crates/invoker-api/src/entry_enricher.rs b/crates/invoker-api/src/entry_enricher.rs index 7426aaa04..1115a8094 100644 --- a/crates/invoker-api/src/entry_enricher.rs +++ b/crates/invoker-api/src/entry_enricher.rs @@ -111,6 +111,10 @@ pub mod test_util { } PlainEntryHeader::Run {} => EnrichedEntryHeader::Run {}, PlainEntryHeader::Custom { code } => EnrichedEntryHeader::Custom { code }, + PlainEntryHeader::CancelInvocation => EnrichedEntryHeader::CancelInvocation, + PlainEntryHeader::GetCallInvocationId { is_completed } => { + EnrichedEntryHeader::GetCallInvocationId { is_completed } + } }; Ok(RawEntry::new(enriched_header, entry)) diff --git a/crates/invoker-impl/src/invocation_task/mod.rs b/crates/invoker-impl/src/invocation_task/mod.rs index c2f125f28..de2d2c184 100644 --- a/crates/invoker-impl/src/invocation_task/mod.rs +++ b/crates/invoker-impl/src/invocation_task/mod.rs @@ -60,6 +60,10 @@ const SERVICE_PROTOCOL_VERSION_V1: HeaderValue = const SERVICE_PROTOCOL_VERSION_V2: HeaderValue = HeaderValue::from_static("application/vnd.restate.invocation.v2"); +#[allow(clippy::declare_interior_mutable_const)] +const SERVICE_PROTOCOL_VERSION_V3: HeaderValue = + HeaderValue::from_static("application/vnd.restate.invocation.v3"); + #[allow(clippy::declare_interior_mutable_const)] const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server"); @@ -511,6 +515,7 @@ fn service_protocol_version_to_header_value( } ServiceProtocolVersion::V1 => SERVICE_PROTOCOL_VERSION_V1, ServiceProtocolVersion::V2 => SERVICE_PROTOCOL_VERSION_V2, + ServiceProtocolVersion::V3 => SERVICE_PROTOCOL_VERSION_V3, } } diff --git a/crates/service-protocol/src/codec.rs b/crates/service-protocol/src/codec.rs index 1bcdc32ea..2b45ff185 100644 --- a/crates/service-protocol/src/codec.rs +++ b/crates/service-protocol/src/codec.rs @@ -95,7 +95,9 @@ impl RawEntryCodec for ProtobufRawEntryCodec { OneWayCall, Awakeable, CompleteAwakeable, - Run + Run, + CancelInvocation, + GetCallInvocationId }) } @@ -174,14 +176,17 @@ mod test_util { AwakeableEnrichmentResult, CallEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, }; use restate_types::journal::{ - AwakeableEntry, CompletableEntry, CompleteAwakeableEntry, EntryResult, GetStateKeysEntry, - GetStateKeysResult, InputEntry, OutputEntry, + AwakeableEntry, CancelInvocationEntry, CancelInvocationTarget, CompletableEntry, + CompleteAwakeableEntry, EntryResult, GetCallInvocationIdEntry, GetCallInvocationIdResult, + GetStateKeysEntry, GetStateKeysResult, InputEntry, OutputEntry, }; use restate_types::service_protocol::{ - awakeable_entry_message, call_entry_message, complete_awakeable_entry_message, + awakeable_entry_message, call_entry_message, cancel_invocation_entry_message, + complete_awakeable_entry_message, get_call_invocation_id_entry_message, get_state_entry_message, get_state_keys_entry_message, output_entry_message, - AwakeableEntryMessage, CallEntryMessage, ClearAllStateEntryMessage, ClearStateEntryMessage, - CompleteAwakeableEntryMessage, Failure, GetStateEntryMessage, GetStateKeysEntryMessage, + AwakeableEntryMessage, CallEntryMessage, CancelInvocationEntryMessage, + ClearAllStateEntryMessage, ClearStateEntryMessage, CompleteAwakeableEntryMessage, Failure, + GetCallInvocationIdEntryMessage, GetStateEntryMessage, GetStateKeysEntryMessage, InputEntryMessage, OneWayCallEntryMessage, OutputEntryMessage, SetStateEntryMessage, }; @@ -345,6 +350,16 @@ mod test_util { }, Self::serialize_awakeable_entry(entry), ), + Entry::CancelInvocation(entry) => EnrichedRawEntry::new( + EnrichedEntryHeader::CancelInvocation {}, + Self::serialize_cancel_invocation_entry(entry), + ), + Entry::GetCallInvocationId(entry) => EnrichedRawEntry::new( + EnrichedEntryHeader::GetCallInvocationId { + is_completed: entry.is_completed(), + }, + Self::serialize_get_call_invocation_id_entry(entry), + ), _ => unimplemented!(), } } @@ -437,6 +452,49 @@ mod test_util { .encode_to_vec() .into() } + + fn serialize_cancel_invocation_entry( + CancelInvocationEntry { target }: CancelInvocationEntry, + ) -> Bytes { + CancelInvocationEntryMessage { + target: Some(match target { + CancelInvocationTarget::InvocationId(id) => { + cancel_invocation_entry_message::Target::InvocationId(id.to_string()) + } + CancelInvocationTarget::CallEntryIndex(idx) => { + cancel_invocation_entry_message::Target::CallEntryIndex(idx) + } + }), + ..Default::default() + } + .encode_to_vec() + .into() + } + + fn serialize_get_call_invocation_id_entry( + GetCallInvocationIdEntry { + call_entry_index, + result, + }: GetCallInvocationIdEntry, + ) -> Bytes { + GetCallInvocationIdEntryMessage { + call_entry_index, + result: result.map(|res| match res { + GetCallInvocationIdResult::InvocationId(success) => { + get_call_invocation_id_entry_message::Result::Value(success) + } + GetCallInvocationIdResult::Failure(code, reason) => { + get_call_invocation_id_entry_message::Result::Failure(Failure { + code: code.into(), + message: reason.to_string(), + }) + } + }), + ..Default::default() + } + .encode_to_vec() + .into() + } } } diff --git a/crates/service-protocol/src/message/encoding.rs b/crates/service-protocol/src/message/encoding.rs index 843819ed2..c9ac47dbe 100644 --- a/crates/service-protocol/src/message/encoding.rs +++ b/crates/service-protocol/src/message/encoding.rs @@ -328,6 +328,10 @@ fn message_header_to_raw_header(message_header: &MessageHeader) -> PlainEntryHea enrichment_result: (), }, MessageType::SideEffectEntry => PlainEntryHeader::Run {}, + MessageType::CancelInvocationEntry => PlainEntryHeader::CancelInvocation {}, + MessageType::GetCallInvocationIdEntry => PlainEntryHeader::GetCallInvocationId { + is_completed: expect_flag!(message_header, completed), + }, MessageType::CustomEntry(code) => PlainEntryHeader::Custom { code }, } } @@ -350,6 +354,8 @@ fn raw_header_to_message_type(entry_header: &PlainEntryHeader) -> MessageType { PlainEntryHeader::Awakeable { .. } => MessageType::AwakeableEntry, PlainEntryHeader::CompleteAwakeable { .. } => MessageType::CompleteAwakeableEntry, PlainEntryHeader::Run { .. } => MessageType::SideEffectEntry, + PlainEntryHeader::CancelInvocation => MessageType::CancelInvocationEntry, + PlainEntryHeader::GetCallInvocationId { .. } => MessageType::GetCallInvocationIdEntry, PlainEntryHeader::Custom { code, .. } => MessageType::CustomEntry(*code), } } diff --git a/crates/service-protocol/src/message/header.rs b/crates/service-protocol/src/message/header.rs index 349e33aca..7150ecd89 100644 --- a/crates/service-protocol/src/message/header.rs +++ b/crates/service-protocol/src/message/header.rs @@ -49,6 +49,8 @@ pub enum MessageType { GetPromiseEntry, PeekPromiseEntry, CompletePromiseEntry, + CancelInvocationEntry, + GetCallInvocationIdEntry, CustomEntry(u16), } @@ -77,6 +79,8 @@ impl MessageType { MessageType::GetPromiseEntry => MessageKind::State, MessageType::PeekPromiseEntry => MessageKind::State, MessageType::CompletePromiseEntry => MessageKind::State, + MessageType::CancelInvocationEntry => MessageKind::Syscall, + MessageType::GetCallInvocationIdEntry => MessageKind::Syscall, MessageType::CustomEntry(_) => MessageKind::CustomEntry, } } @@ -92,6 +96,7 @@ impl MessageType { | MessageType::GetPromiseEntry | MessageType::PeekPromiseEntry | MessageType::CompletePromiseEntry + | MessageType::GetCallInvocationIdEntry ) } @@ -125,6 +130,8 @@ const BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: u16 = 0x0C02; const AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C03; const COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C04; const SIDE_EFFECT_ENTRY_MESSAGE_TYPE: u16 = 0x0C05; +const CANCEL_INVOCATION_ENTRY_MESSAGE_TYPE: u16 = 0x0C06; +const GET_CALL_INVOCATION_ID_ENTRY_MESSAGE_TYPE: u16 = 0x0C07; impl From for MessageTypeId { fn from(mt: MessageType) -> Self { @@ -151,6 +158,8 @@ impl From for MessageTypeId { MessageType::GetPromiseEntry => GET_PROMISE_ENTRY_MESSAGE_TYPE, MessageType::PeekPromiseEntry => PEEK_PROMISE_ENTRY_MESSAGE_TYPE, MessageType::CompletePromiseEntry => COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE, + MessageType::CancelInvocationEntry => CANCEL_INVOCATION_ENTRY_MESSAGE_TYPE, + MessageType::GetCallInvocationIdEntry => GET_CALL_INVOCATION_ID_ENTRY_MESSAGE_TYPE, MessageType::CustomEntry(id) => id, } } @@ -187,6 +196,8 @@ impl TryFrom for MessageType { PEEK_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::PeekPromiseEntry), COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CompletePromiseEntry), SIDE_EFFECT_ENTRY_MESSAGE_TYPE => Ok(MessageType::SideEffectEntry), + CANCEL_INVOCATION_ENTRY_MESSAGE_TYPE => Ok(MessageType::CancelInvocationEntry), + GET_CALL_INVOCATION_ID_ENTRY_MESSAGE_TYPE => Ok(MessageType::GetCallInvocationIdEntry), v if ((v & CUSTOM_MESSAGE_MASK) != 0) => Ok(MessageType::CustomEntry(v)), v => Err(UnknownMessageType(v)), } @@ -214,6 +225,8 @@ impl TryFrom for EntryType { MessageType::GetPromiseEntry => Ok(EntryType::GetPromise), MessageType::PeekPromiseEntry => Ok(EntryType::PeekPromise), MessageType::CompletePromiseEntry => Ok(EntryType::CompletePromise), + MessageType::CancelInvocationEntry => Ok(EntryType::CancelInvocation), + MessageType::GetCallInvocationIdEntry => Ok(EntryType::GetCallInvocationId), MessageType::CustomEntry(_) => Ok(EntryType::Custom), MessageType::Start | MessageType::Completion diff --git a/crates/storage-api/proto/dev/restate/storage/v1/domain.proto b/crates/storage-api/proto/dev/restate/storage/v1/domain.proto index 9b53c582b..7497e76fa 100644 --- a/crates/storage-api/proto/dev/restate/storage/v1/domain.proto +++ b/crates/storage-api/proto/dev/restate/storage/v1/domain.proto @@ -389,6 +389,13 @@ message EnrichedEntryHeader { message SideEffect { } + message CancelInvocation { + } + + message GetCallInvocationId { + bool is_completed = 1; + } + message Custom { uint32 code = 1; } @@ -411,6 +418,8 @@ message EnrichedEntryHeader { CompleteAwakeable complete_awakeable = 10; Custom custom = 11; SideEffect side_effect = 14; + CancelInvocation cancel_invocation = 18; + GetCallInvocationId get_call_invocation_id = 19; } } diff --git a/crates/storage-api/src/storage.rs b/crates/storage-api/src/storage.rs index 44ee0c2fd..752f6c033 100644 --- a/crates/storage-api/src/storage.rs +++ b/crates/storage-api/src/storage.rs @@ -82,9 +82,9 @@ pub mod v1 { use crate::storage::v1::dedup_sequence_number::Variant; use crate::storage::v1::enriched_entry_header::{ - Awakeable, BackgroundCall, ClearAllState, ClearState, CompleteAwakeable, - CompletePromise, Custom, GetPromise, GetState, GetStateKeys, Input, Invoke, Output, - PeekPromise, SetState, SideEffect, Sleep, + Awakeable, BackgroundCall, CancelInvocation, ClearAllState, ClearState, + CompleteAwakeable, CompletePromise, Custom, GetCallInvocationId, GetPromise, GetState, + GetStateKeys, Input, Invoke, Output, PeekPromise, SetState, SideEffect, Sleep, }; use crate::storage::v1::invocation_status::{Completed, Free, Inboxed, Invoked, Suspended}; use crate::storage::v1::journal_entry::completion_result::{Empty, Failure, Success}; @@ -2210,6 +2210,14 @@ pub mod v1 { enriched_entry_header::Kind::SideEffect(_) => { restate_types::journal::enriched::EnrichedEntryHeader::Run {} } + enriched_entry_header::Kind::CancelInvocation(_) => { + restate_types::journal::enriched::EnrichedEntryHeader::CancelInvocation {} + } + enriched_entry_header::Kind::GetCallInvocationId(entry) => { + restate_types::journal::enriched::EnrichedEntryHeader::GetCallInvocationId { + is_completed: entry.is_completed, + } + } enriched_entry_header::Kind::Custom(custom) => { restate_types::journal::enriched::EnrichedEntryHeader::Custom { code: u16::try_from(custom.code) @@ -2306,6 +2314,13 @@ pub mod v1 { } => enriched_entry_header::Kind::CompletePromise(CompletePromise { is_completed, }), + restate_types::journal::enriched::EnrichedEntryHeader::CancelInvocation { + .. + } => enriched_entry_header::Kind::CancelInvocation(CancelInvocation {}), + restate_types::journal::enriched::EnrichedEntryHeader::GetCallInvocationId { + is_completed, + .. + } => enriched_entry_header::Kind::GetCallInvocationId(GetCallInvocationId { is_completed }), }; EnrichedEntryHeader { kind: Some(kind) } diff --git a/crates/types/service-protocol/buf.lock b/crates/types/service-protocol/buf.lock new file mode 100644 index 000000000..4f98143f5 --- /dev/null +++ b/crates/types/service-protocol/buf.lock @@ -0,0 +1,2 @@ +# Generated by buf. DO NOT EDIT. +version: v2 diff --git a/crates/types/service-protocol/buf.yaml b/crates/types/service-protocol/buf.yaml new file mode 100644 index 000000000..ab3bd5be4 --- /dev/null +++ b/crates/types/service-protocol/buf.yaml @@ -0,0 +1,8 @@ +version: v2 +name: buf.build/restatedev/service-protocol +lint: + use: + - DEFAULT +breaking: + use: + - FILE diff --git a/crates/types/service-protocol/dev/restate/service/protocol.proto b/crates/types/service-protocol/dev/restate/service/protocol.proto index f7aa8b891..57cc5c63c 100644 --- a/crates/types/service-protocol/dev/restate/service/protocol.proto +++ b/crates/types/service-protocol/dev/restate/service/protocol.proto @@ -22,6 +22,10 @@ enum ServiceProtocolVersion { // Added // * Entry retry mechanism: ErrorMessage.next_retry_delay, StartMessage.retry_count_since_last_stored_entry and StartMessage.duration_since_last_stored_entry V2 = 2; + // Added + // * New entry to cancel invocations: CancelInvocationEntryMessage + // * New entry to retrieve the invocation id: GetCallInvocationIdEntryMessage + V3 = 3; } // --- Core frames --- @@ -53,8 +57,8 @@ message StartMessage { // Retry count since the last stored entry. // - // Please not this count might not be accurate, as it's not durably stored, - // thus it's susceptible to Restate's crashes/leader election changes. + // Please note that this count might not be accurate, as it's not durably stored, + // thus it might get reset in case Restate crashes/changes leader. uint32 retry_count_since_last_stored_entry = 7; // Duration since the last stored entry, in milliseconds. @@ -383,13 +387,43 @@ message CompleteAwakeableEntryMessage { message RunEntryMessage { oneof result { bytes value = 14; - dev.restate.service.protocol.Failure failure = 15; + Failure failure = 15; }; // Entry name string name = 12; } +// Completable: No +// Fallible: Yes +// Type: 0x0C00 + 6 +message CancelInvocationEntryMessage { + oneof target { + // Target invocation id to cancel + string invocation_id = 1; + // Target index of the call/one way call journal entry in this journal. + uint32 call_entry_index = 2; + } + + // Entry name + string name = 12; +} + +// Completable: Yes +// Fallible: Yes +// Type: 0x0C00 + 7 +message GetCallInvocationIdEntryMessage { + // Index of the call/one way call journal entry in this journal. + uint32 call_entry_index = 1; + + oneof result { + string value = 14; + Failure failure = 15; + }; + + string name = 12; +} + // --- Nested messages // This failure object carries user visible errors, diff --git a/crates/types/service-protocol/service-invocation-protocol.md b/crates/types/service-protocol/service-invocation-protocol.md index 22635bc19..89e23cbdf 100644 --- a/crates/types/service-protocol/service-invocation-protocol.md +++ b/crates/types/service-protocol/service-invocation-protocol.md @@ -330,24 +330,26 @@ used for observability purposes by Restate observability tools. The following tables describe the currently available journal entries. For more details, check the protobuf message descriptions in [`protocol.proto`](dev/restate/service/protocol.proto). -| Message | Type | Completable | Fallible | Description | -| ------------------------------- | -------- | ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `InputEntryMessage` | `0x0400` | No | No | Carries the invocation input message(s) of the invocation. | -| `GetStateEntryMessage` | `0x0800` | Yes | No | Get the value of a service instance state key. | -| `GetStateKeysEntryMessage` | `0x0804` | Yes | No | Get all the known state keys for this service instance. Note: the completion value for this message is a protobuf of type `GetStateKeysEntryMessage.StateKeys`. | -| `SleepEntryMessage` | `0x0C00` | Yes | No | Initiate a timer that completes after the given time. | -| `CallEntryMessage` | `0x0C01` | Yes | Yes | Invoke another Restate service. | -| `AwakeableEntryMessage` | `0x0C03` | Yes | No | Arbitrary result container which can be completed from another service, given a specific id. See [Awakeable identifier](#awakeable-identifier) for more details. | -| `OneWayCallEntryMessage` | `0x0C02` | No | Yes | Invoke another Restate service at the given time, without waiting for the response. | -| `CompleteAwakeableEntryMessage` | `0x0C04` | No | Yes | Complete an `Awakeable`, given its id. See [Awakeable identifier](#awakeable-identifier) for more details. | -| `OutputEntryMessage` | `0x0401` | No | No | Carries the invocation output message(s) or terminal failure of the invocation. | -| `SetStateEntryMessage` | `0x0800` | No | No | Set the value of a service instance state key. | -| `ClearStateEntryMessage` | `0x0801` | No | No | Clear the value of a service instance state key. | -| `ClearAllStateEntryMessage` | `0x0802` | No | No | Clear all the values of the service instance state. | -| `RunEntryMessage` | `0x0C05` | No | No | Run non-deterministic user provided code and persist the result. | -| `GetPromiseEntryMessage` | `0x0808` | Yes | No | Get or wait the value of the given promise. If the value is not present yet, this entry will block waiting for the value. | -| `PeekPromiseEntryMessage` | `0x0809` | Yes | No | Get the value of the given promise. If the value is not present, this entry completes immediately with empty completion. | -| `CompletePromiseEntryMessage` | `0x080A` | Yes | No | Complete the given promise. If the promise was completed already, this entry completes with a failure. | +| Message | Type | Completable | Fallible | Description | +|-----------------------------------|----------|-------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `InputEntryMessage` | `0x0400` | No | No | Carries the invocation input message(s) of the invocation. | +| `GetStateEntryMessage` | `0x0800` | Yes | No | Get the value of a service instance state key. | +| `GetStateKeysEntryMessage` | `0x0804` | Yes | No | Get all the known state keys for this service instance. Note: the completion value for this message is a protobuf of type `GetStateKeysEntryMessage.StateKeys`. | +| `SleepEntryMessage` | `0x0C00` | Yes | No | Initiate a timer that completes after the given time. | +| `CallEntryMessage` | `0x0C01` | Yes | Yes | Invoke another Restate service. | +| `AwakeableEntryMessage` | `0x0C03` | Yes | No | Arbitrary result container which can be completed from another service, given a specific id. See [Awakeable identifier](#awakeable-identifier) for more details. | +| `OneWayCallEntryMessage` | `0x0C02` | No | Yes | Invoke another Restate service at the given time, without waiting for the response. | +| `CompleteAwakeableEntryMessage` | `0x0C04` | No | Yes | Complete an `Awakeable`, given its id. See [Awakeable identifier](#awakeable-identifier) for more details. | +| `OutputEntryMessage` | `0x0401` | No | No | Carries the invocation output message(s) or terminal failure of the invocation. | +| `SetStateEntryMessage` | `0x0800` | No | No | Set the value of a service instance state key. | +| `ClearStateEntryMessage` | `0x0801` | No | No | Clear the value of a service instance state key. | +| `ClearAllStateEntryMessage` | `0x0802` | No | No | Clear all the values of the service instance state. | +| `RunEntryMessage` | `0x0C05` | No | No | Run non-deterministic user provided code and persist the result. | +| `GetPromiseEntryMessage` | `0x0808` | Yes | No | Get or wait the value of the given promise. If the value is not present yet, this entry will block waiting for the value. | +| `PeekPromiseEntryMessage` | `0x0809` | Yes | No | Get the value of the given promise. If the value is not present, this entry completes immediately with empty completion. | +| `CompletePromiseEntryMessage` | `0x080A` | Yes | No | Complete the given promise. If the promise was completed already, this entry completes with a failure. | +| `CancelInvocationEntryMessage` | `0x0C06` | No | Yes | Cancel the target invocation id or the target journal entry. | +| `GetCallInvocationIdEntryMessage` | `0x0C07` | Yes | Yes | Get the invocation id of a previously created call/one way call. | #### Awakeable identifier diff --git a/crates/types/src/journal/entries.rs b/crates/types/src/journal/entries.rs index 48b2928b8..04f7af3ec 100644 --- a/crates/types/src/journal/entries.rs +++ b/crates/types/src/journal/entries.rs @@ -41,6 +41,8 @@ pub enum Entry { Awakeable(AwakeableEntry), CompleteAwakeable(CompleteAwakeableEntry), Run(RunEntry), + CancelInvocation(CancelInvocationEntry), + GetCallInvocationId(GetCallInvocationIdEntry), Custom(Bytes), } @@ -105,6 +107,20 @@ impl Entry { pub fn awakeable(result: Option) -> Self { Entry::Awakeable(AwakeableEntry { result }) } + + pub fn cancel_invocation(target: CancelInvocationTarget) -> Entry { + Entry::CancelInvocation(CancelInvocationEntry { target }) + } + + pub fn get_call_invocation_id( + call_entry_index: EntryIndex, + result: Option, + ) -> Entry { + Entry::GetCallInvocationId(GetCallInvocationIdEntry { + call_entry_index, + result, + }) + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -171,6 +187,8 @@ pub enum EntryType { Awakeable, CompleteAwakeable, Run, + CancelInvocation, + GetCallInvocationId, Custom, } @@ -214,6 +232,7 @@ mod private { impl Sealed for SleepEntry {} impl Sealed for InvokeEntry {} impl Sealed for AwakeableEntry {} + impl Sealed for GetCallInvocationIdEntry {} } #[derive(Debug, Clone, PartialEq, Eq)] @@ -396,3 +415,32 @@ pub struct CompleteAwakeableEntry { pub struct RunEntry { pub result: EntryResult, } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CancelInvocationTarget { + InvocationId(ByteString), + CallEntryIndex(EntryIndex), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CancelInvocationEntry { + pub target: CancelInvocationTarget, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum GetCallInvocationIdResult { + InvocationId(String), + Failure(InvocationErrorCode, ByteString), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GetCallInvocationIdEntry { + pub call_entry_index: EntryIndex, + pub result: Option, +} + +impl CompletableEntry for GetCallInvocationIdEntry { + fn is_completed(&self) -> bool { + self.result.is_some() + } +} diff --git a/crates/types/src/journal/raw.rs b/crates/types/src/journal/raw.rs index f44da8030..439c44b10 100644 --- a/crates/types/src/journal/raw.rs +++ b/crates/types/src/journal/raw.rs @@ -143,6 +143,10 @@ pub enum EntryHeader { enrichment_result: AwakeableEnrichmentResult, }, Run, + CancelInvocation, + GetCallInvocationId { + is_completed: bool, + }, Custom { code: u16, }, @@ -170,6 +174,8 @@ impl EntryHeader::GetPromise { is_completed } => Some(*is_completed), EntryHeader::PeekPromise { is_completed } => Some(*is_completed), EntryHeader::CompletePromise { is_completed } => Some(*is_completed), + EntryHeader::CancelInvocation => None, + EntryHeader::GetCallInvocationId { is_completed } => Some(*is_completed), } } @@ -192,6 +198,8 @@ impl EntryHeader::GetPromise { is_completed } => *is_completed = true, EntryHeader::PeekPromise { is_completed } => *is_completed = true, EntryHeader::CompletePromise { is_completed } => *is_completed = true, + EntryHeader::CancelInvocation => {} + EntryHeader::GetCallInvocationId { is_completed } => *is_completed = true, } } @@ -214,6 +222,8 @@ impl EntryHeader::GetPromise { .. } => EntryType::GetPromise, EntryHeader::PeekPromise { .. } => EntryType::PeekPromise, EntryHeader::CompletePromise { .. } => EntryType::CompletePromise, + EntryHeader::CancelInvocation => EntryType::CancelInvocation, + EntryHeader::GetCallInvocationId { .. } => EntryType::GetCallInvocationId, } } @@ -247,6 +257,10 @@ impl EntryHeader::CompletePromise { is_completed } => { EntryHeader::CompletePromise { is_completed } } + EntryHeader::CancelInvocation => EntryHeader::CancelInvocation, + EntryHeader::GetCallInvocationId { is_completed } => { + EntryHeader::GetCallInvocationId { is_completed } + } } } } diff --git a/crates/types/src/service_protocol.rs b/crates/types/src/service_protocol.rs index c7abdd310..965b16101 100644 --- a/crates/types/src/service_protocol.rs +++ b/crates/types/src/service_protocol.rs @@ -13,7 +13,7 @@ use std::ops::RangeInclusive; // Range of supported service protocol versions by this server pub const MIN_SERVICE_PROTOCOL_VERSION: ServiceProtocolVersion = ServiceProtocolVersion::V1; -pub const MAX_SERVICE_PROTOCOL_VERSION: ServiceProtocolVersion = ServiceProtocolVersion::V2; +pub const MAX_SERVICE_PROTOCOL_VERSION: ServiceProtocolVersion = ServiceProtocolVersion::V3; pub const MAX_SERVICE_PROTOCOL_VERSION_VALUE: i32 = i32::MAX; @@ -79,11 +79,12 @@ mod pb_into { use super::*; use crate::journal::{ - AwakeableEntry, ClearStateEntry, CompleteAwakeableEntry, CompletePromiseEntry, - CompleteResult, CompletionResult, Entry, EntryResult, GetPromiseEntry, GetStateEntry, - GetStateKeysEntry, GetStateKeysResult, InputEntry, InvokeEntry, InvokeRequest, - OneWayCallEntry, OutputEntry, PeekPromiseEntry, RunEntry, SetStateEntry, SleepEntry, - SleepResult, + AwakeableEntry, CancelInvocationEntry, CancelInvocationTarget, ClearStateEntry, + CompleteAwakeableEntry, CompletePromiseEntry, CompleteResult, CompletionResult, Entry, + EntryResult, GetCallInvocationIdEntry, GetCallInvocationIdResult, GetPromiseEntry, + GetStateEntry, GetStateKeysEntry, GetStateKeysResult, InputEntry, InvokeEntry, + InvokeRequest, OneWayCallEntry, OutputEntry, PeekPromiseEntry, RunEntry, SetStateEntry, + SleepEntry, SleepResult, }; impl TryFrom for Entry { @@ -328,4 +329,40 @@ mod pb_into { })) } } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: CancelInvocationEntryMessage) -> Result { + Ok(Self::CancelInvocation(CancelInvocationEntry { + target: match msg.target.ok_or("target")? { + cancel_invocation_entry_message::Target::InvocationId(s) => { + CancelInvocationTarget::InvocationId(s.into()) + } + cancel_invocation_entry_message::Target::CallEntryIndex(i) => { + CancelInvocationTarget::CallEntryIndex(i) + } + }, + })) + } + } + + impl TryFrom for Entry { + type Error = &'static str; + + fn try_from(msg: GetCallInvocationIdEntryMessage) -> Result { + Ok(Self::GetCallInvocationId(GetCallInvocationIdEntry { + call_entry_index: msg.call_entry_index, + result: msg.result.map(|v| match v { + get_call_invocation_id_entry_message::Result::Value(r) => { + GetCallInvocationIdResult::InvocationId(r) + } + get_call_invocation_id_entry_message::Result::Failure(Failure { + code, + message, + }) => GetCallInvocationIdResult::Failure(code.into(), message.into()), + }), + })) + } + } } diff --git a/crates/worker/src/invoker_integration.rs b/crates/worker/src/invoker_integration.rs index b0c129251..45278fefb 100644 --- a/crates/worker/src/invoker_integration.rs +++ b/crates/worker/src/invoker_integration.rs @@ -26,7 +26,10 @@ use restate_types::journal::enriched::{ AwakeableEnrichmentResult, CallEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, }; use restate_types::journal::raw::{PlainEntryHeader, PlainRawEntry, RawEntry, RawEntryCodec}; -use restate_types::journal::{CompleteAwakeableEntry, Entry, InvokeEntry, OneWayCallEntry}; +use restate_types::journal::{ + CancelInvocationEntry, CancelInvocationTarget, CompleteAwakeableEntry, Entry, InvokeEntry, + OneWayCallEntry, +}; use restate_types::journal::{EntryType, InvokeRequest}; use restate_types::live::Live; use restate_types::schema::invocation_target::InvocationTargetResolver; @@ -250,6 +253,29 @@ where } } PlainEntryHeader::Run { .. } => EnrichedEntryHeader::Run {}, + PlainEntryHeader::CancelInvocation { .. } => { + // Validate the invocation id is valid + let entry = + Codec::deserialize(EntryType::CancelInvocation, serialized_entry.clone()) + .map_err(InvocationError::internal)?; + let_assert!(Entry::CancelInvocation(CancelInvocationEntry { target }) = entry); + if let CancelInvocationTarget::InvocationId(id) = target { + if let Err(e) = id.parse::() { + return Err(InvocationError::new( + codes::BAD_REQUEST, + format!( + "The given invocation id '{}' to cancel is invalid: {}", + id, e + ), + )); + } + } + + EnrichedEntryHeader::CancelInvocation {} + } + PlainEntryHeader::GetCallInvocationId { is_completed } => { + EnrichedEntryHeader::GetCallInvocationId { is_completed } + } PlainEntryHeader::Custom { code } => EnrichedEntryHeader::Custom { code }, }; diff --git a/crates/worker/src/partition/state_machine/mod.rs b/crates/worker/src/partition/state_machine/mod.rs index 77f6f99a4..4a17776dd 100644 --- a/crates/worker/src/partition/state_machine/mod.rs +++ b/crates/worker/src/partition/state_machine/mod.rs @@ -69,7 +69,7 @@ use restate_types::journal::enriched::EnrichedRawEntry; use restate_types::journal::enriched::{ AwakeableEnrichmentResult, CallEnrichmentResult, EnrichedEntryHeader, }; -use restate_types::journal::raw::{RawEntryCodec, RawEntryCodecError}; +use restate_types::journal::raw::{EntryHeader, RawEntryCodec, RawEntryCodecError}; use restate_types::journal::Completion; use restate_types::journal::CompletionResult; use restate_types::journal::EntryType; @@ -2372,6 +2372,48 @@ impl StateMachine { EnrichedEntryHeader::Run { .. } | EnrichedEntryHeader::Custom { .. } => { // We just store it } + EntryHeader::CancelInvocation => { + let_assert!( + Entry::CancelInvocation(entry) = + journal_entry.deserialize_entry_ref::()? + ); + self.apply_cancel_invocation_journal_entry_action(ctx, &invocation_id, entry) + .await?; + } + EntryHeader::GetCallInvocationId { is_completed } => { + if !is_completed { + let_assert!( + Entry::GetCallInvocationId(entry) = + journal_entry.deserialize_entry_ref::()? + ); + let callee_invocation_id = Self::get_journal_entry_callee_invocation_id( + ctx, + &invocation_id, + entry.call_entry_index, + ) + .await?; + + if let Some(callee_invocation_id) = callee_invocation_id { + let completion_result = CompletionResult::Success(Bytes::from( + callee_invocation_id.to_string(), + )); + + Codec::write_completion(&mut journal_entry, completion_result.clone())?; + Self::forward_completion( + ctx, + invocation_id, + Completion::new(entry_index, completion_result), + ); + } else { + // Nothing we can do here, just forward an empty completion (which is invalid for this entry). + Self::forward_completion( + ctx, + invocation_id, + Completion::new(entry_index, CompletionResult::Empty), + ); + } + } + } } Self::append_journal_entry( @@ -2390,6 +2432,94 @@ impl StateMachine { Ok(()) } + async fn apply_cancel_invocation_journal_entry_action< + State: OutboxTable + FsmTable + ReadOnlyJournalTable, + >( + &mut self, + ctx: &mut StateMachineApplyContext<'_, State>, + invocation_id: &InvocationId, + entry: CancelInvocationEntry, + ) -> Result<(), Error> { + let target_invocation_id = match entry.target { + CancelInvocationTarget::InvocationId(id) => { + if let Ok(id) = id.parse::() { + Some(id) + } else { + warn!( + "Error when trying to parse the invocation id '{}' of CancelInvocation. \ + This should have been previously checked by the invoker.", + id + ); + None + } + } + CancelInvocationTarget::CallEntryIndex(call_entry_index) => { + // Look for the given entry index, then resolve the invocation id. + Self::get_journal_entry_callee_invocation_id(ctx, invocation_id, call_entry_index) + .await? + } + }; + + if let Some(target_invocation_id) = target_invocation_id { + self.handle_outgoing_message( + ctx, + OutboxMessage::InvocationTermination(InvocationTermination { + invocation_id: target_invocation_id, + flavor: TerminationFlavor::Cancel, + }), + ) + .await?; + } + Ok(()) + } + + async fn get_journal_entry_callee_invocation_id( + ctx: &mut StateMachineApplyContext<'_, State>, + invocation_id: &InvocationId, + call_entry_index: EntryIndex, + ) -> Result, Error> { + Ok( + match ctx + .storage + .get_journal_entry(invocation_id, call_entry_index) + .await? + { + Some(JournalEntry::Entry(e)) => { + match e.header() { + EnrichedEntryHeader::Call { + enrichment_result: Some(CallEnrichmentResult { invocation_id, .. }), + .. + } + | EnrichedEntryHeader::OneWayCall { + enrichment_result: CallEnrichmentResult { invocation_id, .. }, + .. + } => Some(*invocation_id), + // This is the corner case when there is no enrichment result due to + // the invocation being already completed from the SDK. Nothing to do here. + EnrichedEntryHeader::Call { + enrichment_result: None, + .. + } => None, + _ => { + warn!( + "The given journal entry index '{}' is not a Call/OneWayCall entry.", + call_entry_index + ); + None + } + } + } + _ => { + warn!( + "The given journal entry index '{}' does not exist.", + call_entry_index + ); + None + } + }, + ) + } + async fn handle_completion( ctx: &mut StateMachineApplyContext<'_, State>, invocation_id: InvocationId, diff --git a/crates/worker/src/partition/state_machine/tests/fixtures.rs b/crates/worker/src/partition/state_machine/tests/fixtures.rs new file mode 100644 index 000000000..d4f3ac0a7 --- /dev/null +++ b/crates/worker/src/partition/state_machine/tests/fixtures.rs @@ -0,0 +1,121 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use crate::partition::state_machine::tests::TestEnv; +use crate::partition::state_machine::Action; +use bytes::Bytes; +use googletest::prelude::*; +use restate_invoker_api::InvokeInputJournal; +use restate_storage_api::journal_table::JournalEntry; +use restate_types::identifiers::{InvocationId, ServiceId}; +use restate_types::invocation::{ + InvocationTarget, ServiceInvocation, ServiceInvocationSpanContext, Source, +}; +use restate_types::journal::enriched::{ + CallEnrichmentResult, EnrichedEntryHeader, EnrichedRawEntry, +}; +use restate_wal_protocol::Command; + +pub fn completed_invoke_entry(invocation_id: InvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Call { + is_completed: true, + enrichment_result: Some(CallEnrichmentResult { + invocation_id, + invocation_target: InvocationTarget::mock_service(), + completion_retention_time: None, + span_context: ServiceInvocationSpanContext::empty(), + }), + }, + Bytes::default(), + )) +} + +pub fn background_invoke_entry(invocation_id: InvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::OneWayCall { + enrichment_result: CallEnrichmentResult { + invocation_id, + invocation_target: InvocationTarget::mock_service(), + completion_retention_time: None, + span_context: ServiceInvocationSpanContext::empty(), + }, + }, + Bytes::default(), + )) +} + +pub fn incomplete_invoke_entry(invocation_id: InvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Call { + is_completed: false, + enrichment_result: Some(CallEnrichmentResult { + invocation_id, + invocation_target: InvocationTarget::mock_service(), + completion_retention_time: None, + span_context: ServiceInvocationSpanContext::empty(), + }), + }, + Bytes::default(), + )) +} + +pub async fn mock_start_invocation_with_service_id( + state_machine: &mut TestEnv, + service_id: ServiceId, +) -> InvocationId { + mock_start_invocation_with_invocation_target( + state_machine, + InvocationTarget::mock_from_service_id(service_id), + ) + .await +} + +pub async fn mock_start_invocation_with_invocation_target( + state_machine: &mut TestEnv, + invocation_target: InvocationTarget, +) -> InvocationId { + let invocation_id = InvocationId::mock_generate(&invocation_target); + + let actions = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + argument: Default::default(), + source: Source::Ingress, + response_sink: None, + span_context: Default::default(), + headers: vec![], + execution_time: None, + completion_retention_duration: None, + idempotency_key: None, + submit_notification_sink: None, + })) + .await; + + assert_that!( + actions, + contains(pat!(Action::Invoke { + invocation_id: eq(invocation_id), + invocation_target: eq(invocation_target), + invoke_input_journal: pat!(InvokeInputJournal::CachedJournal(_, _)) + })) + ); + + invocation_id +} + +pub async fn mock_start_invocation(state_machine: &mut TestEnv) -> InvocationId { + mock_start_invocation_with_invocation_target( + state_machine, + InvocationTarget::mock_virtual_object(), + ) + .await +} diff --git a/crates/worker/src/partition/state_machine/tests/kill_cancel.rs b/crates/worker/src/partition/state_machine/tests/kill_cancel.rs index c7e8b5532..f90560495 100644 --- a/crates/worker/src/partition/state_machine/tests/kill_cancel.rs +++ b/crates/worker/src/partition/state_machine/tests/kill_cancel.rs @@ -8,7 +8,7 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use super::*; +use super::{fixtures, matchers, *}; use assert2::assert; use assert2::let_assert; @@ -17,8 +17,8 @@ use prost::Message; use restate_storage_api::journal_table::JournalTable; use restate_storage_api::timer_table::{Timer, TimerKey, TimerKeyKind, TimerTable}; use restate_types::identifiers::EntryIndex; -use restate_types::invocation::{ServiceInvocationSpanContext, TerminationFlavor}; -use restate_types::journal::enriched::{CallEnrichmentResult, EnrichedEntryHeader}; +use restate_types::invocation::TerminationFlavor; +use restate_types::journal::enriched::EnrichedEntryHeader; use restate_types::service_protocol; use test_log::test; @@ -142,19 +142,19 @@ async fn kill_call_tree() -> anyhow::Result<()> { tx.put_journal_entry( &invocation_id, 1, - &uncompleted_invoke_entry(call_invocation_id), + &fixtures::incomplete_invoke_entry(call_invocation_id), ) .await; tx.put_journal_entry( &invocation_id, 2, - &background_invoke_entry(background_call_invocation_id), + &fixtures::background_invoke_entry(background_call_invocation_id), ) .await; tx.put_journal_entry( &invocation_id, 3, - &completed_invoke_entry(finished_call_invocation_id), + &fixtures::completed_invoke_entry(finished_call_invocation_id), ) .await; let mut invocation_status = tx.get_invocation_status(&invocation_id).await?; @@ -195,7 +195,7 @@ async fn kill_call_tree() -> anyhow::Result<()> { invocation_id: eq(enqueued_invocation_id_on_same_target), invocation_target: eq(invocation_target) })), - contains(terminate_invocation_outbox_message_matcher( + contains(matchers::actions::terminate_invocation( call_invocation_id, TerminationFlavor::Kill )), @@ -307,23 +307,23 @@ async fn cancel_invoked_invocation() -> Result<(), Error> { // Entries are completed for idx in 4..=9 { - assert_entry_completed(&mut test_env, invocation_id, idx).await?; + assert_entry_completed(&mut test_env, invocation_id, idx).await; } assert_that!( actions, all!( - contains(terminate_invocation_outbox_message_matcher( + contains(matchers::actions::terminate_invocation( call_invocation_id, TerminationFlavor::Cancel )), - contains(forward_canceled_completion_matcher(4)), - contains(forward_canceled_completion_matcher(5)), - contains(forward_canceled_completion_matcher(6)), - contains(forward_canceled_completion_matcher(7)), - contains(forward_canceled_completion_matcher(8)), - contains(forward_canceled_completion_matcher(9)), - contains(delete_timer_matcher(5)), + contains(matchers::actions::forward_canceled_completion(4)), + contains(matchers::actions::forward_canceled_completion(5)), + contains(matchers::actions::forward_canceled_completion(6)), + contains(matchers::actions::forward_canceled_completion(7)), + contains(matchers::actions::forward_canceled_completion(8)), + contains(matchers::actions::forward_canceled_completion(9)), + contains(matchers::actions::delete_sleep_timer(5)), ) ); @@ -426,17 +426,17 @@ async fn cancel_suspended_invocation() -> Result<(), Error> { // Entries are completed for idx in 4..=9 { - assert_entry_completed(&mut test_env, invocation_id, idx).await?; + assert_entry_completed(&mut test_env, invocation_id, idx).await; } assert_that!( actions, all!( - contains(terminate_invocation_outbox_message_matcher( + contains(matchers::actions::terminate_invocation( call_invocation_id, TerminationFlavor::Cancel )), - contains(delete_timer_matcher(5)), + contains(matchers::actions::delete_sleep_timer(5)), contains(pat!(Action::Invoke { invocation_id: eq(invocation_id), invocation_target: eq(invocation_target) @@ -448,48 +448,101 @@ async fn cancel_suspended_invocation() -> Result<(), Error> { Ok(()) } -fn completed_invoke_entry(invocation_id: InvocationId) -> JournalEntry { - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Call { - is_completed: true, - enrichment_result: Some(CallEnrichmentResult { +#[test(tokio::test)] +async fn cancel_invocation_entry_referring_to_previous_entry() { + let mut test_env = TestEnv::create().await; + + let invocation_target = InvocationTarget::mock_service(); + let invocation_id = InvocationId::mock_random(); + + let callee_1 = InvocationId::mock_random(); + let callee_2 = InvocationId::mock_random(); + + let _ = test_env + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + // Add call and one way call journal entry + let mut tx = test_env.storage.transaction(); + tx.put_journal_entry( + &invocation_id, + 1, + &fixtures::background_invoke_entry(callee_1), + ) + .await; + tx.put_journal_entry( + &invocation_id, + 2, + &fixtures::incomplete_invoke_entry(callee_2), + ) + .await; + let mut invocation_status = tx.get_invocation_status(&invocation_id).await.unwrap(); + invocation_status.get_journal_metadata_mut().unwrap().length = 3; + tx.put_invocation_status(&invocation_id, &invocation_status) + .await; + tx.commit().await.unwrap(); + + // Now create cancel invocation entry + let actions = test_env + .apply_multiple(vec![ + Command::InvokerEffect(InvokerEffect { invocation_id, - invocation_target: InvocationTarget::mock_service(), - completion_retention_time: None, - span_context: ServiceInvocationSpanContext::empty(), + kind: InvokerEffectKind::JournalEntry { + entry_index: 3, + entry: ProtobufRawEntryCodec::serialize_enriched(Entry::cancel_invocation( + CancelInvocationTarget::InvocationId(callee_1.to_string().into()), + )), + }, }), - }, - Bytes::default(), - )) -} - -fn background_invoke_entry(invocation_id: InvocationId) -> JournalEntry { - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::OneWayCall { - enrichment_result: CallEnrichmentResult { + Command::InvokerEffect(InvokerEffect { invocation_id, - invocation_target: InvocationTarget::mock_service(), - completion_retention_time: None, - span_context: ServiceInvocationSpanContext::empty(), - }, - }, - Bytes::default(), - )) + kind: InvokerEffectKind::JournalEntry { + entry_index: 4, + entry: ProtobufRawEntryCodec::serialize_enriched(Entry::cancel_invocation( + CancelInvocationTarget::CallEntryIndex(2), + )), + }, + }), + ]) + .await; + + assert_that!( + actions, + all!( + contains(matchers::actions::terminate_invocation( + callee_1, + TerminationFlavor::Cancel + )), + contains(matchers::actions::terminate_invocation( + callee_2, + TerminationFlavor::Cancel + )), + ) + ); + assert_that!( + test_env.storage.get_invocation_status(&invocation_id).await, + ok(pat!(InvocationStatus::Invoked { .. })) + ); + test_env.shutdown().await; } -fn uncompleted_invoke_entry(invocation_id: InvocationId) -> JournalEntry { - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Call { - is_completed: false, - enrichment_result: Some(CallEnrichmentResult { - invocation_id, - invocation_target: InvocationTarget::mock_service(), - completion_retention_time: None, - span_context: ServiceInvocationSpanContext::empty(), - }), - }, - Bytes::default(), - )) +async fn assert_entry_completed( + test_env: &mut TestEnv, + invocation_id: InvocationId, + idx: EntryIndex, +) { + assert_that!( + test_env + .storage + .get_journal_entry(&invocation_id, idx) + .await + .unwrap(), + some(pat!(JournalEntry::Entry(matchers::completed_entry()))) + ); } fn create_termination_journal( @@ -498,9 +551,9 @@ fn create_termination_journal( finished_call_invocation_id: InvocationId, ) -> Vec { vec![ - uncompleted_invoke_entry(call_invocation_id), - completed_invoke_entry(finished_call_invocation_id), - background_invoke_entry(background_invocation_id), + fixtures::incomplete_invoke_entry(call_invocation_id), + fixtures::completed_invoke_entry(finished_call_invocation_id), + fixtures::background_invoke_entry(background_invocation_id), JournalEntry::Entry(EnrichedRawEntry::new( EnrichedEntryHeader::GetState { is_completed: false, @@ -560,66 +613,3 @@ fn create_termination_journal( )), ] } - -async fn assert_entry_completed( - test_env: &mut TestEnv, - invocation_id: InvocationId, - idx: EntryIndex, -) -> Result<(), Error> { - assert_that!( - test_env - .storage - .get_journal_entry(&invocation_id, idx) - .await?, - some(pat!(JournalEntry::Entry(entry_completed_matcher()))) - ); - Ok(()) -} - -fn canceled_completion_matcher(entry_index: EntryIndex) -> impl Matcher { - pat!(Completion { - entry_index: eq(entry_index), - result: pat!(CompletionResult::Failure( - eq(codes::ABORTED), - eq(ByteString::from_static("canceled")) - )) - }) -} - -fn entry_completed_matcher() -> impl Matcher { - predicate(|e: &EnrichedRawEntry| e.header().is_completed().unwrap_or(false)) - .with_description("completed entry", "uncompleted entry") -} - -fn forward_canceled_completion_matcher(entry_index: EntryIndex) -> impl Matcher { - pat!(Action::ForwardCompletion { - completion: canceled_completion_matcher(entry_index), - }) -} - -fn delete_timer_matcher(entry_index: EntryIndex) -> impl Matcher { - pat!(Action::DeleteTimer { - timer_key: pat!(TimerKey { - kind: pat!(TimerKeyKind::CompleteJournalEntry { - journal_index: eq(entry_index), - }), - timestamp: eq(1337), - }) - }) -} - -fn terminate_invocation_outbox_message_matcher( - target_invocation_id: InvocationId, - termination_flavor: TerminationFlavor, -) -> impl Matcher { - pat!(Action::NewOutboxMessage { - message: pat!( - restate_storage_api::outbox_table::OutboxMessage::InvocationTermination(pat!( - InvocationTermination { - invocation_id: eq(target_invocation_id), - flavor: eq(termination_flavor) - } - )) - ) - }) -} diff --git a/crates/worker/src/partition/state_machine/tests/matchers.rs b/crates/worker/src/partition/state_machine/tests/matchers.rs index 38a542afa..4aeeb7aa9 100644 --- a/crates/worker/src/partition/state_machine/tests/matchers.rs +++ b/crates/worker/src/partition/state_machine/tests/matchers.rs @@ -8,14 +8,25 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use bytes::Bytes; +use bytestring::ByteString; use googletest::prelude::*; +use restate_storage_api::timer_table::{TimerKey, TimerKeyKind}; +use restate_types::errors::codes; +use restate_types::identifiers::EntryIndex; +use restate_types::invocation::{InvocationTermination, TerminationFlavor}; +use restate_types::journal::enriched::EnrichedRawEntry; +use restate_types::journal::{Completion, CompletionResult}; pub mod storage { use super::*; + use restate_service_protocol::codec::ProtobufRawEntryCodec; use restate_storage_api::inbox_table::{InboxEntry, SequenceNumberInboxEntry}; + use restate_storage_api::journal_table::JournalEntry; use restate_types::identifiers::InvocationId; use restate_types::invocation::InvocationTarget; + use restate_types::journal::Entry; pub fn invocation_inbox_entry( invocation_id: InvocationId, @@ -28,6 +39,12 @@ pub mod storage { )) }) } + + pub fn is_entry(entry: Entry) -> impl Matcher { + pat!(JournalEntry::Entry(eq( + ProtobufRawEntryCodec::serialize_enriched(entry) + ))) + } } pub mod actions { @@ -41,4 +58,76 @@ pub mod actions { invocation_id: eq(invocation_id) }) } + + pub fn delete_sleep_timer(entry_index: EntryIndex) -> impl Matcher { + pat!(Action::DeleteTimer { + timer_key: pat!(TimerKey { + kind: pat!(TimerKeyKind::CompleteJournalEntry { + journal_index: eq(entry_index), + }), + timestamp: eq(1337), + }) + }) + } + + pub fn terminate_invocation( + target_invocation_id: InvocationId, + termination_flavor: TerminationFlavor, + ) -> impl Matcher { + pat!(Action::NewOutboxMessage { + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination(pat!( + InvocationTermination { + invocation_id: eq(target_invocation_id), + flavor: eq(termination_flavor) + } + )) + ) + }) + } + + pub fn forward_canceled_completion(entry_index: EntryIndex) -> impl Matcher { + pat!(Action::ForwardCompletion { + completion: canceled_completion(entry_index), + }) + } + + pub fn forward_completion( + invocation_id: InvocationId, + inner: impl Matcher + 'static, + ) -> impl Matcher { + pat!(Action::ForwardCompletion { + invocation_id: eq(invocation_id), + completion: inner, + }) + } +} + +pub fn completion( + entry_index: EntryIndex, + completion_result: CompletionResult, +) -> impl Matcher { + pat!(Completion { + entry_index: eq(entry_index), + result: eq(completion_result) + }) +} + +pub fn success_completion( + entry_index: EntryIndex, + bytes: impl Into, +) -> impl Matcher { + completion(entry_index, CompletionResult::Success(bytes.into())) +} + +pub fn canceled_completion(entry_index: EntryIndex) -> impl Matcher { + completion( + entry_index, + CompletionResult::Failure(codes::ABORTED, ByteString::from_static("canceled")), + ) +} + +pub fn completed_entry() -> impl Matcher { + predicate(|e: &EnrichedRawEntry| e.header().is_completed().unwrap_or(false)) + .with_description("completed entry", "uncompleted entry") } diff --git a/crates/worker/src/partition/state_machine/tests/mod.rs b/crates/worker/src/partition/state_machine/tests/mod.rs index a220be4ad..c980d5314 100644 --- a/crates/worker/src/partition/state_machine/tests/mod.rs +++ b/crates/worker/src/partition/state_machine/tests/mod.rs @@ -11,11 +11,17 @@ use super::*; mod delayed_send; +mod fixtures; mod idempotency; mod kill_cancel; mod matchers; mod workflow; +use crate::partition::state_machine::tests::fixtures::{ + background_invoke_entry, incomplete_invoke_entry, +}; +use crate::partition::state_machine::tests::matchers::storage::is_entry; +use crate::partition::state_machine::tests::matchers::success_completion; use crate::partition::types::{InvokerEffect, InvokerEffectKind}; use ::tracing::info; use bytes::Bytes; @@ -176,7 +182,7 @@ type TestResult = Result<(), anyhow::Error>; #[test(tokio::test)] async fn start_invocation() -> TestResult { let mut test_env = TestEnv::create().await; - let id = mock_start_invocation(&mut test_env).await; + let id = fixtures::mock_start_invocation(&mut test_env).await; let invocation_status = test_env.storage().get_invocation_status(&id).await.unwrap(); assert_that!(invocation_status, pat!(InvocationStatus::Invoked(_))); @@ -205,9 +211,11 @@ async fn shared_invocation_skips_inbox() -> TestResult { tx.commit().await.unwrap(); // Start the invocation - let invocation_id = - mock_start_invocation_with_invocation_target(&mut test_env, invocation_target.clone()) - .await; + let invocation_id = fixtures::mock_start_invocation_with_invocation_target( + &mut test_env, + invocation_target.clone(), + ) + .await; // Should be in invoked status let invocation_status = test_env @@ -230,7 +238,7 @@ async fn shared_invocation_skips_inbox() -> TestResult { #[test(tokio::test)] async fn awakeable_completion_received_before_entry() -> TestResult { let mut test_env = TestEnv::create().await; - let invocation_id = mock_start_invocation(&mut test_env).await; + let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; // Send completion first let _ = test_env @@ -334,7 +342,7 @@ async fn awakeable_completion_received_before_entry() -> TestResult { #[test(tokio::test)] async fn complete_awakeable_with_success() { let mut test_env = TestEnv::create().await; - let invocation_id = mock_start_invocation(&mut test_env).await; + let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; let callee_invocation_id = InvocationId::mock_random(); let callee_entry_index = 10; @@ -377,7 +385,7 @@ async fn complete_awakeable_with_success() { #[test(tokio::test)] async fn complete_awakeable_with_failure() { let mut test_env = TestEnv::create().await; - let invocation_id = mock_start_invocation(&mut test_env).await; + let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; let callee_invocation_id = InvocationId::mock_random(); let callee_entry_index = 10; @@ -425,7 +433,7 @@ async fn invoke_with_headers() -> TestResult { let mut test_env = TestEnv::create().await; let service_id = ServiceId::mock_random(); let invocation_id = - mock_start_invocation_with_service_id(&mut test_env, service_id.clone()).await; + fixtures::mock_start_invocation_with_service_id(&mut test_env, service_id.clone()).await; let actions = test_env .apply(Command::InvokerEffect(InvokerEffect { @@ -467,9 +475,11 @@ async fn mutate_state() -> anyhow::Result<()> { let mut test_env = TestEnv::create().await; let invocation_target = InvocationTarget::mock_virtual_object(); let keyed_service_id = invocation_target.as_keyed_service_id().unwrap(); - let invocation_id = - mock_start_invocation_with_invocation_target(&mut test_env, invocation_target.clone()) - .await; + let invocation_id = fixtures::mock_start_invocation_with_invocation_target( + &mut test_env, + invocation_target.clone(), + ) + .await; let first_state_mutation: HashMap = [ (Bytes::from_static(b"foobar"), Bytes::from_static(b"foobar")), @@ -543,7 +553,7 @@ async fn clear_all_user_states() -> anyhow::Result<()> { txn.commit().await.unwrap(); let invocation_id = - mock_start_invocation_with_service_id(&mut test_env, service_id.clone()).await; + fixtures::mock_start_invocation_with_service_id(&mut test_env, service_id.clone()).await; test_env .apply(Command::InvokerEffect(InvokerEffect { @@ -571,7 +581,7 @@ async fn get_state_keys() -> TestResult { let mut test_env = TestEnv::create().await; let service_id = ServiceId::mock_random(); let invocation_id = - mock_start_invocation_with_service_id(&mut test_env, service_id.clone()).await; + fixtures::mock_start_invocation_with_service_id(&mut test_env, service_id.clone()).await; // Mock some state let mut txn = test_env.storage.transaction(); @@ -592,21 +602,110 @@ async fn get_state_keys() -> TestResult { // At this point we expect the completion to be forwarded to the invoker assert_that!( actions, - contains(pat!(Action::ForwardCompletion { - invocation_id: eq(invocation_id), - completion: eq(Completion::new( + contains(matchers::actions::forward_completion( + invocation_id, + matchers::completion( 1, ProtobufRawEntryCodec::serialize_get_state_keys_completion(vec![ Bytes::copy_from_slice(b"key1"), Bytes::copy_from_slice(b"key2"), ]) - )) - })) + ) + )) ); test_env.shutdown().await; Ok(()) } +#[test(tokio::test)] +async fn get_invocation_id_entry() { + let mut test_env = TestEnv::create().await; + let invocation_id = fixtures::mock_start_invocation(&mut test_env).await; + + let callee_1 = InvocationId::mock_random(); + let callee_2 = InvocationId::mock_random(); + + // Mock some state + // Add call and one way call journal entry + let mut tx = test_env.storage.transaction(); + tx.put_journal_entry(&invocation_id, 1, &background_invoke_entry(callee_1)) + .await; + tx.put_journal_entry(&invocation_id, 2, &incomplete_invoke_entry(callee_2)) + .await; + let mut invocation_status = tx.get_invocation_status(&invocation_id).await.unwrap(); + invocation_status.get_journal_metadata_mut().unwrap().length = 3; + tx.put_invocation_status(&invocation_id, &invocation_status) + .await; + tx.commit().await.unwrap(); + + let actions = test_env + .apply_multiple(vec![ + Command::InvokerEffect(InvokerEffect { + invocation_id, + kind: InvokerEffectKind::JournalEntry { + entry_index: 3, + entry: ProtobufRawEntryCodec::serialize_enriched( + Entry::get_call_invocation_id(1, None), + ), + }, + }), + Command::InvokerEffect(InvokerEffect { + invocation_id, + kind: InvokerEffectKind::JournalEntry { + entry_index: 4, + entry: ProtobufRawEntryCodec::serialize_enriched( + Entry::get_call_invocation_id(2, None), + ), + }, + }), + ]) + .await; + + // Assert completion is forwarded and stored + assert_that!( + actions, + all!( + contains(matchers::actions::forward_completion( + invocation_id, + success_completion(3, callee_1.to_string()) + )), + contains(matchers::actions::forward_completion( + invocation_id, + success_completion(4, callee_2.to_string()) + )) + ) + ); + + assert_that!( + test_env + .storage + .get_journal_entry(&invocation_id, 3) + .await + .unwrap(), + some(is_entry(Entry::get_call_invocation_id( + 1, + Some(GetCallInvocationIdResult::InvocationId( + callee_1.to_string() + )) + ))) + ); + assert_that!( + test_env + .storage + .get_journal_entry(&invocation_id, 4) + .await + .unwrap(), + some(is_entry(Entry::get_call_invocation_id( + 2, + Some(GetCallInvocationIdResult::InvocationId( + callee_2.to_string() + )) + ))) + ); + + test_env.shutdown().await; +} + #[test(tokio::test)] async fn send_ingress_response_to_multiple_targets() -> TestResult { let mut test_env = TestEnv::create().await; @@ -900,56 +999,3 @@ async fn consecutive_exclusive_handler_invocations_will_use_inbox() -> TestResul test_env.shutdown().await; Ok(()) } - -async fn mock_start_invocation_with_service_id( - state_machine: &mut TestEnv, - service_id: ServiceId, -) -> InvocationId { - mock_start_invocation_with_invocation_target( - state_machine, - InvocationTarget::mock_from_service_id(service_id), - ) - .await -} - -async fn mock_start_invocation_with_invocation_target( - state_machine: &mut TestEnv, - invocation_target: InvocationTarget, -) -> InvocationId { - let invocation_id = InvocationId::mock_generate(&invocation_target); - - let actions = state_machine - .apply(Command::Invoke(ServiceInvocation { - invocation_id, - invocation_target: invocation_target.clone(), - argument: Default::default(), - source: Source::Ingress, - response_sink: None, - span_context: Default::default(), - headers: vec![], - execution_time: None, - completion_retention_duration: None, - idempotency_key: None, - submit_notification_sink: None, - })) - .await; - - assert_that!( - actions, - contains(pat!(Action::Invoke { - invocation_id: eq(invocation_id), - invocation_target: eq(invocation_target), - invoke_input_journal: pat!(InvokeInputJournal::CachedJournal(_, _)) - })) - ); - - invocation_id -} - -async fn mock_start_invocation(state_machine: &mut TestEnv) -> InvocationId { - mock_start_invocation_with_invocation_target( - state_machine, - InvocationTarget::mock_virtual_object(), - ) - .await -}