Skip to content

Commit

Permalink
Merge pull request #209 from NethermindEth/feat/horizontal/poll
Browse files Browse the repository at this point in the history
Polling status of execution
  • Loading branch information
taco-paco authored Oct 18, 2024
2 parents b87a0dd + 95da4d4 commit 56906e3
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 9 deletions.
7 changes: 6 additions & 1 deletion crates/lambdas/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ path = "src/compile.rs"
[[bin]]
name = "verify"
version = "0.0.1"
path = "src/verify.rs"
path = "src/verify.rs"

[[bin]]
name = "poll"
version = "0.0.1"
path = "src/poll.rs"
7 changes: 5 additions & 2 deletions crates/lambdas/src/common/utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use lambda_http::{Request, RequestPayloadExt, Response};
use serde::de::DeserializeOwned;
use serde::Deserialize;

use crate::common::errors::{Error, Error::HttpError};

const EMPTY_PAYLOAD_ERROR: &str = "Request payload is empty";

pub fn extract_request<T: DeserializeOwned>(request: Request) -> Result<T, Error> {
pub fn extract_request<T>(request: &Request) -> Result<T, Error>
where
T: for<'de> Deserialize<'de>,
{
return match request.payload::<T>() {
Ok(Some(val)) => Ok(val),
Ok(None) => {
Expand Down
2 changes: 1 addition & 1 deletion crates/lambdas/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async fn process_request(
s3_client: &aws_sdk_s3::Client,
bucket_name: &str,
) -> Result<LambdaResponse<String>, Error> {
let request = extract_request::<CompilationRequest>(request)?;
let request = extract_request::<CompilationRequest>(&request)?;

let objects = s3_client
.list_objects_v2()
Expand Down
2 changes: 1 addition & 1 deletion crates/lambdas/src/generate_presigned_urls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async fn process_request(
bucket_name: &str,
s3_client: &aws_sdk_s3::Client,
) -> Result<LambdaResponse<String>, Error> {
let request = extract_request::<Request>(request)?;
let request = extract_request::<Request>(&request)?;
if request.files.len() > MAX_FILES {
warn!("MAX_FILES limit exceeded");
let response = LambdaResponse::builder()
Expand Down
128 changes: 128 additions & 0 deletions crates/lambdas/src/poll.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use aws_config::BehaviorVersion;
use aws_sdk_dynamodb::types::AttributeValue;
use lambda_http::http::StatusCode;
use lambda_http::{
run, service_fn, Error as LambdaError, Request as LambdaRequest, Response as LambdaResponse,
};
use serde::Deserialize;
use tracing::error;
use types::item::errors::ItemError;
use types::item::task_result::TaskResult;
use types::item::{Item, Status};
use uuid::Uuid;

const TABLE_NAME_DEFAULT: &str = "zksync-table";
const NO_SUCH_ITEM: &str = "No such item";

mod common;
use crate::common::{errors::Error, utils::extract_request};

#[derive(Deserialize)]
struct PollRequest {
pub id: Uuid,
}

#[tracing::instrument(skip(dynamo_client, table_name))]
async fn process_request(
request: LambdaRequest,
dynamo_client: &aws_sdk_dynamodb::Client,
table_name: &str,
) -> Result<LambdaResponse<String>, Error> {
let request = extract_request::<PollRequest>(&request)?;
let output = dynamo_client
.get_item()
.table_name(table_name)
.key(
Item::primary_key_name(),
AttributeValue::S(request.id.to_string()),
)
.send()
.await
.map_err(Box::new)?;

let raw_item = output.item.ok_or_else(|| {
let response = LambdaResponse::builder()
.status(StatusCode::NOT_FOUND)
.header("content-type", "text/html")
.body(NO_SUCH_ITEM.to_string())
.map_err(Error::from);

match response {
Ok(response) => Error::HttpError(response),
Err(err) => err,
}
})?;

let item: Item = raw_item.try_into().map_err(|err: ItemError| {
error!("Failed to deserialize item. id: {}", request.id);
let response = LambdaResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("content-type", "text/html")
.body(err.to_string())
.map_err(Error::from);

match response {
Ok(response) => Error::HttpError(response),
Err(err) => err,
}
})?;

let task_result = if let Status::Done(task_result) = item.status {
Ok(task_result)
} else {
let response = LambdaResponse::builder()
.status(StatusCode::BAD_REQUEST)
.header("content-type", "text/html")
.body("Task isn't ready".to_owned())
.map_err(Error::from)?;

Err(Error::HttpError(response))
}?;

match task_result {
TaskResult::Success(value) => {
let response = LambdaResponse::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(serde_json::to_string(&value)?)?;

Ok(response)
}
TaskResult::Failure(value) => {
let status_code: StatusCode = value.error_type.into();
let response = LambdaResponse::builder()
.status(status_code)
.header("content-type", "text/html")
.body(value.message)
.map_err(Box::new)?;

Err(Error::HttpError(response))
}
}
}

#[tokio::main]
async fn main() -> Result<(), LambdaError> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_ansi(false)
.without_time() // CloudWatch will add the ingestion time
.with_target(false)
.init();

let table_name = std::env::var("TABLE_NAME").unwrap_or(TABLE_NAME_DEFAULT.into());

let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
let dynamo_client = aws_sdk_dynamodb::Client::new(&config);

run(service_fn(|request: LambdaRequest| async {
let result = process_request(request, &dynamo_client, &table_name).await;

match result {
Ok(val) => Ok(val),
Err(Error::HttpError(val)) => Ok(val),
Err(Error::LambdaError(err)) => Err(err),
}
}))
.await
}
2 changes: 1 addition & 1 deletion crates/lambdas/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async fn process_request(
s3_client: &aws_sdk_s3::Client,
bucket_name: &str,
) -> Result<LambdaResponse<String>, Error> {
let request = extract_request::<VerificationRequest>(request)?;
let request = extract_request::<VerificationRequest>(&request)?;

let objects = s3_client
.list_objects_v2()
Expand Down
2 changes: 2 additions & 0 deletions crates/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ serde_json.workspace = true
thiserror.workspace = true
uuid.workspace = true

http = "1.1.0"

[dev-dependencies]
serde_json.workspace = true
16 changes: 14 additions & 2 deletions crates/types/src/item/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl ItemError {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ServerError {
UnsupportedCompilerVersion,
Expand All @@ -37,7 +37,7 @@ impl Into<&'static str> for ServerError {
ServerError::CompilationError => "CompilationError",
ServerError::InternalError => "InternalError",
ServerError::UnknownNetworkError => "UnknownNetworkError",
ServerError::VerificationError => "VerificationError"
ServerError::VerificationError => "VerificationError",
}
}
}
Expand All @@ -55,3 +55,15 @@ impl TryFrom<&str> for ServerError {
}
}
}

impl Into<http::StatusCode> for ServerError {
fn into(self) -> http::StatusCode {
match self {
Self::UnsupportedCompilerVersion
| Self::CompilationError
| Self::UnknownNetworkError
| Self::VerificationError => http::StatusCode::BAD_REQUEST,
Self::InternalError => http::StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
4 changes: 3 additions & 1 deletion crates/types/src/item/task_result.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use aws_sdk_dynamodb::types::AttributeValue;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt::Formatter;

Expand Down Expand Up @@ -77,8 +78,9 @@ impl TryFrom<&AttributeMap> for TaskResult {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(untagged)]
pub enum TaskSuccess {
Compile { presigned_urls: Vec<String> },
Verify { message: String },
Expand Down

0 comments on commit 56906e3

Please sign in to comment.