Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return streaming body from wasm router #558

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 21 additions & 31 deletions codegen/src/next/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,12 @@ impl ToTokens for App {
let Self { endpoints } = self;

let app = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
async fn __app(request: http::Request<axum::body::BoxBody>,) -> axum::response::Response
{
use tower_service::Service;

let mut router = axum::Router::new()
#(#endpoints)*
.into_service();
#(#endpoints)*;

let response = router.call(request).await.unwrap();

Expand All @@ -268,12 +265,12 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
fd_5: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");
println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };
Expand All @@ -283,28 +280,22 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };
// file descriptor 4 for reading http body into wasm
let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut reader = std::io::BufReader::new(&mut body_read_stream);
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
reader.read_to_end(&mut body_buf).unwrap();

let request: http::Request<axum::body::Body> = wrapper
let body = axum::body::Body::from(body_buf);

let request = wrapper
.into_request_builder()
.body(body_buf.into())
.body(axum::body::boxed(body))
.unwrap();

println!("inner router received request: {:?}", &request);
let res = futures_executor::block_on(__app(request));
let res = handle_request(request);

let (parts, mut body) = res.into_parts();

Expand All @@ -314,12 +305,13 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream {
// write response parts
parts_fd.write_all(&response_parts).unwrap();

// file descriptor 5 for writing http body to host
let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) };

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
body_write_stream.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}
)
}
Expand Down Expand Up @@ -367,16 +359,14 @@ mod tests {

let actual = quote!(#app);
let expected = quote!(
async fn __app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
async fn __app(
request: http::Request<axum::body::BoxBody>,
) -> axum::response::Response {
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::post(goodbye))
.into_service();
.route("/goodbye", axum::routing::post(goodbye));

let response = router.call(request).await.unwrap();

Expand Down
1 change: 1 addition & 0 deletions runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ uuid = { workspace = true, features = ["v4"] }
wasi-common = "4.0.0"
wasmtime = "4.0.0"
wasmtime-wasi = "4.0.0"
futures = "0.3.25"

[dependencies.shuttle-common]
workspace = true
Expand Down
88 changes: 61 additions & 27 deletions runtime/src/axum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::sync::Mutex;

use async_trait::async_trait;
use cap_std::os::unix::net::UnixStream;
use futures::TryStreamExt;
use hyper::body::HttpBody;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response};
use shuttle_common::wasm::{RequestWrapper, ResponseWrapper};
Expand Down Expand Up @@ -197,18 +199,22 @@ impl RouterInner {
.unwrap();

let (mut parts_stream, parts_client) = UnixStream::pair().unwrap();
let (mut body_stream, body_client) = UnixStream::pair().unwrap();
let (mut body_write_stream, body_write_client) = UnixStream::pair().unwrap();
let (body_read_stream, body_read_client) = UnixStream::pair().unwrap();

let parts_client = WasiUnixStream::from_cap_std(parts_client);
let body_client = WasiUnixStream::from_cap_std(body_client);
let body_write_client = WasiUnixStream::from_cap_std(body_write_client);
let body_read_client = WasiUnixStream::from_cap_std(body_read_client);

store
.data_mut()
.insert_file(3, Box::new(parts_client), FileCaps::all());

store
.data_mut()
.insert_file(4, Box::new(body_client), FileCaps::all());
.insert_file(4, Box::new(body_write_client), FileCaps::all());
store
.data_mut()
.insert_file(5, Box::new(body_read_client), FileCaps::all());

let (parts, body) = req.into_parts();

Expand All @@ -218,22 +224,37 @@ impl RouterInner {
// write request parts
parts_stream.write_all(&request_rmp).unwrap();

// write body
body_stream
.write_all(hyper::body::to_bytes(body).await.unwrap().as_ref())
.unwrap();
// signal to the receiver that end of file has been reached
body_stream.write_all(&[0]).unwrap();
// To protect our server, reject requests with bodies larger than
// 64kbs of data.
let body_size = body.size_hint().upper().unwrap_or(u64::MAX);

if body_size > 1024 * 64 {
let response = Response::builder()
.status(hyper::http::StatusCode::PAYLOAD_TOO_LARGE)
.body(Body::empty())
.unwrap();

// Return early if body is too big
return Ok(response);
}

let body_bytes = hyper::body::to_bytes(body).await.unwrap();

// write body to axum
body_write_stream.write_all(body_bytes.as_ref()).unwrap();

// drop stream to signal EOF
drop(body_write_stream);

println!("calling inner Router");
self.linker
.get(&mut store, "axum", "__SHUTTLE_Axum_call")
.unwrap()
.into_func()
.unwrap()
.typed::<(RawFd, RawFd), ()>(&store)
.typed::<(RawFd, RawFd, RawFd), ()>(&store)
.unwrap()
.call(&mut store, (3, 4))
.call(&mut store, (3, 4, 5))
.unwrap();

// read response parts from host
Expand All @@ -242,22 +263,12 @@ impl RouterInner {
// deserialize response parts from rust messagepack
let wrapper: ResponseWrapper = rmps::from_read(reader).unwrap();

// read response body from wasm router
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_stream.read_exact(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
// read response body from wasm and stream it to our hyper server
let reader = BufReader::new(body_read_stream);
let stream = futures::stream::iter(reader.bytes()).try_chunks(2);
let body = hyper::Body::wrap_stream(stream);

let response: Response<Body> = wrapper
.into_response_builder()
.body(body_buf.into())
.unwrap();
let response: Response<Body> = wrapper.into_response_builder().body(body).unwrap();

Ok(response)
}
Expand Down Expand Up @@ -382,5 +393,28 @@ pub mod tests {
let res = inner.clone().handle_request(request).await.unwrap();

assert_eq!(res.status(), StatusCode::NOT_FOUND);

// POST /uppercase
let request: Request<Body> = Request::builder()
.method(Method::POST)
.version(Version::HTTP_11)
.header("test", HeaderValue::from_static("invalid"))
.uri("https://axum-wasm.example/uppercase")
.body("this should be uppercased".into())
.unwrap();

let res = inner.clone().handle_request(request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
&hyper::body::to_bytes(res.into_body())
.await
.unwrap()
.iter()
.cloned()
.collect::<Vec<u8>>()
.as_ref(),
b"THIS SHOULD BE UPPERCASED"
);
}
}
1 change: 1 addition & 0 deletions tmp/axum-wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ futures-executor = "0.3.21"
http = "0.2.7"
tower-service = "0.3.1"
rmp-serde = { version = "1.1.1" }
futures = "0.3.25"

[dependencies.shuttle-common]
path = "../../common"
Expand Down
63 changes: 36 additions & 27 deletions tmp/axum-wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
pub fn handle_request<B>(req: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
use axum::{
body::BoxBody,
extract::BodyStream,
response::{IntoResponse, Response},
};
use futures::TryStreamExt;

pub fn handle_request(req: http::Request<BoxBody>) -> axum::response::Response {
futures_executor::block_on(app(req))
}

async fn app<B>(request: http::Request<B>) -> axum::response::Response
where
B: axum::body::HttpBody + Send + 'static,
{
async fn app(request: http::Request<BoxBody>) -> axum::response::Response {
use tower_service::Service;

let mut router = axum::Router::new()
.route("/hello", axum::routing::get(hello))
.route("/goodbye", axum::routing::get(goodbye));
.route("/goodbye", axum::routing::get(goodbye))
.route("/uppercase", axum::routing::post(uppercase));

let response = router.call(request).await.unwrap();

Expand All @@ -28,17 +30,29 @@ async fn goodbye() -> &'static str {
"Goodbye, World!"
}

// Map the bytes of the body stream to uppercase and return the stream directly.
async fn uppercase(body: BodyStream) -> impl IntoResponse {
let chunk_stream = body.map_ok(|chunk| {
chunk
.iter()
.map(|byte| byte.to_ascii_uppercase())
.collect::<Vec<u8>>()
});
Response::new(axum::body::StreamBody::new(chunk_stream))
}

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn __SHUTTLE_Axum_call(
fd_3: std::os::wasi::prelude::RawFd,
fd_4: std::os::wasi::prelude::RawFd,
fd_5: std::os::wasi::prelude::RawFd,
) {
use axum::body::HttpBody;
use std::io::{Read, Write};
use std::os::wasi::io::FromRawFd;

println!("inner handler awoken; interacting with fd={fd_3},{fd_4}");
println!("inner handler awoken; interacting with fd={fd_3},{fd_4},{fd_5}");

// file descriptor 3 for reading and writing http parts
let mut parts_fd = unsafe { std::fs::File::from_raw_fd(fd_3) };
Expand All @@ -48,24 +62,18 @@ pub extern "C" fn __SHUTTLE_Axum_call(
// deserialize request parts from rust messagepack
let wrapper: shuttle_common::wasm::RequestWrapper = rmp_serde::from_read(reader).unwrap();

// file descriptor 4 for reading and writing http body
let mut body_fd = unsafe { std::fs::File::from_raw_fd(fd_4) };
// file descriptor 4 for reading http body into wasm
let mut body_read_stream = unsafe { std::fs::File::from_raw_fd(fd_4) };

// read body from host
let mut reader = std::io::BufReader::new(&mut body_read_stream);
let mut body_buf = Vec::new();
let mut c_buf: [u8; 1] = [0; 1];
loop {
body_fd.read(&mut c_buf).unwrap();
if c_buf[0] == 0 {
break;
} else {
body_buf.push(c_buf[0]);
}
}
reader.read_to_end(&mut body_buf).unwrap();

let request: http::Request<axum::body::Body> = wrapper
let body = axum::body::Body::from(body_buf);

let request = wrapper
.into_request_builder()
.body(body_buf.into())
.body(axum::body::boxed(body))
.unwrap();

println!("inner router received request: {:?}", &request);
Expand All @@ -79,10 +87,11 @@ pub extern "C" fn __SHUTTLE_Axum_call(
// write response parts
parts_fd.write_all(&response_parts).unwrap();

// file descriptor 5 for writing http body to host
let mut body_write_stream = unsafe { std::fs::File::from_raw_fd(fd_5) };

// write body if there is one
if let Some(body) = futures_executor::block_on(body.data()) {
body_fd.write_all(body.unwrap().as_ref()).unwrap();
body_write_stream.write_all(body.unwrap().as_ref()).unwrap();
}
// signal to the reader that end of file has been reached
body_fd.write(&[0]).unwrap();
}