Skip to content

Commit

Permalink
feat: return streaming body from wasm router (#558)
Browse files Browse the repository at this point in the history
* feat: stream body to and from router

* fix: reader to stream hack didn't work in wasm

I added a post endpoint and test, which proved that my hacky stream from BufReader didn't work on the wasm side

* refactor: test string

* feat: update codegen with axum-wasm changes

* refactor: clean up

* refactor: typo

* feat: guard against large request bodies
  • Loading branch information
oddgrd authored Jan 4, 2023
1 parent 1487ddf commit 9db7f90
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 85 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 21 additions & 31 deletions codegen/src/next/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,12 @@ impl ToTokens for App {
let Self { endpoints } = self;

let app = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
async fn __app(request: http::Request<axum::body::BoxBody>,) -> axum::response::Response
{
use tower_service::Service;

let mut router = axum::Router::new()
#(#endpoints)*
.into_service();
#(#endpoints)*;

let response = router.call(request).await.unwrap();

Expand All @@ -268,12 +265,12 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
fd_5: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");
println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };
Expand All @@ -283,28 +280,22 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };
// file descriptor 4 for reading http body into wasm
let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut reader = std::io::BufReader::new(&mut body_read_stream);
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
reader.read_to_end(&mut body_buf).unwrap();

let request: http::Request<axum::body::Body> = wrapper
let body = axum::body::Body::from(body_buf);

let request = wrapper
.into_request_builder()
.body(body_buf.into())
.body(axum::body::boxed(body))
.unwrap();

println!("inner router received request: {:?}", &request);
let res = futures_executor::block_on(__app(request));
let res = handle_request(request);

let (parts, mut body) = res.into_parts();

Expand All @@ -314,12 +305,13 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
// write response parts
parts_fd.write_all(&response_parts).unwrap();

// file descriptor 5 for writing http body to host
let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) };

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
body_write_stream.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}
)
}
Expand Down Expand Up @@ -367,16 +359,14 @@ mod tests {

let actual = quote!(#app);
let expected = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
async fn __app(
request: http::Request<axum::body::BoxBody>,
) -> axum::response::Response {
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::post(goodbye))
.into_service();
.route("/goodbye", axum::routing::post(goodbye));

let response = router.call(request).await.unwrap();

Expand Down
1 change: 1 addition & 0 deletions runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ uuid = { workspace = true, features = ["v4"] }
wasi-common = "4.0.0"
wasmtime = "4.0.0"
wasmtime-wasi = "4.0.0"
futures = "0.3.25"

[dependencies.shuttle-common]
workspace = true
Expand Down
88 changes: 61 additions & 27 deletions runtime/src/axum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::sync::Mutex;

use async_trait::async_trait;
use cap_std::os::unix::net::UnixStream;
use futures::TryStreamExt;
use hyper::body::HttpBody;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response};
use shuttle_common::wasm::{RequestWrapper, ResponseWrapper};
Expand Down Expand Up @@ -197,18 +199,22 @@ impl RouterInner {
.unwrap();

let (mut parts_stream, parts_client) = UnixStream::pair().unwrap();
let (mut body_stream, body_client) = UnixStream::pair().unwrap();
let (mut body_write_stream, body_write_client) = UnixStream::pair().unwrap();
let (body_read_stream, body_read_client) = UnixStream::pair().unwrap();

let parts_client = WasiUnixStream::from_cap_std(parts_client);
let body_client = WasiUnixStream::from_cap_std(body_client);
let body_write_client = WasiUnixStream::from_cap_std(body_write_client);
let body_read_client = WasiUnixStream::from_cap_std(body_read_client);

store
.data_mut()
.insert_file(3, Box::new(parts_client), FileCaps::all());

store
.data_mut()
.insert_file(4, Box::new(body_client), FileCaps::all());
.insert_file(4, Box::new(body_write_client), FileCaps::all());
store
.data_mut()
.insert_file(5, Box::new(body_read_client), FileCaps::all());

let (parts, body) = req.into_parts();

Expand All @@ -218,22 +224,37 @@ impl RouterInner {
// write request parts
parts_stream.write_all(&request_rmp).unwrap();

// write body
body_stream
.write_all(hyper::body::to_bytes(body).await.unwrap().as_ref())
.unwrap();
// signal to the receiver that end of file has been reached
body_stream.write_all(&[0]).unwrap();
// To protect our server, reject requests with bodies larger than
// 64kbs of data.
let body_size = body.size_hint().upper().unwrap_or(u64::MAX);

if body_size > 1024 * 64 {
let response = Response::builder()
.status(hyper::http::StatusCode::PAYLOAD_TOO_LARGE)
.body(Body::empty())
.unwrap();

// Return early if body is too big
return Ok(response);
}

let body_bytes = hyper::body::to_bytes(body).await.unwrap();

// write body to axum
body_write_stream.write_all(body_bytes.as_ref()).unwrap();

// drop stream to signal EOF
drop(body_write_stream);

println!("calling inner Router");
self.linker
.get(&mut store, "axum", "__SHUTTLE_Axum_call")
.unwrap()
.into_func()
.unwrap()
.typed::<(RawFd, RawFd), ()>(&store)
.typed::<(RawFd, RawFd, RawFd), ()>(&store)
.unwrap()
.call(&mut store, (3, 4))
.call(&mut store, (3, 4, 5))
.unwrap();

// read response parts from host
Expand All @@ -242,22 +263,12 @@ impl RouterInner {
// deserialize response parts from rust messagepack
let wrapper: ResponseWrapper = rmps::from_read(reader).unwrap();

// read response body from wasm router
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_stream.read_exact(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
// read response body from wasm and stream it to our hyper server
let reader = BufReader::new(body_read_stream);
let stream = futures::stream::iter(reader.bytes()).try_chunks(2);
let body = hyper::Body::wrap_stream(stream);

let response: Response<Body> = wrapper
.into_response_builder()
.body(body_buf.into())
.unwrap();
let response: Response<Body> = wrapper.into_response_builder().body(body).unwrap();

Ok(response)
}
Expand Down Expand Up @@ -382,5 +393,28 @@ pub mod tests {
let res = inner.clone().handle_request(request).await.unwrap();

assert_eq!(res.status(), StatusCode::NOT_FOUND);

// POST /uppercase
let request: Request<Body> = Request::builder()
.method(Method::POST)
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("invalid"))
.uri("https://axum-wasm.example/uppercase")
.body("this should be uppercased".into())
.unwrap();

let res = inner.clone().handle_request(request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
&hyper::body::to_bytes(res.into_body())
.await
.unwrap()
.iter()
.cloned()
.collect::<Vec<u8>>()
.as_ref(),
b"THIS SHOULD BE UPPERCASED"
);
}
}
1 change: 1 addition & 0 deletions tmp/axum-wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ futures-executor = "0.3.21"
http = "0.2.7"
tower-service = "0.3.1"
rmp-serde = { version = "1.1.1" }
futures = "0.3.25"

[dependencies.shuttle-common]
path = "../../common"
Expand Down
63 changes: 36 additions & 27 deletions tmp/axum-wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
pub fn handle_request<B>(req: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
use axum::{
body::BoxBody,
extract::BodyStream,
response::{IntoResponse, Response},
};
use futures::TryStreamExt;

pub fn handle_request(req: http::Request<BoxBody>) -> axum::response::Response {
futures_executor::block_on(app(req))
}

async fn app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
async fn app(request: http::Request<BoxBody>) -> axum::response::Response {
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::get(goodbye));
.route("/goodbye", axum::routing::get(goodbye))
.route("/uppercase", axum::routing::post(uppercase));

let response = router.call(request).await.unwrap();

Expand All @@ -28,17 +30,29 @@ async fn goodbye() -> &'static str {
"Goodbye, World!"
}

// Map the bytes of the body stream to uppercase and return the stream directly.
async fn uppercase(body: BodyStream) -> impl IntoResponse {
let chunk_stream = body.map_ok(|chunk| {
chunk
.iter()
.map(|byte| byte.to_ascii_uppercase())
.collect::<Vec<u8>>()
});
Response::new(axum::body::StreamBody::new(chunk_stream))
}

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
fd_5: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");
println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };
Expand All @@ -48,24 +62,18 @@ pub extern "C" fn __SHUTTLE_Axum_call(
// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };
// file descriptor 4 for reading http body into wasm
let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut reader = std::io::BufReader::new(&mut body_read_stream);
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
reader.read_to_end(&mut body_buf).unwrap();

let request: http::Request<axum::body::Body> = wrapper
let body = axum::body::Body::from(body_buf);

let request = wrapper
.into_request_builder()
.body(body_buf.into())
.body(axum::body::boxed(body))
.unwrap();

println!("inner router received request: {:?}", &request);
Expand All @@ -79,10 +87,11 @@ pub extern "C" fn __SHUTTLE_Axum_call(
// write response parts
parts_fd.write_all(&response_parts).unwrap();

// file descriptor 5 for writing http body to host
let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) };

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
body_write_stream.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}

0 comments on commit 9db7f90

Please sign in to comment.