Skip to content

Commit

Permalink
feat: Implement before_connect callback to modify connect options.
Browse files Browse the repository at this point in the history
Allows the user to see and maybe modify the connect options before
each attempt to connect to a database. May be used in a number of
ways, e.g.:
 - adding jitter to connection lifetime
 - validating/setting a per-connection password
 - using a custom server discovery process
  • Loading branch information
fiadliel committed Mar 5, 2024
1 parent 664dbdf commit 94ccd05
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
16 changes: 15 additions & 1 deletion sqlx-core/src/pool/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crossbeam_queue::ArrayQueue;

use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser};

use std::borrow::Cow;
use std::cmp;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
Expand Down Expand Up @@ -300,18 +301,31 @@ impl<DB: Database> PoolInner<DB> {

let mut backoff = Duration::from_millis(10);
let max_backoff = deadline_as_timeout::<DB>(deadline)? / 5;
let mut num_attempts: u32 = 0;

loop {
let timeout = deadline_as_timeout::<DB>(deadline)?;
num_attempts += 1;

// clone the connect options arc so it can be used without holding the RwLockReadGuard
// across an async await point
let connect_options = self
let connect_options_arc = self
.connect_options
.read()
.expect("write-lock holder panicked")
.clone();

let connect_options = if let Some(callback) = &self.options.before_connect {
callback(connect_options_arc.as_ref(), num_attempts)
.await
.map_err(|error| {
tracing::error!(%error, "error returned from before_connect");
error
})?
} else {
Cow::Borrowed(connect_options_arc.as_ref())
};

// result here is `Result<Result<C, Error>, TimeoutError>`
// if this block does not return, sleep for the backoff timeout and try again
match crate::rt::timeout(timeout, connect_options.connect()).await {
Expand Down
63 changes: 63 additions & 0 deletions sqlx-core/src/pool/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::error::Error;
use crate::pool::inner::PoolInner;
use crate::pool::Pool;
use futures_core::future::BoxFuture;
use std::borrow::Cow;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -43,6 +44,18 @@ use std::time::{Duration, Instant};
/// the perspectives of both API designer and consumer.
pub struct PoolOptions<DB: Database> {
pub(crate) test_before_acquire: bool,
pub(crate) before_connect: Option<
Arc<
dyn Fn(
&<DB::Connection as Connection>::Options,
u32,
)
-> BoxFuture<'_, Result<Cow<'_, <DB::Connection as Connection>::Options>, Error>>
+ 'static
+ Send
+ Sync,
>,
>,
pub(crate) after_connect: Option<
Arc<
dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>>
Expand Down Expand Up @@ -90,6 +103,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
fn clone(&self) -> Self {
PoolOptions {
test_before_acquire: self.test_before_acquire,
before_connect: self.before_connect.clone(),
after_connect: self.after_connect.clone(),
before_acquire: self.before_acquire.clone(),
after_release: self.after_release.clone(),
Expand Down Expand Up @@ -136,6 +150,7 @@ impl<DB: Database> PoolOptions<DB> {
pub fn new() -> Self {
Self {
// User-specifiable routines
before_connect: None,
after_connect: None,
before_acquire: None,
after_release: None,
Expand Down Expand Up @@ -292,6 +307,54 @@ impl<DB: Database> PoolOptions<DB> {
self
}

/// Perform an asynchronous action before connecting to the database.
///
/// This operation is performed on every attempt to connect, including retries. The
/// current `ConnectOptions` is passed, and this may be passed unchanged, or modified
/// after cloning. The current connection attempt is passed as the second parameter
/// (starting at 1).
///
/// If the operation returns with an error, then the connection attempt fails without
/// attempting further retries. The operation therefore may need to implement error
/// handling and/or value caching to avoid failing the connection attempt.
///
/// # Example: Per-Request Authentication
/// This callback may be used to modify values in the database's `ConnectOptions`, before
/// connecting to the database.
///
/// This example is written for PostgreSQL but can likely be adapted to other databases.
///
/// ```no_run
/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
/// use std::borrow::Cow;
/// use sqlx::Executor;
/// use sqlx::postgres::PgPoolOptions;
///
/// let pool = PgPoolOptions::new()
/// .after_connect(move |opts, _num_attempts| Box::pin(async move {

Check failure on line 334 in sqlx-core/src/pool/options.rs

View workflow job for this annotation

GitHub Actions / Unit Test (async-std, native-tls)

expected `{async block@sqlx-core/src/pool/options.rs:9:56: 11:6}` to be a future that resolves to `Result<(), Error>`, but it resolves to `Result<Cow<'_, _>, _>`

Check failure on line 334 in sqlx-core/src/pool/options.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, rustls)

expected `{async block@sqlx-core/src/pool/options.rs:9:56: 11:6}` to be a future that resolves to `Result<(), Error>`, but it resolves to `Result<Cow<'_, _>, _>`
/// Ok(Cow::Owned(opts.clone().password("abc")))

Check failure on line 335 in sqlx-core/src/pool/options.rs

View workflow job for this annotation

GitHub Actions / Unit Test (async-std, native-tls)

no method named `clone` found for mutable reference `&mut PgConnection` in the current scope

Check failure on line 335 in sqlx-core/src/pool/options.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, rustls)

no method named `clone` found for mutable reference `&mut PgConnection` in the current scope
/// }))
/// .connect("postgres:// …").await?;
/// # Ok(())
/// # }
/// ```
///
/// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self].
pub fn before_connect<F>(mut self, callback: F) -> Self
where
for<'c> F: Fn(
&'c <DB::Connection as Connection>::Options,
u32,
)
-> BoxFuture<'c, crate::Result<Cow<'c, <DB::Connection as Connection>::Options>>>
+ 'static
+ Send
+ Sync,
{
self.before_connect = Some(Arc::new(callback));
self
}

/// Perform an asynchronous action after connecting to the database.
///
/// If the operation returns with an error then the error is logged, the connection is closed
Expand Down

0 comments on commit 94ccd05

Please sign in to comment.