Skip to content

Commit

Permalink
Merge pull request #705 from andreiltd/graceful-shutdown
Browse files Browse the repository at this point in the history
Add graceful server shutdown
  • Loading branch information
jprendes authored Oct 31, 2024
2 parents 5ece90d + 2a748b4 commit 63fba0d
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 29 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ serial_test = "2"
tracing = "0.1"
hyper = "1.5.0"
tokio = { version = "1.40.0", default-features = false }
tokio-util = { version = "0.7", default-features = false }

# wasmtime
wasmtime = { version = "25.0.2", features = ["async"] }
Expand Down
8 changes: 7 additions & 1 deletion crates/containerd-shim-wasm/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::time::Duration;

use anyhow::{bail, Result};
pub use containerd_shim_wasm_test_modules as modules;
use libc::SIGINT;
use libc::{SIGINT, SIGTERM};
use oci_spec::runtime::{
get_default_namespaces, LinuxBuilder, LinuxNamespace, LinuxNamespaceType, ProcessBuilder,
RootBuilder, SpecBuilder,
Expand Down Expand Up @@ -223,6 +223,12 @@ where
Ok(self)
}

pub fn terminate(&self) -> Result<&Self> {
log::info!("sending SIGTERM");
self.instance.kill(SIGTERM as u32)?;
Ok(self)
}

pub fn wait(&self, timeout: Duration) -> Result<(u32, String, String)> {
let dir = self.tempdir.path();

Expand Down
3 changes: 2 additions & 1 deletion crates/containerd-shim-wasmtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ edition.workspace = true
[dependencies]
anyhow = { workspace = true }
containerd-shim-wasm = { workspace = true, features = ["opentelemetry"] }
libc = { workspace = true }
log = { workspace = true }
hyper = { workspace = true }
tokio = { workspace = true, features = ["signal", "macros"] }
libc = { workspace = true }
tokio-util = { workspace = true, features = ["rt"] }

wasmtime = { workspace = true }
wasmtime-wasi = { workspace = true }
Expand Down
72 changes: 67 additions & 5 deletions crates/containerd-shim-wasmtime/src/http_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;

use anyhow::{bail, Result};
use containerd_shim_wasm::container::RuntimeContext;
use hyper::server::conn::http1;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use wasmtime::component::ResourceTable;
use wasmtime::Store;
use wasmtime_wasi_http::bindings::http::types::Scheme;
Expand All @@ -26,9 +30,44 @@ const DEFAULT_BACKLOG: u32 = 100;

type Request = hyper::Request<hyper::body::Incoming>;

fn is_connection_error(e: &std::io::Error) -> bool {
matches!(
e.kind(),
std::io::ErrorKind::ConnectionRefused
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
)
}

// [From axum](https://github.com/tokio-rs/axum/blob/280d16a61059f57230819a79b15aa12a263e8cca/axum/src/serve.rs#L425)
async fn tcp_accept(listener: &TcpListener) -> Option<TcpStream> {
match listener.accept().await {
Ok((stream, _addr)) => Some(stream),
Err(e) => {
if is_connection_error(&e) {
return None;
}

// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
log::error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
None
}
}
}

pub(crate) async fn serve_conn(
ctx: &impl RuntimeContext,
instance: ProxyPre<WasiPreview2Ctx>,
cancel: CancellationToken,
) -> Result<()> {
let mut env = envs_from_ctx(ctx).into_iter().collect::<HashMap<_, _>>();

Expand Down Expand Up @@ -59,20 +98,32 @@ pub(crate) async fn serve_conn(
socket.bind(addr)?;

let listener = socket.listen(backlog)?;
let tracker = TaskTracker::new();

log::info!("Serving HTTP on http://{}/", listener.local_addr()?);

let env = env.into_iter().collect();
let handler = Arc::new(ProxyHandler::new(instance, env));
let handler = Arc::new(ProxyHandler::new(instance, env, tracker.clone()));

loop {
let (stream, _) = listener.accept().await?;
let stream = tokio::select! {
conn = tcp_accept(&listener) => {
match conn {
Some(conn) => conn,
None => continue,
}
}
_ = cancel.cancelled() => {
break;
}
};

log::debug!("New connection");

let stream = TokioIo::new(stream);
let h = handler.clone();

tokio::spawn(async {
tracker.spawn(async {
if let Err(e) = http1::Builder::new()
.keep_alive(true)
.serve_connection(
Expand All @@ -85,19 +136,30 @@ pub(crate) async fn serve_conn(
}
});
}

tracker.close();
tracker.wait().await;

Ok(())
}

struct ProxyHandler {
instance_pre: ProxyPre<WasiPreview2Ctx>,
next_id: AtomicU64,
env: Vec<(String, String)>,
tracker: TaskTracker,
}

impl ProxyHandler {
fn new(instance_pre: ProxyPre<WasiPreview2Ctx>, env: Vec<(String, String)>) -> Self {
fn new(
instance_pre: ProxyPre<WasiPreview2Ctx>,
env: Vec<(String, String)>,
tracker: TaskTracker,
) -> Self {
ProxyHandler {
instance_pre,
env,
tracker,
next_id: AtomicU64::from(0),
}
}
Expand Down Expand Up @@ -138,7 +200,7 @@ impl ProxyHandler {
let out = store.data_mut().new_response_outparam(sender)?;
let proxy = self.instance_pre.instantiate_async(&mut store).await?;

let task = tokio::spawn(async move {
let task = self.tracker.spawn(async move {
if let Err(e) = proxy
.wasi_http_incoming_handler()
.call_handle(store, req, out)
Expand Down
36 changes: 30 additions & 6 deletions crates/containerd-shim-wasmtime/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use containerd_shim_wasm::container::{
Engine, Entrypoint, Instance, RuntimeContext, Stdio, WasmBinaryType,
};
use containerd_shim_wasm::sandbox::WasmLayer;
use tokio_util::sync::CancellationToken;
use wasi_preview1::WasiP1Ctx;
use wasi_preview2::bindings::Command;
use wasmtime::component::types::ComponentItem;
Expand Down Expand Up @@ -55,6 +56,7 @@ impl<'a> ComponentTarget<'a> {
#[derive(Clone)]
pub struct WasmtimeEngine<T: WasiConfig> {
engine: wasmtime::Engine,
cancel: CancellationToken,
config_type: PhantomData<T>,
}

Expand All @@ -81,6 +83,7 @@ impl<T: WasiConfig> Default for WasmtimeEngine<T> {
engine: wasmtime::Engine::new(&config)
.context("failed to create wasmtime engine")
.unwrap(),
cancel: CancellationToken::new(),
config_type: PhantomData,
}
}
Expand Down Expand Up @@ -250,7 +253,8 @@ where
let instance = ProxyPre::new(pre)?;

log::info!("starting HTTP server");
serve_conn(ctx, instance).await
let cancel = self.cancel.clone();
serve_conn(ctx, instance, cancel).await
}
ComponentTarget::Command => {
let wasi_ctx = WasiPreview2Ctx::new(ctx)?;
Expand Down Expand Up @@ -303,12 +307,32 @@ where

wasmtime_wasi::runtime::in_tokio(async move {
tokio::select! {
status = self.execute_component_async(ctx, component, func, stdio) => { status }
sig = wait_for_signal() => { sig }
status = self.execute_component_async(ctx, component, func, stdio) => {
status
}
status = self.handle_signals() => {
status
}
}
})
}

async fn handle_signals(&self) -> Result<i32> {
match wait_for_signal().await? {
libc::SIGINT => {
// Request graceful shutdown;
self.cancel.cancel();
}
sig => {
// On other signal, terminate the process without waiting for spawned tasks to finish.
return Ok(128 + sig);
}
}

// On a second SIGINT, terminate the process as well
wait_for_signal().await
}

fn execute(
&self,
ctx: &impl RuntimeContext,
Expand Down Expand Up @@ -397,9 +421,9 @@ async fn wait_for_signal() -> Result<i32> {
let mut sigterm = signal(SignalKind::terminate())?;

tokio::select! {
_ = sigquit.recv() => { Ok(128 + libc::SIGINT) }
_ = sigterm.recv() => { Ok(128 + libc::SIGTERM) }
_ = tokio::signal::ctrl_c() => { Ok(128 + libc::SIGINT) }
_ = sigquit.recv() => { Ok(libc::SIGINT) }
_ = sigterm.recv() => { Ok(libc::SIGTERM) }
_ = tokio::signal::ctrl_c() => { Ok(libc::SIGINT) }
}
}
#[cfg(not(unix))]
Expand Down
57 changes: 41 additions & 16 deletions crates/containerd-shim-wasmtime/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,27 +268,13 @@ fn test_wasip2_component() -> anyhow::Result<()> {
#[test]
#[serial]
fn test_wasip2_component_http_proxy() -> anyhow::Result<()> {
const MAX_ATTEMPTS: u32 = 10;
const BACKOFF_DURATION: Duration = Duration::from_secs(1);

let srv = WasiTest::<WasiInstance>::builder()?
.with_wasm(HELLO_WASI_HTTP)?
.with_host_network()
.build()?;

let srv = srv.start()?;

let mut attempts = 0;
let response = loop {
match reqwest::blocking::get("http://127.0.0.1:8080") {
Ok(resp) => break Ok(resp),
Err(err) if attempts == MAX_ATTEMPTS => break Err(err),
Err(_) => {
std::thread::sleep(BACKOFF_DURATION);
attempts += 1;
}
}
};
let response = http_get();

let response = response.expect("Server did not start in time");
assert!(response.status().is_success());
Expand All @@ -297,7 +283,46 @@ fn test_wasip2_component_http_proxy() -> anyhow::Result<()> {
assert_eq!(body, "Hello, this is your first wasi:http/proxy world!\n");

let (exit_code, _, _) = srv.ctrl_c()?.wait(Duration::from_secs(5))?;
assert_eq!(exit_code, 128 + 2);
assert_eq!(exit_code, 0);

Ok(())
}

// Test that the shim can terminate component targeting wasi:http/proxy by sending SIGTERM.
#[test]
#[serial]
fn test_wasip2_component_http_proxy_force_shutdown() -> anyhow::Result<()> {
let srv = WasiTest::<WasiInstance>::builder()?
.with_wasm(FAULTY_WASI_HTTP)?
.with_host_network()
.build()?;

let srv = srv.start()?;
assert!(http_get().unwrap().status().is_success());

// Send SIGTERM
let (exit_code, _, _) = srv.terminate()?.wait(Duration::from_secs(5))?;
// The exit code indicates that the process did not exit cleanly
assert_eq!(exit_code, 128 + libc::SIGTERM as u32);

Ok(())
}

// Helper method to make a `GET` request
fn http_get() -> reqwest::Result<reqwest::blocking::Response> {
const MAX_ATTEMPTS: u32 = 10;
const BACKOFF_DURATION: Duration = Duration::from_secs(1);

let mut attempts = 0;

loop {
match reqwest::blocking::get("http://127.0.0.1:8080") {
Ok(resp) => break Ok(resp),
Err(err) if attempts == MAX_ATTEMPTS => break Err(err),
Err(_) => {
std::thread::sleep(BACKOFF_DURATION);
attempts += 1;
}
}
}
}

0 comments on commit 63fba0d

Please sign in to comment.