Skip to content

Commit

Permalink
Basic implementation of combinators (#14)
Browse files Browse the repository at this point in the history
* First pass at a combinator implementation

* We need to suspend when we can't make progress!

* Missing a bunch of instrument macro calls

* More logging changes to make it a tad bit better with the output

* Allow to disable the AwaitTwoAsyncResults error

* Make sure to pin prost-build

* Update the list of SDKs using it
  • Loading branch information
slinkydeveloper authored Sep 27, 2024
1 parent 3c690e0 commit 3f6f6d8
Show file tree
Hide file tree
Showing 18 changed files with 565 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ http = { version = "1.1.0", optional = true }
googletest = "0.11.0"
test-log = { version = "0.2.16", default-features = false, features = ["trace", "color"] }
assert2 = "0.3.14"
prost-build = "0.13.2"
prost-build = "=0.13.3"
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Shared core to build SDKs in various languages. Currently used by:

* [Python SDK](https://github.com/restatedev/sdk-python)
* [Rust SDK](https://github.com/restatedev/sdk-rust)

## Versions

Expand Down
22 changes: 22 additions & 0 deletions service-protocol-ext/combinators.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
*
* This file is part of the Restate SDK for Node.js/TypeScript,
* which is released under the MIT license.
*
* You can find a copy of the license in file LICENSE in the root
* directory of this repository or package, or at
* https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
*/

syntax = "proto3";

package dev.restate.service.protocol.extensions;

// Type: 0xFC00 + 2
message CombinatorEntryMessage {
repeated uint32 completed_entries_order = 1;

// Entry name
string name = 12;
}
48 changes: 46 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ mod vm;

use bytes::Bytes;
use std::borrow::Cow;
use std::fmt;
use std::time::Duration;

pub use crate::retries::RetryPolicy;
use crate::vm::AsyncResultAccessTrackerInner;
pub use headers::HeaderMap;
pub use request_identity::*;
pub use vm::CoreVM;
Expand Down Expand Up @@ -83,7 +85,7 @@ pub struct Target {
pub key: Option<String>,
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[derive(Debug, Hash, Clone, Copy, Eq, PartialEq)]
pub struct AsyncResultHandle(u32);

impl From<u32> for AsyncResultHandle {
Expand All @@ -106,6 +108,7 @@ pub enum Value {
Failure(Failure),
/// Only returned for get_state_keys
StateKeys(Vec<String>),
CombinatorResult(Vec<AsyncResultHandle>),
}

/// Terminal failure
Expand Down Expand Up @@ -162,8 +165,21 @@ pub enum TakeOutputResult {

pub type VMResult<T> = Result<T, VMError>;

pub struct VMOptions {
/// If true, false when two concurrent async results are awaited at the same time. If false, just log it.
pub fail_on_wait_concurrent_async_result: bool,
}

impl Default for VMOptions {
fn default() -> Self {
Self {
fail_on_wait_concurrent_async_result: true,
}
}
}

pub trait VM: Sized {
fn new(request_headers: impl HeaderMap) -> VMResult<Self>;
fn new(request_headers: impl HeaderMap, options: VMOptions) -> VMResult<Self>;

fn get_response_head(&self) -> ResponseHead;

Expand Down Expand Up @@ -257,6 +273,12 @@ pub trait VM: Sized {

/// Returns true if the state machine is between a sys_run_enter and sys_run_exit
fn is_inside_run(&self) -> bool;

/// Returns false if the combinator can't be completed yet.
fn sys_try_complete_combinator(
&mut self,
combinator: impl AsyncResultCombinator + fmt::Debug,
) -> VMResult<Option<AsyncResultHandle>>;
}

// HOW TO USE THIS API
Expand Down Expand Up @@ -300,5 +322,27 @@ pub trait VM: Sized {
// }
// io.close()

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AsyncResultState {
Success,
Failure,
NotReady,
}

pub struct AsyncResultAccessTracker(AsyncResultAccessTrackerInner);

impl AsyncResultAccessTracker {
pub fn get_state(&mut self, handle: AsyncResultHandle) -> AsyncResultState {
self.0.get_state(handle)
}
}

pub trait AsyncResultCombinator {
fn try_complete(
&self,
tracker: &mut AsyncResultAccessTracker,
) -> Option<Vec<AsyncResultHandle>>;
}

#[cfg(test)]
mod tests;
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// This file is @generated by prost-build.
/// Type: 0xFC00 + 2
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CombinatorEntryMessage {
#[prost(uint32, repeated, tag = "1")]
pub completed_entries_order: ::prost::alloc::vec::Vec<u32>,
/// Entry name
#[prost(string, tag = "12")]
pub name: ::prost::alloc::string::String,
}
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ impl ServiceProtocolVersion {
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
ServiceProtocolVersion::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED",
ServiceProtocolVersion::V1 => "V1",
ServiceProtocolVersion::V2 => "V2",
Self::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED",
Self::V1 => "V1",
Self::V2 => "V2",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
Expand Down
5 changes: 5 additions & 0 deletions src/service_protocol/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub enum MessageType {
GetPromiseEntry,
PeekPromiseEntry,
CompletePromiseEntry,
CombinatorEntry,
CustomEntry(u16),
}

Expand Down Expand Up @@ -75,6 +76,7 @@ impl MessageType {
MessageType::GetPromiseEntry => MessageKind::State,
MessageType::PeekPromiseEntry => MessageKind::State,
MessageType::CompletePromiseEntry => MessageKind::State,
MessageType::CombinatorEntry => MessageKind::Syscall,
MessageType::CustomEntry(_) => MessageKind::CustomEntry,
}
}
Expand Down Expand Up @@ -127,6 +129,7 @@ const BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: u16 = 0x0C02;
const AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C03;
const COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C04;
const SIDE_EFFECT_ENTRY_MESSAGE_TYPE: u16 = 0x0C05;
const COMBINATOR_ENTRY_MESSAGE_TYPE: u16 = 0xFC02;

impl From<MessageType> for MessageTypeId {
fn from(mt: MessageType) -> Self {
Expand All @@ -153,6 +156,7 @@ impl From<MessageType> for MessageTypeId {
MessageType::GetPromiseEntry => GET_PROMISE_ENTRY_MESSAGE_TYPE,
MessageType::PeekPromiseEntry => PEEK_PROMISE_ENTRY_MESSAGE_TYPE,
MessageType::CompletePromiseEntry => COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE,
MessageType::CombinatorEntry => COMBINATOR_ENTRY_MESSAGE_TYPE,
MessageType::CustomEntry(id) => id,
}
}
Expand Down Expand Up @@ -189,6 +193,7 @@ impl TryFrom<MessageTypeId> for MessageType {
PEEK_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::PeekPromiseEntry),
COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CompletePromiseEntry),
SIDE_EFFECT_ENTRY_MESSAGE_TYPE => Ok(MessageType::RunEntry),
COMBINATOR_ENTRY_MESSAGE_TYPE => Ok(MessageType::CombinatorEntry),
v if ((v & CUSTOM_MESSAGE_MASK) != 0) => Ok(MessageType::CustomEntry(v)),
v => Err(UnknownMessageType(v)),
}
Expand Down
19 changes: 19 additions & 0 deletions src/service_protocol/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl<M: CompletableEntryMessage> WriteableRestateMessage for M {
}

include!("./generated/dev.restate.service.protocol.rs");
include!("./generated/dev.restate.service.protocol.extensions.rs");

macro_rules! impl_message_traits {
($name:ident: core) => {
Expand Down Expand Up @@ -233,6 +234,24 @@ impl EntryMessageHeaderEq for RunEntryMessage {
}
}

impl_message_traits!(CombinatorEntry: message);
impl_message_traits!(CombinatorEntry: entry);
impl WriteableRestateMessage for CombinatorEntryMessage {
fn generate_header(&self, never_ack: bool) -> MessageHeader {
MessageHeader::new_ackable_entry_header(
MessageType::CombinatorEntry,
None,
if never_ack { Some(false) } else { Some(true) },
self.encoded_len() as u32,
)
}
}
impl EntryMessageHeaderEq for CombinatorEntryMessage {
fn header_eq(&self, _: &Self) -> bool {
true
}
}

// --- Completion extraction

impl TryFrom<get_state_entry_message::Result> for Value {
Expand Down
6 changes: 5 additions & 1 deletion src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ use test_log::test;

impl CoreVM {
fn mock_init(version: Version) -> CoreVM {
let vm = CoreVM::new(vec![("content-type".to_owned(), version.to_string())]).unwrap();
let vm = CoreVM::new(
vec![("content-type".to_owned(), version.to_string())],
VMOptions::default(),
)
.unwrap();

assert_that!(
vm.get_response_head().headers,
Expand Down
14 changes: 7 additions & 7 deletions src/tests/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn get_state_handler(vm: &mut CoreVM) {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice(
Expand Down Expand Up @@ -374,7 +374,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice(
Expand Down Expand Up @@ -619,7 +619,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_state_set(
Expand All @@ -644,7 +644,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_write_output(NonEmptyValue::Success(second_get_result))
Expand Down Expand Up @@ -799,7 +799,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_state_clear("STATE".to_owned()).unwrap();
Expand Down Expand Up @@ -958,7 +958,7 @@ mod eager {
vm.sys_end().unwrap();
return;
}
Value::StateKeys(_) => panic!("Unexpected variant"),
_ => panic!("Unexpected variant"),
};

vm.sys_state_clear_all().unwrap();
Expand Down Expand Up @@ -1232,9 +1232,9 @@ mod state_keys {
}

let output = match h1_result.unwrap().unwrap() {
Value::Void | Value::Success(_) => panic!("Unexpected variants"),
Value::Failure(f) => NonEmptyValue::Failure(f),
Value::StateKeys(keys) => NonEmptyValue::Success(Bytes::from(keys.join(","))),
_ => panic!("Unexpected variants"),
};

vm.sys_write_output(output).unwrap();
Expand Down
22 changes: 21 additions & 1 deletion src/vm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::service_protocol::messages::{
WriteableRestateMessage,
};
use crate::service_protocol::{Encoder, MessageType, Version};
use crate::{EntryRetryInfo, VMError, Value};
use crate::{AsyncResultHandle, AsyncResultState, EntryRetryInfo, VMError, VMOptions, Value};
use bytes::Bytes;
use bytes_utils::SegmentedBuf;
use std::collections::{HashMap, VecDeque};
Expand Down Expand Up @@ -182,6 +182,24 @@ impl AsyncResultsState {
self.ready_results.insert(idx, value);
}
}

pub(crate) fn get_ready_results_state(&self) -> HashMap<AsyncResultHandle, AsyncResultState> {
self.ready_results
.iter()
.map(|(idx, val)| {
(
AsyncResultHandle(*idx),
match val {
Value::Void
| Value::Success(_)
| Value::StateKeys(_)
| Value::CombinatorResult(_) => AsyncResultState::Success,
Value::Failure(_) => AsyncResultState::Failure,
},
)
})
.collect()
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -286,6 +304,8 @@ pub(crate) struct Context {

// Used by the error handler to set ErrorMessage.next_retry_delay
pub(crate) next_retry_delay: Option<Duration>,

pub(crate) options: VMOptions,
}

impl Context {
Expand Down
8 changes: 7 additions & 1 deletion src/vm/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub mod codes {
pub const UNSUPPORTED_MEDIA_TYPE: InvocationErrorCode = InvocationErrorCode(415);
pub const JOURNAL_MISMATCH: InvocationErrorCode = InvocationErrorCode(570);
pub const PROTOCOL_VIOLATION: InvocationErrorCode = InvocationErrorCode(571);
pub const AWAITING_TWO_ASYNC_RESULTS: InvocationErrorCode = InvocationErrorCode(572);
}

// Const errors
Expand Down Expand Up @@ -118,6 +119,11 @@ pub const INPUT_CLOSED_WHILE_WAITING_ENTRIES: VMError = VMError::new_const(
"The input was closed while still waiting to receive all the `known_entries`",
);

pub const BAD_COMBINATOR_ENTRY: VMError = VMError::new_const(
codes::PROTOCOL_VIOLATION,
"The combinator cannot be replayed. This is most likely caused by non deterministic code.",
);

// Other errors

#[derive(Debug, Clone, thiserror::Error)]
Expand Down Expand Up @@ -223,7 +229,7 @@ impl WithInvocationErrorCode for DecodingError {
}
impl_error_code!(UnavailableEntryError, JOURNAL_MISMATCH);
impl_error_code!(UnexpectedStateError, PROTOCOL_VIOLATION);
impl_error_code!(AwaitingTwoAsyncResultError, INTERNAL);
impl_error_code!(AwaitingTwoAsyncResultError, AWAITING_TWO_ASYNC_RESULTS);
impl_error_code!(BadEagerStateKeyError, INTERNAL);
impl_error_code!(DecodeStateKeysProst, PROTOCOL_VIOLATION);
impl_error_code!(DecodeStateKeysUtf8, PROTOCOL_VIOLATION);
Expand Down
Loading

0 comments on commit 3f6f6d8

Please sign in to comment.