diff --git a/Cargo.lock b/Cargo.lock index 85aeaab15..bc7113212 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6347,6 +6347,7 @@ dependencies = [ "async-trait", "cap-std", "clap 4.0.27", + "futures", "hyper", "rmp-serde", "shuttle-common", diff --git a/codegen/src/next/mod.rs b/codegen/src/next/mod.rs index f0983bc18..520b9ee95 100644 --- a/codegen/src/next/mod.rs +++ b/codegen/src/next/mod.rs @@ -238,15 +238,12 @@ impl ToTokens for App { let Self { endpoints } = self; let app = quote!( - async fn __app(request: http::Request) -> axum::response::Response - where - B: axum::body::HttpBody + Send + 'static, + async fn __app(request: http::Request,) -> axum::response::Response { use tower_service::Service; let mut router = axum::Router::new() - #(#endpoints)* - .into_service(); + #(#endpoints)*; let response = router.call(request).await.unwrap(); @@ -268,12 +265,12 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { pub extern "C" fn __SHUTTLE_Axum_call( fd_3: std::os::wasi::prelude::RawFd, fd_4: std::os::wasi::prelude::RawFd, + fd_5: std::os::wasi::prelude::RawFd, ) { use axum::body::HttpBody; use std::io::{Read, Write}; use std::os::wasi::io::FromRawFd; - - println!("inner handler awoken; interacting with fd={fd_3},{fd_4}"); + println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}"); // file descriptor 3 for reading and writing http parts let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) }; @@ -283,28 +280,22 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { // deserialize request parts from rust messagepack let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap(); - // file descriptor 4 for reading and writing http body - let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) }; + // file descriptor 4 for reading http body into wasm + let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) }; - // read body from host + let mut reader = std::io::BufReader::new(&mut body_read_stream); let mut body_buf = Vec::new(); - let mut c_buf: [u8; 1] = [0; 1]; - loop { - body_fd.read(&mut c_buf).unwrap(); - if c_buf[0] == 0 { - break; - } else { - body_buf.push(c_buf[0]); - } - } + reader.read_to_end(&mut body_buf).unwrap(); - let request: http::Request = wrapper + let body = axum::body::Body::from(body_buf); + + let request = wrapper .into_request_builder() - .body(body_buf.into()) + .body(axum::body::boxed(body)) .unwrap(); println!("inner router received request: {:?}", &request); - let res = futures_executor::block_on(__app(request)); + let res = handle_request(request); let (parts, mut body) = res.into_parts(); @@ -314,12 +305,13 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { // write response parts parts_fd.write_all(&response_parts).unwrap(); + // file descriptor 5 for writing http body to host + let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) }; + // write body if there is one if let Some(body) = futures_executor::block_on(body.data()) { - body_fd.write_all(body.unwrap().as_ref()).unwrap(); + body_write_stream.write_all(body.unwrap().as_ref()).unwrap(); } - // signal to the reader that end of file has been reached - body_fd.write(&[0]).unwrap(); } ) } @@ -367,16 +359,14 @@ mod tests { let actual = quote!(#app); let expected = quote!( - async fn __app(request: http::Request) -> axum::response::Response - where - B: axum::body::HttpBody + Send + 'static, - { + async fn __app( + request: http::Request, + ) -> axum::response::Response { use tower_service::Service; let mut router = axum::Router::new() .route("/hello", axum::routing::get(hello)) - .route("/goodbye", axum::routing::post(goodbye)) - .into_service(); + .route("/goodbye", axum::routing::post(goodbye)); let response = router.call(request).await.unwrap(); diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 6840847f8..89fbf12b4 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -23,6 +23,7 @@ uuid = { workspace = true, features = ["v4"] } wasi-common = "4.0.0" wasmtime = "4.0.0" wasmtime-wasi = "4.0.0" +futures = "0.3.25" [dependencies.shuttle-common] workspace = true diff --git a/runtime/src/axum/mod.rs b/runtime/src/axum/mod.rs index 5bd8fcfe6..e3dada547 100644 --- a/runtime/src/axum/mod.rs +++ b/runtime/src/axum/mod.rs @@ -9,6 +9,8 @@ use std::sync::Mutex; use async_trait::async_trait; use cap_std::os::unix::net::UnixStream; +use futures::TryStreamExt; +use hyper::body::HttpBody; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response}; use shuttle_common::wasm::{RequestWrapper, ResponseWrapper}; @@ -197,18 +199,22 @@ impl RouterInner { .unwrap(); let (mut parts_stream, parts_client) = UnixStream::pair().unwrap(); - let (mut body_stream, body_client) = UnixStream::pair().unwrap(); + let (mut body_write_stream, body_write_client) = UnixStream::pair().unwrap(); + let (body_read_stream, body_read_client) = UnixStream::pair().unwrap(); let parts_client = WasiUnixStream::from_cap_std(parts_client); - let body_client = WasiUnixStream::from_cap_std(body_client); + let body_write_client = WasiUnixStream::from_cap_std(body_write_client); + let body_read_client = WasiUnixStream::from_cap_std(body_read_client); store .data_mut() .insert_file(3, Box::new(parts_client), FileCaps::all()); - store .data_mut() - .insert_file(4, Box::new(body_client), FileCaps::all()); + .insert_file(4, Box::new(body_write_client), FileCaps::all()); + store + .data_mut() + .insert_file(5, Box::new(body_read_client), FileCaps::all()); let (parts, body) = req.into_parts(); @@ -218,12 +224,27 @@ impl RouterInner { // write request parts parts_stream.write_all(&request_rmp).unwrap(); - // write body - body_stream - .write_all(hyper::body::to_bytes(body).await.unwrap().as_ref()) - .unwrap(); - // signal to the receiver that end of file has been reached - body_stream.write_all(&[0]).unwrap(); + // To protect our server, reject requests with bodies larger than + // 64kbs of data. + let body_size = body.size_hint().upper().unwrap_or(u64::MAX); + + if body_size > 1024 * 64 { + let response = Response::builder() + .status(hyper::http::StatusCode::PAYLOAD_TOO_LARGE) + .body(Body::empty()) + .unwrap(); + + // Return early if body is too big + return Ok(response); + } + + let body_bytes = hyper::body::to_bytes(body).await.unwrap(); + + // write body to axum + body_write_stream.write_all(body_bytes.as_ref()).unwrap(); + + // drop stream to signal EOF + drop(body_write_stream); println!("calling inner Router"); self.linker @@ -231,9 +252,9 @@ impl RouterInner { .unwrap() .into_func() .unwrap() - .typed::<(RawFd, RawFd), ()>(&store) + .typed::<(RawFd, RawFd, RawFd), ()>(&store) .unwrap() - .call(&mut store, (3, 4)) + .call(&mut store, (3, 4, 5)) .unwrap(); // read response parts from host @@ -242,22 +263,12 @@ impl RouterInner { // deserialize response parts from rust messagepack let wrapper: ResponseWrapper = rmps::from_read(reader).unwrap(); - // read response body from wasm router - let mut body_buf = Vec::new(); - let mut c_buf: [u8; 1] = [0; 1]; - loop { - body_stream.read_exact(&mut c_buf).unwrap(); - if c_buf[0] == 0 { - break; - } else { - body_buf.push(c_buf[0]); - } - } + // read response body from wasm and stream it to our hyper server + let reader = BufReader::new(body_read_stream); + let stream = futures::stream::iter(reader.bytes()).try_chunks(2); + let body = hyper::Body::wrap_stream(stream); - let response: Response = wrapper - .into_response_builder() - .body(body_buf.into()) - .unwrap(); + let response: Response = wrapper.into_response_builder().body(body).unwrap(); Ok(response) } @@ -382,5 +393,28 @@ pub mod tests { let res = inner.clone().handle_request(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // POST /uppercase + let request: Request = Request::builder() + .method(Method::POST) + .version(Version::HTTP_11) + .header("test", HeaderValue::from_static("invalid")) + .uri("https://axum-wasm.example/uppercase") + .body("this should be uppercased".into()) + .unwrap(); + + let res = inner.clone().handle_request(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + &hyper::body::to_bytes(res.into_body()) + .await + .unwrap() + .iter() + .cloned() + .collect::>() + .as_ref(), + b"THIS SHOULD BE UPPERCASED" + ); } } diff --git a/tmp/axum-wasm/Cargo.toml b/tmp/axum-wasm/Cargo.toml index 1b01b15a9..740001a64 100644 --- a/tmp/axum-wasm/Cargo.toml +++ b/tmp/axum-wasm/Cargo.toml @@ -14,6 +14,7 @@ futures-executor = "0.3.21" http = "0.2.7" tower-service = "0.3.1" rmp-serde = { version = "1.1.1" } +futures = "0.3.25" [dependencies.shuttle-common] path = "../../common" diff --git a/tmp/axum-wasm/src/lib.rs b/tmp/axum-wasm/src/lib.rs index 23f49b0c5..e0801ac9a 100644 --- a/tmp/axum-wasm/src/lib.rs +++ b/tmp/axum-wasm/src/lib.rs @@ -1,19 +1,21 @@ -pub fn handle_request(req: http::Request) -> axum::response::Response -where - B: axum::body::HttpBody + Send + 'static, -{ +use axum::{ + body::BoxBody, + extract::BodyStream, + response::{IntoResponse, Response}, +}; +use futures::TryStreamExt; + +pub fn handle_request(req: http::Request) -> axum::response::Response { futures_executor::block_on(app(req)) } -async fn app(request: http::Request) -> axum::response::Response -where - B: axum::body::HttpBody + Send + 'static, -{ +async fn app(request: http::Request) -> axum::response::Response { use tower_service::Service; let mut router = axum::Router::new() .route("/hello", axum::routing::get(hello)) - .route("/goodbye", axum::routing::get(goodbye)); + .route("/goodbye", axum::routing::get(goodbye)) + .route("/uppercase", axum::routing::post(uppercase)); let response = router.call(request).await.unwrap(); @@ -28,17 +30,29 @@ async fn goodbye() -> &'static str { "Goodbye, World!" } +// Map the bytes of the body stream to uppercase and return the stream directly. +async fn uppercase(body: BodyStream) -> impl IntoResponse { + let chunk_stream = body.map_ok(|chunk| { + chunk + .iter() + .map(|byte| byte.to_ascii_uppercase()) + .collect::>() + }); + Response::new(axum::body::StreamBody::new(chunk_stream)) +} + #[no_mangle] #[allow(non_snake_case)] pub extern "C" fn __SHUTTLE_Axum_call( fd_3: std::os::wasi::prelude::RawFd, fd_4: std::os::wasi::prelude::RawFd, + fd_5: std::os::wasi::prelude::RawFd, ) { use axum::body::HttpBody; use std::io::{Read, Write}; use std::os::wasi::io::FromRawFd; - println!("inner handler awoken; interacting with fd={fd_3},{fd_4}"); + println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}"); // file descriptor 3 for reading and writing http parts let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) }; @@ -48,24 +62,18 @@ pub extern "C" fn __SHUTTLE_Axum_call( // deserialize request parts from rust messagepack let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap(); - // file descriptor 4 for reading and writing http body - let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) }; + // file descriptor 4 for reading http body into wasm + let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) }; - // read body from host + let mut reader = std::io::BufReader::new(&mut body_read_stream); let mut body_buf = Vec::new(); - let mut c_buf: [u8; 1] = [0; 1]; - loop { - body_fd.read(&mut c_buf).unwrap(); - if c_buf[0] == 0 { - break; - } else { - body_buf.push(c_buf[0]); - } - } + reader.read_to_end(&mut body_buf).unwrap(); - let request: http::Request = wrapper + let body = axum::body::Body::from(body_buf); + + let request = wrapper .into_request_builder() - .body(body_buf.into()) + .body(axum::body::boxed(body)) .unwrap(); println!("inner router received request: {:?}", &request); @@ -79,10 +87,11 @@ pub extern "C" fn __SHUTTLE_Axum_call( // write response parts parts_fd.write_all(&response_parts).unwrap(); + // file descriptor 5 for writing http body to host + let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) }; + // write body if there is one if let Some(body) = futures_executor::block_on(body.data()) { - body_fd.write_all(body.unwrap().as_ref()).unwrap(); + body_write_stream.write_all(body.unwrap().as_ref()).unwrap(); } - // signal to the reader that end of file has been reached - body_fd.write(&[0]).unwrap(); }