Skip to content

Commit

Permalink
Shutdown the server if any of the tasks crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Dec 13, 2024
1 parent f2221d3 commit 990b992
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 20 deletions.
4 changes: 2 additions & 2 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ impl Options {
shutdown.hard_shutdown_token(),
));

shutdown.run().await;
let exit_code = shutdown.run().await;

Ok(ExitCode::SUCCESS)
Ok(exit_code)
}
}
20 changes: 17 additions & 3 deletions crates/cli/src/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::time::Duration;
use std::{process::ExitCode, time::Duration};

use tokio::signal::unix::{Signal, SignalKind};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
Expand Down Expand Up @@ -74,14 +74,22 @@ impl ShutdownManager {
}

/// Run until we finish completely shutting down.
pub async fn run(mut self) {
pub async fn run(mut self) -> ExitCode {
// Wait for a first signal and trigger the soft shutdown
tokio::select! {
let likely_crashed = tokio::select! {
() = self.soft_shutdown_token.cancelled() => {
tracing::warn!("Another task triggered a shutdown, it likely crashed! Shutting down");
true
},

_ = self.sigterm.recv() => {
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
false
},

_ = self.sigint.recv() => {
tracing::info!("Shutdown signal received (SIGINT), shutting down");
false
},
};

Expand Down Expand Up @@ -112,5 +120,11 @@ impl ShutdownManager {
self.task_tracker().wait().await;

tracing::info!("All tasks are done, exitting");

if likely_crashed {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
}
4 changes: 4 additions & 0 deletions crates/handlers/src/activity_tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ impl ActivityTracker {
interval: std::time::Duration,
cancellation_token: CancellationToken,
) {
// This guard on the shutdown token is to ensure that if this task crashes for
// any reason, the server will shut down
let _guard = cancellation_token.clone().drop_guard();

loop {
tokio::select! {
biased;
Expand Down
4 changes: 4 additions & 0 deletions crates/handlers/src/activity_tracker/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ impl Worker {
mut receiver: tokio::sync::mpsc::Receiver<Message>,
cancellation_token: CancellationToken,
) {
// This guard on the shutdown token is to ensure that if this task crashes for
// any reason, the server will shut down
let _guard = cancellation_token.clone().drop_guard();

loop {
let message = tokio::select! {
// Because we want the cancellation token to trigger only once,
Expand Down
16 changes: 10 additions & 6 deletions crates/listener/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ pub async fn run_servers<S, B>(
B::Data: Send,
B::Error: std::error::Error + Send + Sync + 'static,
{
// This guard on the shutdown token is to ensure that if this task crashes for
// any reason, the server will shut down
let _guard = soft_shutdown_token.clone().drop_guard();

// Create a stream of accepted connections out of the listeners
let mut accept_stream: SelectAll<_> = listeners
.into_iter()
Expand Down Expand Up @@ -360,7 +364,7 @@ pub async fn run_servers<S, B>(
connection_tasks.spawn(conn);
},
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
Some(Err(e)) => tracing::error!("Join error: {e}"),
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
Expand All @@ -369,8 +373,8 @@ pub async fn run_servers<S, B>(
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
Some(Err(e)) => tracing::error!("Join error: {e}"),
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
Expand Down Expand Up @@ -412,7 +416,7 @@ pub async fn run_servers<S, B>(
connection_tasks.spawn(conn);
}
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
Some(Err(e)) => tracing::error!("Join error: {e}"),
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
Expand All @@ -421,8 +425,8 @@ pub async fn run_servers<S, B>(
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
match res {
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
Some(Err(e)) => tracing::error!("Join error: {e}"),
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
None => tracing::error!("Join set was polled even though it was empty"),
}
},
Expand Down
9 changes: 1 addition & 8 deletions crates/tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,7 @@ pub async fn init(
mas_storage::queue::CleanupExpiredTokensJob,
);

task_tracker.spawn(async move {
if let Err(e) = worker.run().await {
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to run new queue"
);
}
});
task_tracker.spawn(worker.run());

Ok(())
}
17 changes: 16 additions & 1 deletion crates/tasks/src/new_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ pub struct QueueWorker {
am_i_leader: bool,
last_heartbeat: DateTime<Utc>,
cancellation_token: CancellationToken,
cancellation_guard: tokio_util::sync::DropGuard,
state: State,
schedules: Vec<ScheduleDefinition>,
tracker: JobTracker,
Expand Down Expand Up @@ -240,6 +241,10 @@ impl QueueWorker {
tracing::info!("Registered worker");
let now = clock.now();

// We put a cancellation drop guard in the structure, so that when it gets
// dropped, we're sure to cancel the token
let cancellation_guard = cancellation_token.clone().drop_guard();

Ok(Self {
rng,
clock,
Expand All @@ -248,6 +253,7 @@ impl QueueWorker {
am_i_leader: false,
last_heartbeat: now,
cancellation_token,
cancellation_guard,
state,
schedules: Vec::new(),
tracker: JobTracker::default(),
Expand Down Expand Up @@ -285,7 +291,16 @@ impl QueueWorker {
self
}

pub async fn run(&mut self) -> Result<(), QueueRunnerError> {
pub async fn run(mut self) {
if let Err(e) = self.run_inner().await {
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to run new queue"
);
}
}

async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
self.setup_schedules().await?;

while !self.cancellation_token.is_cancelled() {
Expand Down

0 comments on commit 990b992

Please sign in to comment.