From d295d859d496f9193fb4cffd364f3958bcde70a9 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Fri, 8 Mar 2024 17:12:08 +0800 Subject: [PATCH] feat: Add retry with context (#72) Close https://github.com/Xuanwo/backon/issues/49 ```rust use anyhow::anyhow; use anyhow::Result; use backon::ExponentialBuilder; use backon::RetryableWithContext; struct Test; impl Test { async fn hello(&mut self) -> Result { Err(anyhow!("not retryable")) } } #[tokio::main] async fn main() -> Result<()> { let mut test = Test; // (Test, Result) let (_, result) = { |mut v: Test| async { let res = v.hello().await; (v, res) } } .retry(&ExponentialBuilder::default()) .context(test) .await; Ok(()) } ``` Signed-off-by: Xuanwo --- src/lib.rs | 3 + src/retry_with_context.rs | 370 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 373 insertions(+) create mode 100644 src/retry_with_context.rs diff --git a/src/lib.rs b/src/lib.rs index 9c544eb..8bc147a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,3 +95,6 @@ pub use retry::Retryable; mod blocking_retry; pub use blocking_retry::BlockingRetry; pub use blocking_retry::BlockingRetryable; + +mod retry_with_context; +pub use retry_with_context::RetryableWithContext; diff --git a/src/retry_with_context.rs b/src/retry_with_context.rs new file mode 100644 index 0000000..f25e801 --- /dev/null +++ b/src/retry_with_context.rs @@ -0,0 +1,370 @@ +use std::future::Future; +use std::pin::{pin, Pin}; +use std::task::Context; +use std::task::Poll; +use std::time::Duration; + +use futures_core::ready; +use pin_project::pin_project; + +use crate::backoff::BackoffBuilder; +use crate::Backoff; + +/// RetryableWithContext will add retry support for functions that produces a futures with results +/// and context. +/// +/// That means all types that implement `FnMut(Ctx) -> impl Future)>` +/// will be able to use `retry`. +/// +/// This will allow users to pass a context to the function and return it back while retry finish. +/// +/// # Example +/// +/// Without context, we could meet errors like the following: +/// +/// ```shell +/// error: captured variable cannot escape `FnMut` closure body +/// --> src/retry.rs:404:27 +/// | +/// 400 | let mut test = Test; +/// | -------- variable defined here +/// ... +/// 404 | let result = { || async { test.hello().await } } +/// | - ^^^^^^^^----^^^^^^^^^^^^^^^^ +/// | | | | +/// | | | variable captured here +/// | | returns an `async` block that contains a reference to a captured variable, which then escapes the closure body +/// | inferred to be a `FnMut` closure +/// | +/// = note: `FnMut` closures only have access to their captured variables while they are executing... +/// = note: ...therefore, they cannot allow references to captured variables to escape +/// ``` +/// +/// But with context support, we can implement in this way: +/// +/// ```no_run +/// use anyhow::anyhow; +/// use anyhow::Result; +/// use backon::ExponentialBuilder; +/// use backon::RetryableWithContext; +/// +/// struct Test; +/// +/// impl Test { +/// async fn hello(&mut self) -> Result { +/// Err(anyhow!("not retryable")) +/// } +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<()> { +/// let mut test = Test; +/// +/// // (Test, Result) +/// let (_, result) = { +/// |mut v: Test| async { +/// let res = v.hello().await; +/// (v, res) +/// } +/// } +/// .retry(&ExponentialBuilder::default()) +/// .context(test) +/// .await; +/// +/// Ok(()) +/// } +/// ``` +pub trait RetryableWithContext< + B: BackoffBuilder, + T, + E, + Ctx, + Fut: Future)>, + FutureFn: FnMut(Ctx) -> Fut, +> +{ + /// Generate a new retry + fn retry(self, builder: &B) -> Retry; +} + +impl RetryableWithContext for FutureFn +where + B: BackoffBuilder, + Fut: Future)>, + FutureFn: FnMut(Ctx) -> Fut, +{ + fn retry(self, builder: &B) -> Retry { + Retry::new(self, builder.build()) + } +} + +/// Retry struct generated by [`Retryable`]. +#[pin_project] +pub struct Retry< + B: Backoff, + T, + E, + Ctx, + Fut: Future)>, + FutureFn: FnMut(Ctx) -> Fut, + RF = fn(&E) -> bool, + NF = fn(&E, Duration), +> { + backoff: B, + retryable: RF, + notify: NF, + future_fn: FutureFn, + + #[pin] + state: State, +} + +impl Retry +where + B: Backoff, + Fut: Future)>, + FutureFn: FnMut(Ctx) -> Fut, +{ + /// Create a new retry. + /// + /// # Notes + /// + /// `context` must be set by `context` method before calling `await`. + fn new(future_fn: FutureFn, backoff: B) -> Self { + Retry { + backoff, + retryable: |_: &E| true, + notify: |_: &E, _: Duration| {}, + future_fn, + state: State::Idle(None), + } + } +} + +impl Retry +where + B: Backoff, + Fut: Future)>, + FutureFn: FnMut(Ctx) -> Fut, + RF: FnMut(&E) -> bool, + NF: FnMut(&E, Duration), +{ + /// Set the context for retrying. + pub fn context(self, context: Ctx) -> Retry { + Retry { + backoff: self.backoff, + retryable: self.retryable, + notify: self.notify, + future_fn: self.future_fn, + state: State::Idle(Some(context)), + } + } + + /// Set the conditions for retrying. + /// + /// If not specified, we treat all errors as retryable. + /// + /// # Examples + /// + /// ```no_run + /// use anyhow::Result; + /// use backon::ExponentialBuilder; + /// use backon::Retryable; + /// + /// async fn fetch() -> Result { + /// Ok(reqwest::get("https://www.rust-lang.org") + /// .await? + /// .text() + /// .await?) + /// } + /// + /// #[tokio::main(flavor = "current_thread")] + /// async fn main() -> Result<()> { + /// let content = fetch + /// .retry(&ExponentialBuilder::default()) + /// .when(|e| e.to_string() == "EOF") + /// .await?; + /// println!("fetch succeeded: {}", content); + /// + /// Ok(()) + /// } + /// ``` + pub fn when bool>( + self, + retryable: RN, + ) -> Retry { + Retry { + backoff: self.backoff, + retryable, + notify: self.notify, + future_fn: self.future_fn, + state: self.state, + } + } + + /// Set to notify for everything retrying. + /// + /// If not specified, this is a no-op. + /// + /// # Examples + /// + /// ```no_run + /// use std::time::Duration; + /// + /// use anyhow::Result; + /// use backon::ExponentialBuilder; + /// use backon::Retryable; + /// + /// async fn fetch() -> Result { + /// Ok(reqwest::get("https://www.rust-lang.org") + /// .await? + /// .text() + /// .await?) + /// } + /// + /// #[tokio::main(flavor = "current_thread")] + /// async fn main() -> Result<()> { + /// let content = fetch + /// .retry(&ExponentialBuilder::default()) + /// .notify(|err: &anyhow::Error, dur: Duration| { + /// println!("retrying error {:?} with sleeping {:?}", err, dur); + /// }) + /// .await?; + /// println!("fetch succeeded: {}", content); + /// + /// Ok(()) + /// } + /// ``` + pub fn notify( + self, + notify: NN, + ) -> Retry { + Retry { + backoff: self.backoff, + retryable: self.retryable, + notify, + future_fn: self.future_fn, + state: self.state, + } + } +} + +/// State maintains internal state of retry. +/// +/// # Notes +/// +/// `tokio::time::Sleep` is a very struct that occupy 640B, so we wrap it +/// into a `Pin>` to avoid this enum too large. +#[pin_project(project = StateProject)] +enum State)>> { + Idle(Option), + Polling(#[pin] Fut), + // TODO: we need to support other sleeper + Sleeping((Option, Pin>)), +} + +impl Future for Retry +where + B: Backoff, + Fut: Future)>, + FutureFn: FnMut(Ctx) -> Fut, + RF: FnMut(&E) -> bool, + NF: FnMut(&E, Duration), +{ + type Output = (Ctx, Result); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + loop { + let state = this.state.as_mut().project(); + match state { + StateProject::Idle(ctx) => { + let ctx = ctx.take().expect("context must be valid"); + let fut = (this.future_fn)(ctx); + this.state.set(State::Polling(fut)); + continue; + } + StateProject::Polling(fut) => { + let (ctx, res) = ready!(fut.poll(cx)); + match res { + Ok(v) => return Poll::Ready((ctx, Ok(v))), + Err(err) => { + // If input error is not retryable, return error directly. + if !(this.retryable)(&err) { + return Poll::Ready((ctx, Err(err))); + } + match this.backoff.next() { + None => return Poll::Ready((ctx, Err(err))), + Some(dur) => { + (this.notify)(&err, dur); + this.state.set(State::Sleeping(( + Some(ctx), + Box::pin(tokio::time::sleep(dur)), + ))); + continue; + } + } + } + } + } + StateProject::Sleeping((ctx, sl)) => { + ready!(pin!(sl).poll(cx)); + let ctx = ctx.take().expect("context must be valid"); + this.state.set(State::Idle(Some(ctx))); + continue; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use anyhow::anyhow; + use tokio::sync::Mutex; + + use super::*; + use crate::exponential::ExponentialBuilder; + use anyhow::Result; + + struct Test; + + impl Test { + async fn hello(&mut self) -> Result { + Err(anyhow!("not retryable")) + } + } + + #[tokio::test] + async fn test_retry_with_not_retryable_error() -> Result<()> { + let error_times = Mutex::new(0); + + let test = Test; + + let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)); + + let (_, result) = { + |mut v: Test| async { + let mut x = error_times.lock().await; + *x += 1; + + let res = v.hello().await; + (v, res) + } + } + .retry(&backoff) + .context(test) + // Only retry If error message is `retryable` + .when(|e| e.to_string() == "retryable") + .await; + + assert!(result.is_err()); + assert_eq!("not retryable", result.unwrap_err().to_string()); + // `f` always returns error "not retryable", so it should be executed + // only once. + assert_eq!(*error_times.lock().await, 1); + Ok(()) + } +}