Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add a layer that catches panics #366

Merged
merged 4 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ timeout = ["tower/timeout"]
limit = ["tower/limit"]
## Support filtering jobs based on a predicate
filter = ["tower/filter"]
## Captures panics in executions and convert them to errors
catch-panic = ["dep:backtrace"]
## Compatibility with async-std and smol runtimes
async-std-comp = ["async-std"]
## Compatibility with tokio and actix runtimes
Expand All @@ -46,6 +48,7 @@ layers = [
"timeout",
"limit",
"filter",
"catch-panic",
]

docsrs = ["document-features"]
Expand Down Expand Up @@ -134,6 +137,7 @@ pin-project-lite = "0.2.14"
uuid = { version = "1.8", optional = true }
ulid = { version = "1", optional = true }
serde = { version = "1.0", features = ["derive"] }
backtrace = { version = "0.3", optional = true }

[dependencies.tracing]
default-features = false
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async fn produce_route_jobs(storage: &RedisStorage<Email>) -> Result<()> {
- _timeout_ — Support timeouts on jobs
- _limit_ — 💪 Limit the amount of jobs
- _filter_ — Support filtering jobs based on a predicate
- _catch-panic_ - Catch panics that occur during execution

## Storage Comparison

Expand Down
2 changes: 1 addition & 1 deletion examples/basics/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
thiserror = "1"
tokio = { version = "1", features = ["full"] }
apalis = { path = "../../", features = ["limit", "tokio-comp"] }
apalis = { path = "../../", features = ["limit", "tokio-comp", "catch-panic"] }
apalis-sql = { path = "../../packages/apalis-sql" }
serde = "1"
tracing-subscriber = "0.3.11"
Expand Down
7 changes: 6 additions & 1 deletion examples/basics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod service;

use std::time::Duration;

use apalis::{layers::tracing::TraceLayer, prelude::*};
use apalis::{
layers::{catch_panic::CatchPanicLayer, tracing::TraceLayer},
prelude::*,
};
use apalis_sql::sqlite::{SqlitePool, SqliteStorage};

use email_service::Email;
Expand Down Expand Up @@ -96,6 +99,8 @@ async fn main() -> Result<(), std::io::Error> {
Monitor::<TokioExecutor>::new()
.register_with_count(2, {
WorkerBuilder::new("tasty-banana")
// This handles any panics that may occur in any of the layers below
.layer(CatchPanicLayer::new())
.layer(TraceLayer::new())
.layer(LogLayer::new("some-log-example"))
// Add shared context to all jobs executed by this worker
Expand Down
181 changes: 181 additions & 0 deletions src/layers/catch_panic/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use std::fmt;
use std::future::Future;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::pin::Pin;
use std::task::{Context, Poll};

use apalis_core::error::Error;
use apalis_core::request::Request;
use backtrace::Backtrace;
use tower::Layer;
use tower::Service;

/// Apalis Layer that catches panics in the service.
#[derive(Clone, Debug)]
pub struct CatchPanicLayer;

impl CatchPanicLayer {
/// Creates a new `CatchPanicLayer`.
pub fn new() -> Self {
CatchPanicLayer
}
}

impl Default for CatchPanicLayer {
fn default() -> Self {
Self::new()
}
}

impl<S> Layer<S> for CatchPanicLayer {
type Service = CatchPanicService<S>;

fn layer(&self, service: S) -> Self::Service {
CatchPanicService { service }
}
}

/// Apalis Service that catches panics.
#[derive(Clone, Debug)]
pub struct CatchPanicService<S> {
service: S,
}

impl<S, J, Res> Service<Request<J>> for CatchPanicService<S>
where
S: Service<Request<J>, Response = Res, Error = Error>,
{
type Response = S::Response;
type Error = S::Error;
type Future = CatchPanicFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}

fn call(&mut self, request: Request<J>) -> Self::Future {
CatchPanicFuture {
future: self.service.call(request),
}
}
}

pin_project_lite::pin_project! {
/// A wrapper that catches panics during execution
pub struct CatchPanicFuture<F> {
#[pin]
future: F,

}
}

/// An error generated from a panic
#[derive(Debug, Clone)]
pub struct PanicError(pub String, pub Backtrace);

impl std::error::Error for PanicError {}

impl fmt::Display for PanicError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PanicError: {}, Backtrace: {:?}", self.0, self.1)
}
}

impl<F, Res> Future for CatchPanicFuture<F>
where
F: Future<Output = Result<Res, Error>>,
{
type Output = Result<Res, Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

match catch_unwind(AssertUnwindSafe(|| this.future.poll(cx))) {
Ok(res) => res,
Err(e) => {
let panic_info = if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
Poll::Ready(Err(Error::Failed(Box::new(PanicError(
panic_info,
Backtrace::new(),
)))))
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::task::{Context, Poll};
use tower::Service;

#[derive(Clone, Debug)]
struct TestJob;

#[derive(Clone)]
struct TestService;

impl Service<Request<TestJob>> for TestService {
type Response = usize;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: Request<TestJob>) -> Self::Future {
Box::pin(async { Ok(42) })
}
}

#[tokio::test]
async fn test_catch_panic_layer() {
let layer = CatchPanicLayer::new();
let mut service = layer.layer(TestService);

let request = Request::new(TestJob);
let response = service.call(request).await;

assert!(response.is_ok());
}

#[tokio::test]
async fn test_catch_panic_layer_panics() {
struct PanicService;

impl Service<Request<TestJob>> for PanicService {
type Response = usize;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: Request<TestJob>) -> Self::Future {
Box::pin(async { None.unwrap() })
}
}

let layer = CatchPanicLayer::new();
let mut service = layer.layer(PanicService);

let request = Request::new(TestJob);
let response = service.call(request).await;

assert!(response.is_err());

assert_eq!(
response.unwrap_err().to_string()[0..87],
*"Task Failed: PanicError: called `Option::unwrap()` on a `None` value, Backtrace: 0: "
);
}
}
5 changes: 5 additions & 0 deletions src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ pub mod limit {
#[cfg(feature = "timeout")]
#[cfg_attr(docsrs, doc(cfg(feature = "timeout")))]
pub use tower::timeout::TimeoutLayer;

/// catch panic middleware for apalis
#[cfg(feature = "catch-panic")]
#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))]
pub mod catch_panic;
Loading