From e851b66bac4160fdf81e28f37e2673eb85c6823d Mon Sep 17 00:00:00 2001 From: Anthony Dodd Date: Thu, 30 Jan 2020 16:16:35 -0600 Subject: [PATCH] WIP: Added async-trait as an optional dep. For now, included only when postgres is enabled. Extension traits are now being used for PgConnection, PgPoolConnection & PgPool for listen/notify functionality. Only two extension traits were introduced. Only a single trait method is present on the extension traits and it works for single or multi channel listening setups. Automatic reconnect behavior is implemented for PgPool based listeners. All logic has been cut over to the `recv` impls for the PgListener variants. --- Cargo.lock | 12 ++ sqlx-core/Cargo.toml | 3 +- sqlx-core/src/postgres/listen.rs | 305 +++++++++++++++++++------------ sqlx-core/src/postgres/mod.rs | 2 +- 4 files changed, 200 insertions(+), 122 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 36fc825ce6..9a8a49f321 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,6 +102,16 @@ dependencies = [ "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "async-trait" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.13 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "atty" version = "0.2.14" @@ -1377,6 +1387,7 @@ dependencies = [ "async-native-tls 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "async-std 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "async-stream 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "async-trait 0.1.22 (registry+https://github.com/rust-lang/crates.io-index)", "base64 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1844,6 +1855,7 @@ dependencies = [ "checksum async-stream 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "58982858be7540a465c790b95aaea6710e5139bf8956b1d1344d014fa40100b0" "checksum async-stream-impl 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "393356ed99aa7bff0ac486dde592633b83ab02bd254d8c209d5b9f1d0f533480" "checksum async-task 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "a9f534e76ca33eaa82bc8da5adb1b9e94a16f6fa217b78e9b400094dbbf844f9" +"checksum async-trait 0.1.22 (registry+https://github.com/rust-lang/crates.io-index)" = "c8df72488e87761e772f14ae0c2480396810e51b2c2ade912f97f0f7e5b95e3c" "checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" "checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" "checksum autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d" diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index a488b712fb..b0b9ae56fc 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -15,7 +15,7 @@ authors = [ [features] default = [ "runtime-async-std" ] unstable = [] -postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ] +postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "async-trait" ] mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] tls = [ "async-native-tls" ] runtime-async-std = [ "async-native-tls/runtime-async-std", "async-std" ] @@ -24,6 +24,7 @@ runtime-tokio = [ "async-native-tls/runtime-tokio", "tokio" ] [dependencies] async-native-tls = { version = "0.3.2", default-features = false, optional = true } async-std = { version = "1.4.0", optional = true } +async-trait = { version = "0.1.22", optional = true } tokio = { version = "0.2.9", default-features = false, features = [ "dns", "fs", "time", "tcp" ], optional = true } async-stream = { version = "0.2.0", default-features = false } base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] } diff --git a/sqlx-core/src/postgres/listen.rs b/sqlx-core/src/postgres/listen.rs index 822d11bb3a..c1457fd6d6 100644 --- a/sqlx-core/src/postgres/listen.rs +++ b/sqlx-core/src/postgres/listen.rs @@ -1,187 +1,243 @@ use std::ops::DerefMut; +use async_trait::async_trait; use futures_core::future::BoxFuture; -use futures_core::stream::BoxStream; use crate::connection::Connection; -use crate::describe::Describe; use crate::executor::Executor; use crate::pool::PoolConnection; use crate::postgres::protocol::{Message, NotificationResponse}; -use crate::postgres::{PgArguments, PgConnection, PgPool, PgRow, Postgres}; +use crate::postgres::{PgConnection, PgPool}; use crate::Result; type PgPoolConnection = PoolConnection; -impl PgConnection { - /// Register this connection as a listener on the specified channel. - /// - /// If an error is returned here, the connection will be dropped. - pub async fn listen(mut self, channel: impl AsRef) -> Result> { - let cmd = build_listen_all_query(&[channel]); - let _ = self.send(cmd.as_str()).await?; - Ok(PgListener::new(self)) - } - - /// Register this connection as a listener on all of the specified channels. - /// - /// If an error is returned here, the connection will be dropped. - pub async fn listen_all( - mut self, - channels: impl IntoIterator>, - ) -> Result> { - let cmd = build_listen_all_query(channels); - let _ = self.send(cmd.as_str()).await?; - Ok(PgListener::new(self)) - } +/// Extension methods for Postgres connections. +#[async_trait] +pub trait PgConnectionExt { + async fn listen(self, channels: &[&str]) -> PgListener; } -impl PgPool { - /// Fetch a new connection from the pool and register it as a listener on the specified channel. - pub async fn listen(&self, channel: impl AsRef) -> Result> { - let mut conn = self.acquire().await?; - let cmd = build_listen_all_query(&[channel]); - let _ = conn.send(cmd.as_str()).await?; - Ok(PgListener::new(conn)) +#[async_trait] +impl PgConnectionExt for PgConnection { + /// Register this connection as a listener on the specified channels. + async fn listen(mut self, channels: &[&str]) -> PgListener { + PgListener::new(Some(self), channels, None) } +} - /// Fetch a new connection from the pool and register it as a listener on the specified channels. - pub async fn listen_all( - &self, - channels: impl IntoIterator>, - ) -> Result> { - let mut conn = self.acquire().await?; - let cmd = build_listen_all_query(channels); - let _ = conn.send(cmd.as_str()).await?; - Ok(PgListener::new(conn)) +#[async_trait] +impl PgConnectionExt for PgPoolConnection { + /// Register this connection as a listener on the specified channels. + async fn listen(mut self, channels: &[&str]) -> PgListener { + PgListener::new(Some(self), channels, None) } } -impl PgPoolConnection { - /// Register this connection as a listener on the specified channel. - /// - /// If an error is returned here, the connection will be dropped. - pub async fn listen(mut self, channel: impl AsRef) -> Result> { - let cmd = build_listen_all_query(&[channel]); - let _ = self.send(cmd.as_str()).await?; - Ok(PgListener::new(self)) - } +/// Extension methods for Postgres connection pools. +#[async_trait] +pub trait PgPoolExt { + async fn listen(&self, channels: &[&str]) -> Result>; +} - /// Register this connection as a listener on all of the specified channels. +#[async_trait] +impl PgPoolExt for PgPool { + /// Fetch a new connection from the pool and register it as a listener on the specified channel. /// - /// If an error is returned here, the connection will be dropped. - pub async fn listen_all( - mut self, - channels: impl IntoIterator>, - ) -> Result> { - let cmd = build_listen_all_query(channels); - let _ = self.send(cmd.as_str()).await?; - Ok(PgListener::new(self)) + /// If the underlying connection ever dies, a new connection will be acquired from the pool, + /// and listening will resume as normal. + async fn listen(&self, channels: &[&str]) -> Result> { + Ok(PgListener::new(None, channels, Some(self.clone()))) } } /// A stream of async database notifications. /// /// Notifications will always correspond to the channel(s) specified this object is created. -pub struct PgListener(C); +pub struct PgListener { + needs_to_send_listen_cmd: bool, + connection: Option, + channels: Vec, + pool: Option, +} impl PgListener { /// Construct a new instance. - pub(self) fn new(conn: C) -> Self { - Self(conn) + pub(self) fn new(connection: Option, channels: &[&str], pool: Option) -> Self { + let channels = channels.iter().map(|chan| String::from(*chan)).collect(); + Self { + needs_to_send_listen_cmd: true, + connection, + channels, + pool, + } } } -impl PgListener -where - C: DerefMut, -{ - /// Get the next async notification from the database. - pub async fn next(&mut self) -> Result { +impl PgListener { + /// Receives the next notification available from any of the subscribed channels. + /// + /// When a `PgListener` is created from `PgPool.listen(..)`, the `PgListener` will perform + /// automatic reconnects to the database using the original `PgPool` and will submit a + /// `LISTEN` command to the database using the same originally specified channels. As such, + /// this routine will never return `None` when called on a `PgListener` created from a `PgPool`. + /// + /// However, if a `PgListener` instance is created outside of the context of a `PgPool`, then + /// this routine will return `None` when the underlying connection dies. At that point, any + /// further calls to this routine will also return `None`. + pub async fn recv(&mut self) -> Option> { loop { - match (&mut self.0).receive().await? { - Some(Message::NotificationResponse(notification)) => return Ok(notification.into()), - Some(msg) => { - return Err(protocol_err!( + // Ensure we have an active connection to work with. + let conn = match &mut self.connection { + Some(conn) => conn, + None => match self.get_new_connection().await { + // A new connection has been established, bind it and loop. + Ok(Some(conn)) => { + self.connection = Some(conn); + continue; + } + // No pool is present on this listener, return None. + Ok(None) => return None, + // We have a pool to work with, but some error has come up. Return the error. + // The next call to `recv` will build a new connection if available. + Err(err) => return Some(Err(err)), + }, + }; + // Ensure the current connection has properly registered all listener channels. + if self.needs_to_send_listen_cmd { + if let Err(err) = send_listen_query(conn, &self.channels).await { + // If we've encountered an error here, test the connection, drop it if needed, + // and return the error. The next call to recv will build a new connection if possible. + if let Err(_) = conn.ping().await { + self.close_conn().await; + } + return Some(Err(err)); + } + self.needs_to_send_listen_cmd = false; + } + // Await a notification from the DB. + match conn.receive().await { + // We've received an async notification, return it. + Ok(Some(Message::NotificationResponse(notification))) => { + return Some(Ok(notification.into())) + } + // Protocol error, return the error. + Ok(Some(msg)) => { + return Some(Err(protocol_err!( "unexpected message received from database {:?}", msg ) - .into()) + .into())) } - None => continue, + // The connection is dead, ensure that it is dropped, update self state, and loop to try again. + Ok(None) => { + self.close_conn().await; + self.needs_to_send_listen_cmd = true; + continue; + } + // An error has come up, return it. + Err(err) => return Some(Err(err)), } } } -} -impl PgListener -where - C: Connection, -{ - /// Close this listener stream and its underlying connection. - pub async fn close(self) -> BoxFuture<'static, Result<()>> { - self.0.close() + /// Fetch a new connection from the connection pool, if a connection pool is available. + /// + /// Errors here are transient. `Ok(None)` indicates that no pool is available. + async fn get_new_connection(&mut self) -> Result> { + let pool = match &self.pool { + Some(pool) => pool, + None => return Ok(None), + }; + Ok(Some(pool.acquire().await?)) } -} - -impl std::ops::Deref for PgListener { - type Target = C; - fn deref(&self) -> &Self::Target { - &self.0 + /// Close and drop the current connection. + async fn close_conn(&mut self) { + if let Some(conn) = self.connection.take() { + let _ = conn.close().await; + } } } -impl std::ops::DerefMut for PgListener { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 +impl PgListener { + /// Receives the next notification available from any of the subscribed channels. + /// + /// If the underlying connection ever dies, this routine will return `None`. Any further calls + /// to this routine will also return `None`. If automatic reconnect behavior is needed, use + /// `PgPool.listen(..)`, which will automatically establish a new connection from the pool and + /// resusbcribe to all channels. + pub async fn recv(&mut self) -> Option> { + loop { + // Ensure we have an active connection to work with. + let mut conn = match &mut self.connection { + Some(conn) => conn, + None => return None, // This will never practically be hit, but let's make Rust happy. + }; + // Ensure the current connection has properly registered all listener channels. + if self.needs_to_send_listen_cmd { + if let Err(err) = send_listen_query(&mut conn, &self.channels).await { + // If we've encountered an error here, test the connection. If the connection + // is good, we return the error. Else, we return `None` as the connection is dead. + if let Err(_) = conn.ping().await { + return None; + } + return Some(Err(err)); + } + self.needs_to_send_listen_cmd = false; + } + // Await a notification from the DB. + match conn.receive().await { + // We've received an async notification, return it. + Ok(Some(Message::NotificationResponse(notification))) => { + return Some(Ok(notification.into())) + } + // Protocol error, return the error. + Ok(Some(msg)) => { + return Some(Err(protocol_err!( + "unexpected message received from database {:?}", + msg + ) + .into())) + } + // The connection is dead, return None. + Ok(None) => return None, + // An error has come up, return it. + Err(err) => return Some(Err(err)), + } + } } } -impl> crate::Executor for PgListener { - type Database = super::Postgres; - - fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, Result<()>> { - Box::pin(self.0.send(query)) - } - - fn execute<'e, 'q: 'e>( - &'e mut self, - query: &'q str, - args: PgArguments, - ) -> BoxFuture<'e, Result> { - Box::pin(self.0.execute(query, args)) - } - - fn fetch<'e, 'q: 'e>( - &'e mut self, - query: &'q str, - args: PgArguments, - ) -> BoxStream<'e, Result> { - self.0.fetch(query, args) - } - - fn describe<'e, 'q: 'e>( - &'e mut self, - query: &'q str, - ) -> BoxFuture<'e, Result>> { - Box::pin(self.0.describe(query)) +impl PgListener +where + C: Connection, +{ + /// Close this listener stream and its underlying connection. + pub async fn close(self) -> BoxFuture<'static, Result<()>> { + match self.connection { + Some(conn) => conn.close(), + None => Box::pin(futures_util::future::ok(())), + } } } /// An asynchronous message sent from the database. #[derive(Debug)] #[non_exhaustive] -pub struct NotifyMessage { +pub struct PgNotification { + /// The PID of the database process which sent this notification. + pub pid: u32, /// The channel of the notification, which can be thought of as a topic. pub channel: String, /// The payload of the notification. pub payload: String, } -impl From> for NotifyMessage { +impl From> for PgNotification { fn from(src: Box) -> Self { Self { + pid: src.pid, channel: src.channel_name, payload: src.message, } @@ -198,6 +254,15 @@ fn build_listen_all_query(channels: impl IntoIterator>) - }) } +/// Send the structure listen query to the database. +async fn send_listen_query>( + conn: &mut C, + channels: impl IntoIterator>, +) -> Result<()> { + let cmd = build_listen_all_query(channels); + conn.send(cmd.as_str()).await +} + #[cfg(test)] mod tests { use super::*; diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 33ae63e6d6..a3d2554e84 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -4,7 +4,7 @@ pub use arguments::PgArguments; pub use connection::PgConnection; pub use database::Postgres; pub use error::PgError; -pub use listen::{NotifyMessage, PgListener}; +pub use listen::{PgConnectionExt, PgListener, PgNotification, PgPoolExt}; pub use row::PgRow; pub use types::PgTypeInfo;