From b46ea89c42441187d15cd70e651907e29d6338c8 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Fri, 26 Aug 2022 13:46:04 -0700 Subject: [PATCH 001/126] Add hostaddr support --- tokio-postgres/src/config.rs | 70 +++++++++++++++++++++++++++++++++++ tokio-postgres/src/connect.rs | 23 +++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 2c29d629c..f29eed2b1 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -12,6 +12,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] @@ -90,6 +91,17 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -117,6 +129,10 @@ pub enum Host { /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -153,6 +169,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -178,6 +195,7 @@ impl Config { application_name: None, ssl_mode: SslMode::Prefer, host: vec![], + hostaddr: vec![], port: vec![], connect_timeout: None, keepalives: true, @@ -288,6 +306,11 @@ impl Config { &self.host } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[String] { + self.hostaddr.deref() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -300,6 +323,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { + self.hostaddr.push(hostaddr.to_string()); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -418,6 +450,11 @@ impl Config { self.host(host); } } + "hostaddr" => { + for hostaddr in value.split(',') { + self.hostaddr(hostaddr); + } + } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -542,6 +579,7 @@ impl fmt::Debug for Config { .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) + .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("keepalives", &self.keepalives) @@ -922,3 +960,35 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } + +#[cfg(test)] +mod tests { + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + + assert_eq!(1, 1); + } + + #[test] + fn test_empty_hostaddrs() { + let s = + "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; + let config = s.parse::().unwrap(); + assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 88faafe6b..e8ac29b42 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -23,6 +23,15 @@ where return Err(Error::config("invalid number of ports".into())); } + if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + let msg = format!( + "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", + config.hostaddr.len(), + config.host.len(), + ); + return Err(Error::config(msg.into())); + } + let mut error = None; for (i, host) in config.host.iter().enumerate() { let port = config @@ -32,18 +41,28 @@ where .copied() .unwrap_or(5432); + // The value of host is always used as the hostname for TLS validation. let hostname = match host { Host::Tcp(host) => host.as_str(), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Host::Unix(_) => "", }; - let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - match connect_once(host, port, tls, config).await { + // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. + let hostaddr = match host { + Host::Tcp(_hostname) => match config.hostaddr.get(i) { + Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), + _ => host.clone(), + }, + #[cfg(unix)] + Host::Unix(_v) => host.clone(), + }; + + match connect_once(&hostaddr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } From 3c9315e3200f5eb99bb5a9b5998aca555951d691 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:40:57 -0700 Subject: [PATCH 002/126] IpAddr + try hostaddr first --- tokio-postgres/src/config.rs | 36 ++++++++++-------- tokio-postgres/src/connect.rs | 61 +++++++++++++++++++------------ tokio-postgres/tests/test/main.rs | 52 ++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0c62b5030..34accdbe8 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::net::IpAddr; use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; @@ -98,7 +99,9 @@ pub enum Host { /// - or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications /// with time constraints. However, a host name is required for verify-full SSL certificate verification. -/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; /// * If `host` is specified without `hostaddr`, a host name lookup occurs; /// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. /// The value for `host` is ignored unless the authentication method requires it, @@ -174,7 +177,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, - pub(crate) hostaddr: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -317,7 +320,7 @@ impl Config { } /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. - pub fn get_hostaddrs(&self) -> &[String] { + pub fn get_hostaddrs(&self) -> &[IpAddr] { self.hostaddr.deref() } @@ -337,8 +340,8 @@ impl Config { /// /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. - pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { - self.hostaddr.push(hostaddr.to_string()); + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); self } @@ -489,7 +492,10 @@ impl Config { } "hostaddr" => { for hostaddr in value.split(',') { - self.hostaddr(hostaddr); + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); } } "port" => { @@ -1016,6 +1022,8 @@ impl<'a> UrlParser<'a> { #[cfg(test)] mod tests { + use std::net::IpAddr; + use crate::{config::Host, Config}; #[test] @@ -1032,16 +1040,14 @@ mod tests { config.get_hosts(), ); - assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); assert_eq!(1, 1); } - - #[test] - fn test_empty_hostaddrs() { - let s = - "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; - let config = s.parse::().unwrap(); - assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); - } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index c36677234..ee1dc1c76 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -5,8 +5,8 @@ use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use std::io; use std::task::Poll; +use std::{cmp, io}; pub async fn connect( mut tls: T, @@ -15,25 +15,35 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); } - if config.port.len() > 1 && config.port.len() != config.host.len() { - return Err(Error::config("invalid number of ports".into())); - } - - if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { let msg = format!( - "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", - config.hostaddr.len(), + "number of hosts ({}) is different from number of hostaddrs ({})", config.host.len(), + config.hostaddr.len(), ); return Err(Error::config(msg.into())); } + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { + return Err(Error::config("invalid number of ports".into())); + } + let mut error = None; - for (i, host) in config.host.iter().enumerate() { + for i in 0..num_hosts { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); let port = config .port .get(i) @@ -42,27 +52,30 @@ where .unwrap_or(5432); // The value of host is always used as the hostname for TLS validation. + // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter let hostname = match host { - Host::Tcp(host) => host.as_str(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", + Some(Host::Tcp(host)) => host.as_str(), + _ => "", }; let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. - let hostaddr = match host { - Host::Tcp(_hostname) => match config.hostaddr.get(i) { - Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), - _ => host.clone(), - }, - #[cfg(unix)] - Host::Unix(_v) => host.clone(), + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => { + if let Some(host) = host { + host.clone() + } else { + // This is unreachable. + return Err(Error::config("both host and hostaddr are empty".into())); + } + } }; - match connect_once(&hostaddr, port, tls, config).await { + match connect_once(&addr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..387c90d7c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,6 +147,58 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; From e30bff65a35d1240f8b920c49569a40563712e5d Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:11 -0700 Subject: [PATCH 003/126] also update postgres --- postgres/src/config.rs | 33 +++++++++++++++++++++++++++++++++ tokio-postgres/src/config.rs | 1 + 2 files changed, 34 insertions(+) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..a754ff91f 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -6,6 +6,7 @@ use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -39,6 +40,19 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -67,6 +81,10 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -204,6 +222,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -214,6 +233,11 @@ impl Config { self.config.get_hosts() } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -226,6 +250,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 34accdbe8..923da2985 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -302,6 +302,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { From 6c49a452feb273430d0091de83961ad65ffb9102 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:47 -0700 Subject: [PATCH 004/126] fmt --- postgres/src/config.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a754ff91f..921566b66 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -83,7 +83,7 @@ use tokio_postgres::{Error, Socket}; /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write /// ``` -/// +/// /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` @@ -236,7 +236,7 @@ impl Config { /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. pub fn get_hostaddrs(&self) -> &[IpAddr] { self.config.get_hostaddrs() - } + } /// Adds a Unix socket host to the configuration. /// From 42fef24973dff5450b294df21e94e665fe4d996d Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:09:53 -0700 Subject: [PATCH 005/126] explicitly handle host being None --- tokio-postgres/src/connect.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ee1dc1c76..63574516c 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -51,14 +51,17 @@ where .copied() .unwrap_or(5432); - // The value of host is always used as the hostname for TLS validation. - // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter + // The value of host is used as the hostname for TLS validation, + // if it's not present, use the value of hostaddr. let hostname = match host { - Some(Host::Tcp(host)) => host.as_str(), - _ => "", + Some(Host::Tcp(host)) => host.clone(), + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + #[cfg(unix)] + Some(Host::Unix(_)) => "".to_string(), + None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), }; let tls = tls - .make_tls_connect(hostname) + .make_tls_connect(&hostname) .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, From 9b34d74df143527602a18b1564b554647dbf5eaf Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:18:36 -0700 Subject: [PATCH 006/126] add negative test --- tokio-postgres/src/config.rs | 6 ++++++ tokio-postgres/src/connect.rs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 923da2985..e5bed8ddf 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -1051,4 +1051,10 @@ mod tests { assert_eq!(1, 1); } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 63574516c..888f9cf8a 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -55,7 +55,7 @@ where // if it's not present, use the value of hostaddr. let hostname = match host { Some(Host::Tcp(host)) => host.clone(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Some(Host::Unix(_)) => "".to_string(), None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), From 8ac10ff1de52281592d5bdd75e109d995ca33a2c Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Tue, 30 Aug 2022 22:10:19 -0700 Subject: [PATCH 007/126] move test to runtime --- tokio-postgres/tests/test/main.rs | 52 ---------------------------- tokio-postgres/tests/test/runtime.rs | 52 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 387c90d7c..0ab4a7bab 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,58 +147,6 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } -#[tokio::test] -async fn host_only_ok() { - let _ = tokio_postgres::connect( - "host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_only_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_and_host_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_mismatch() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_both_missing() { - let _ = tokio_postgres::connect( - "port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67b4ead8a..86c1f0701 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,6 +66,58 @@ async fn target_session_attrs_err() { .unwrap(); } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; From 3697f6b63c67073925e1db4d5bb74f1a4dc8c3f3 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Fri, 26 Aug 2022 13:46:04 -0700 Subject: [PATCH 008/126] Add hostaddr support --- tokio-postgres/src/config.rs | 70 +++++++++++++++++++++++++++++++++++ tokio-postgres/src/connect.rs | 23 +++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 5b364ec06..0c62b5030 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] @@ -91,6 +92,17 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -122,6 +134,10 @@ pub enum Host { /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -158,6 +174,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -188,6 +205,7 @@ impl Config { application_name: None, ssl_mode: SslMode::Prefer, host: vec![], + hostaddr: vec![], port: vec![], connect_timeout: None, keepalives: true, @@ -298,6 +316,11 @@ impl Config { &self.host } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[String] { + self.hostaddr.deref() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -310,6 +333,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { + self.hostaddr.push(hostaddr.to_string()); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -455,6 +487,11 @@ impl Config { self.host(host); } } + "hostaddr" => { + for hostaddr in value.split(',') { + self.hostaddr(hostaddr); + } + } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -593,6 +630,7 @@ impl fmt::Debug for Config { .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) + .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("keepalives", &self.keepalives) @@ -975,3 +1013,35 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } + +#[cfg(test)] +mod tests { + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + + assert_eq!(1, 1); + } + + #[test] + fn test_empty_hostaddrs() { + let s = + "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; + let config = s.parse::().unwrap(); + assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 97a00c812..c36677234 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -23,6 +23,15 @@ where return Err(Error::config("invalid number of ports".into())); } + if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + let msg = format!( + "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", + config.hostaddr.len(), + config.host.len(), + ); + return Err(Error::config(msg.into())); + } + let mut error = None; for (i, host) in config.host.iter().enumerate() { let port = config @@ -32,18 +41,28 @@ where .copied() .unwrap_or(5432); + // The value of host is always used as the hostname for TLS validation. let hostname = match host { Host::Tcp(host) => host.as_str(), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Host::Unix(_) => "", }; - let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - match connect_once(host, port, tls, config).await { + // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. + let hostaddr = match host { + Host::Tcp(_hostname) => match config.hostaddr.get(i) { + Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), + _ => host.clone(), + }, + #[cfg(unix)] + Host::Unix(_v) => host.clone(), + }; + + match connect_once(&hostaddr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } From 48874dc5753e33f49508ba986d7f1d7bc74b4a74 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:40:57 -0700 Subject: [PATCH 009/126] IpAddr + try hostaddr first --- tokio-postgres/src/config.rs | 36 ++++++++++-------- tokio-postgres/src/connect.rs | 61 +++++++++++++++++++------------ tokio-postgres/tests/test/main.rs | 52 ++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0c62b5030..34accdbe8 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,6 +13,7 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::net::IpAddr; use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; @@ -98,7 +99,9 @@ pub enum Host { /// - or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications /// with time constraints. However, a host name is required for verify-full SSL certificate verification. -/// Note that `host` is always required regardless of whether `hostaddr` is present. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; /// * If `host` is specified without `hostaddr`, a host name lookup occurs; /// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. /// The value for `host` is ignored unless the authentication method requires it, @@ -174,7 +177,7 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, - pub(crate) hostaddr: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) keepalives: bool, @@ -317,7 +320,7 @@ impl Config { } /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. - pub fn get_hostaddrs(&self) -> &[String] { + pub fn get_hostaddrs(&self) -> &[IpAddr] { self.hostaddr.deref() } @@ -337,8 +340,8 @@ impl Config { /// /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. - pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config { - self.hostaddr.push(hostaddr.to_string()); + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); self } @@ -489,7 +492,10 @@ impl Config { } "hostaddr" => { for hostaddr in value.split(',') { - self.hostaddr(hostaddr); + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); } } "port" => { @@ -1016,6 +1022,8 @@ impl<'a> UrlParser<'a> { #[cfg(test)] mod tests { + use std::net::IpAddr; + use crate::{config::Host, Config}; #[test] @@ -1032,16 +1040,14 @@ mod tests { config.get_hosts(), ); - assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),); + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); assert_eq!(1, 1); } - - #[test] - fn test_empty_hostaddrs() { - let s = - "user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2"; - let config = s.parse::().unwrap(); - assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),); - } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index c36677234..ee1dc1c76 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -5,8 +5,8 @@ use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use std::io; use std::task::Poll; +use std::{cmp, io}; pub async fn connect( mut tls: T, @@ -15,25 +15,35 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); } - if config.port.len() > 1 && config.port.len() != config.host.len() { - return Err(Error::config("invalid number of ports".into())); - } - - if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() { + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { let msg = format!( - "invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})", - config.hostaddr.len(), + "number of hosts ({}) is different from number of hostaddrs ({})", config.host.len(), + config.hostaddr.len(), ); return Err(Error::config(msg.into())); } + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { + return Err(Error::config("invalid number of ports".into())); + } + let mut error = None; - for (i, host) in config.host.iter().enumerate() { + for i in 0..num_hosts { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); let port = config .port .get(i) @@ -42,27 +52,30 @@ where .unwrap_or(5432); // The value of host is always used as the hostname for TLS validation. + // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter let hostname = match host { - Host::Tcp(host) => host.as_str(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", + Some(Host::Tcp(host)) => host.as_str(), + _ => "", }; let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - // If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection. - let hostaddr = match host { - Host::Tcp(_hostname) => match config.hostaddr.get(i) { - Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()), - _ => host.clone(), - }, - #[cfg(unix)] - Host::Unix(_v) => host.clone(), + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => { + if let Some(host) = host { + host.clone() + } else { + // This is unreachable. + return Err(Error::config("both host and hostaddr are empty".into())); + } + } }; - match connect_once(&hostaddr, port, tls, config).await { + match connect_once(&addr, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..387c90d7c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,6 +147,58 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; From d97bed635ef3fe21a3d9dbef0945e57ab2baf8ba Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:11 -0700 Subject: [PATCH 010/126] also update postgres --- postgres/src/config.rs | 33 +++++++++++++++++++++++++++++++++ tokio-postgres/src/config.rs | 1 + 2 files changed, 34 insertions(+) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..a754ff91f 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -6,6 +6,7 @@ use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -39,6 +40,19 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// - or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -67,6 +81,10 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -204,6 +222,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -214,6 +233,11 @@ impl Config { self.config.get_hosts() } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -226,6 +250,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 34accdbe8..923da2985 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -302,6 +302,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { From 1a9c1d4ff3e25b7bef01f05c3e396b2eec1564d9 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sat, 27 Aug 2022 11:55:47 -0700 Subject: [PATCH 011/126] fmt --- postgres/src/config.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a754ff91f..921566b66 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -83,7 +83,7 @@ use tokio_postgres::{Error, Socket}; /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write /// ``` -/// +/// /// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` @@ -236,7 +236,7 @@ impl Config { /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. pub fn get_hostaddrs(&self) -> &[IpAddr] { self.config.get_hostaddrs() - } + } /// Adds a Unix socket host to the configuration. /// From 58149dacf6f4633a3c2b24cda442623bd2abb08d Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:09:53 -0700 Subject: [PATCH 012/126] explicitly handle host being None --- tokio-postgres/src/connect.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ee1dc1c76..63574516c 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -51,14 +51,17 @@ where .copied() .unwrap_or(5432); - // The value of host is always used as the hostname for TLS validation. - // postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter + // The value of host is used as the hostname for TLS validation, + // if it's not present, use the value of hostaddr. let hostname = match host { - Some(Host::Tcp(host)) => host.as_str(), - _ => "", + Some(Host::Tcp(host)) => host.clone(), + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + #[cfg(unix)] + Some(Host::Unix(_)) => "".to_string(), + None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), }; let tls = tls - .make_tls_connect(hostname) + .make_tls_connect(&hostname) .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, From 7a648ad0cb911cb9144c0db441399f3189d28b3b Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Sun, 28 Aug 2022 12:18:36 -0700 Subject: [PATCH 013/126] add negative test --- tokio-postgres/src/config.rs | 6 ++++++ tokio-postgres/src/connect.rs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 923da2985..e5bed8ddf 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -1051,4 +1051,10 @@ mod tests { assert_eq!(1, 1); } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 63574516c..888f9cf8a 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -55,7 +55,7 @@ where // if it's not present, use the value of hostaddr. let hostname = match host { Some(Host::Tcp(host)) => host.clone(), - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter Some() + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Some(Host::Unix(_)) => "".to_string(), None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), From a70a7c36c74bfeaf1e171dc2572fddd30d182179 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Tue, 30 Aug 2022 22:10:19 -0700 Subject: [PATCH 014/126] move test to runtime --- tokio-postgres/tests/test/main.rs | 52 ---------------------------- tokio-postgres/tests/test/runtime.rs | 52 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 387c90d7c..0ab4a7bab 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -147,58 +147,6 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } -#[tokio::test] -async fn host_only_ok() { - let _ = tokio_postgres::connect( - "host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_only_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_and_host_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_mismatch() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_both_missing() { - let _ = tokio_postgres::connect( - "port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - #[tokio::test] async fn pipelined_prepare() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67b4ead8a..86c1f0701 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,6 +66,58 @@ async fn target_session_attrs_err() { .unwrap(); } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; From 071dfa3f3b217a32b1e2ab3db9e6ab5132f2fcd1 Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Sun, 26 Mar 2023 20:33:29 +1100 Subject: [PATCH 015/126] added a rename_all container attribute for enums and structs --- postgres-derive-test/src/composites.rs | 43 +++++++ postgres-derive-test/src/enums.rs | 29 +++++ postgres-derive/src/case.rs | 158 +++++++++++++++++++++++++ postgres-derive/src/composites.rs | 26 ++-- postgres-derive/src/enums.rs | 13 +- postgres-derive/src/fromsql.rs | 9 +- postgres-derive/src/lib.rs | 1 + postgres-derive/src/overrides.rs | 32 ++++- postgres-derive/src/tosql.rs | 9 +- 9 files changed, 299 insertions(+), 21 deletions(-) create mode 100644 postgres-derive/src/case.rs diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index a1b76345f..50a22790d 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -89,6 +89,49 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")] + struct InventoryItem { + name: String, + supplier_id: i32, + #[postgres(name = "Price")] + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + \"NAME\" TEXT, + \"SUPPLIER_ID\" INT, + \"Price\" DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + #[test] fn wrong_name() { #[derive(FromSql, ToSql, Debug, PartialEq)] diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index a7039ca05..e44f37616 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -53,6 +53,35 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood", rename_all = "snake_case")] + enum Mood { + Sad, + #[postgres(name = "okay")] + Ok, + Happy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('sad', 'okay', 'happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::Sad, "'sad'"), + (Mood::Ok, "'okay'"), + (Mood::Happy, "'happy'"), + ], + ); +} + #[test] fn wrong_name() { #[derive(Debug, ToSql, FromSql, PartialEq)] diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs new file mode 100644 index 000000000..b128990c5 --- /dev/null +++ b/postgres-derive/src/case.rs @@ -0,0 +1,158 @@ +#[allow(deprecated, unused_imports)] +use std::ascii::AsciiExt; + +use self::RenameRule::*; + +/// The different possible ways to change case of fields in a struct, or variants in an enum. +#[allow(clippy::enum_variant_names)] +#[derive(Copy, Clone, PartialEq)] +pub enum RenameRule { + /// Rename direct children to "lowercase" style. + LowerCase, + /// Rename direct children to "UPPERCASE" style. + UpperCase, + /// Rename direct children to "PascalCase" style, as typically used for + /// enum variants. + PascalCase, + /// Rename direct children to "camelCase" style. + CamelCase, + /// Rename direct children to "snake_case" style, as commonly used for + /// fields. + SnakeCase, + /// Rename direct children to "SCREAMING_SNAKE_CASE" style, as commonly + /// used for constants. + ScreamingSnakeCase, + /// Rename direct children to "kebab-case" style. + KebabCase, + /// Rename direct children to "SCREAMING-KEBAB-CASE" style. + ScreamingKebabCase, +} + +pub static RENAME_RULES: &[(&str, RenameRule)] = &[ + ("lowercase", LowerCase), + ("UPPERCASE", UpperCase), + ("PascalCase", PascalCase), + ("camelCase", CamelCase), + ("snake_case", SnakeCase), + ("SCREAMING_SNAKE_CASE", ScreamingSnakeCase), + ("kebab-case", KebabCase), + ("SCREAMING-KEBAB-CASE", ScreamingKebabCase), +]; + +impl RenameRule { + /// Apply a renaming rule to an enum variant, returning the version expected in the source. + pub fn apply_to_variant(&self, variant: &str) -> String { + match *self { + PascalCase => variant.to_owned(), + LowerCase => variant.to_ascii_lowercase(), + UpperCase => variant.to_ascii_uppercase(), + CamelCase => variant[..1].to_ascii_lowercase() + &variant[1..], + SnakeCase => { + let mut snake = String::new(); + for (i, ch) in variant.char_indices() { + if i > 0 && ch.is_uppercase() { + snake.push('_'); + } + snake.push(ch.to_ascii_lowercase()); + } + snake + } + ScreamingSnakeCase => SnakeCase.apply_to_variant(variant).to_ascii_uppercase(), + KebabCase => SnakeCase.apply_to_variant(variant).replace('_', "-"), + ScreamingKebabCase => ScreamingSnakeCase + .apply_to_variant(variant) + .replace('_', "-"), + } + } + + /// Apply a renaming rule to a struct field, returning the version expected in the source. + pub fn apply_to_field(&self, field: &str) -> String { + match *self { + LowerCase | SnakeCase => field.to_owned(), + UpperCase => field.to_ascii_uppercase(), + PascalCase => { + let mut pascal = String::new(); + let mut capitalize = true; + for ch in field.chars() { + if ch == '_' { + capitalize = true; + } else if capitalize { + pascal.push(ch.to_ascii_uppercase()); + capitalize = false; + } else { + pascal.push(ch); + } + } + pascal + } + CamelCase => { + let pascal = PascalCase.apply_to_field(field); + pascal[..1].to_ascii_lowercase() + &pascal[1..] + } + ScreamingSnakeCase => field.to_ascii_uppercase(), + KebabCase => field.replace('_', "-"), + ScreamingKebabCase => ScreamingSnakeCase.apply_to_field(field).replace('_', "-"), + } + } +} + +#[test] +fn rename_variants() { + for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ + ( + "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "VeryTasty", + "verytasty", + "VERYTASTY", + "veryTasty", + "very_tasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("A", "a", "A", "a", "a", "A", "a", "A"), + ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(LowerCase.apply_to_variant(original), lower); + assert_eq!(UpperCase.apply_to_variant(original), upper); + assert_eq!(PascalCase.apply_to_variant(original), original); + assert_eq!(CamelCase.apply_to_variant(original), camel); + assert_eq!(SnakeCase.apply_to_variant(original), snake); + assert_eq!(ScreamingSnakeCase.apply_to_variant(original), screaming); + assert_eq!(KebabCase.apply_to_variant(original), kebab); + assert_eq!( + ScreamingKebabCase.apply_to_variant(original), + screaming_kebab + ); + } +} + +#[test] +fn rename_fields() { + for &(original, upper, pascal, camel, screaming, kebab, screaming_kebab) in &[ + ( + "outcome", "OUTCOME", "Outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "very_tasty", + "VERY_TASTY", + "VeryTasty", + "veryTasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("a", "A", "A", "a", "A", "a", "A"), + ("z42", "Z42", "Z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(UpperCase.apply_to_field(original), upper); + assert_eq!(PascalCase.apply_to_field(original), pascal); + assert_eq!(CamelCase.apply_to_field(original), camel); + assert_eq!(SnakeCase.apply_to_field(original), original); + assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); + assert_eq!(KebabCase.apply_to_field(original), kebab); + assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); + } +} diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index 15bfabc13..dcff2c581 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -4,7 +4,7 @@ use syn::{ TypeParamBound, }; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Field { pub name: String, @@ -13,18 +13,26 @@ pub struct Field { } impl Field { - pub fn parse(raw: &syn::Field) -> Result { + pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { let overrides = Overrides::extract(&raw.attrs)?; - let ident = raw.ident.as_ref().unwrap().clone(); - Ok(Field { - name: overrides.name.unwrap_or_else(|| { + + // field level name override takes precendence over container level rename_all override + let name = match overrides.name { + Some(n) => n, + None => { let name = ident.to_string(); - match name.strip_prefix("r#") { - Some(name) => name.to_string(), - None => name, + let stripped = name.strip_prefix("r#").map(String::from).unwrap_or(name); + + match rename_all { + Some(rule) => rule.apply_to_field(&stripped), + None => stripped, } - }), + } + }; + + Ok(Field { + name, ident, type_: raw.ty.clone(), }) diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3c6bc7113..d99eca1c4 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -1,6 +1,6 @@ use syn::{Error, Fields, Ident}; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Variant { pub ident: Ident, @@ -8,7 +8,7 @@ pub struct Variant { } impl Variant { - pub fn parse(raw: &syn::Variant) -> Result { + pub fn parse(raw: &syn::Variant, rename_all: Option) -> Result { match raw.fields { Fields::Unit => {} _ => { @@ -18,11 +18,16 @@ impl Variant { )) } } - let overrides = Overrides::extract(&raw.attrs)?; + + // variant level name override takes precendence over container level rename_all override + let name = overrides.name.unwrap_or_else(|| match rename_all { + Some(rule) => rule.apply_to_variant(&raw.ident.to_string()), + None => raw.ident.to_string(), + }); Ok(Variant { ident: raw.ident.clone(), - name: overrides.name.unwrap_or_else(|| raw.ident.to_string()), + name, }) } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index bb87ded5f..3736e01e9 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -24,7 +24,10 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -51,7 +54,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( accepts::enum_body(&name, &variants), @@ -75,7 +78,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "FromSql", &fields), diff --git a/postgres-derive/src/lib.rs b/postgres-derive/src/lib.rs index 98e6add24..b849096c9 100644 --- a/postgres-derive/src/lib.rs +++ b/postgres-derive/src/lib.rs @@ -7,6 +7,7 @@ use proc_macro::TokenStream; use syn::parse_macro_input; mod accepts; +mod case; mod composites; mod enums; mod fromsql; diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index ddb37688b..3918446a2 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -1,8 +1,11 @@ use syn::punctuated::Punctuated; use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token}; +use crate::case::{RenameRule, RENAME_RULES}; + pub struct Overrides { pub name: Option, + pub rename_all: Option, pub transparent: bool, } @@ -10,6 +13,7 @@ impl Overrides { pub fn extract(attrs: &[Attribute]) -> Result { let mut overrides = Overrides { name: None, + rename_all: None, transparent: false, }; @@ -28,7 +32,9 @@ impl Overrides { for item in nested { match item { Meta::NameValue(meta) => { - if !meta.path.is_ident("name") { + let name_override = meta.path.is_ident("name"); + let rename_all_override = meta.path.is_ident("rename_all"); + if !name_override && !rename_all_override { return Err(Error::new_spanned(&meta.path, "unknown override")); } @@ -41,7 +47,29 @@ impl Overrides { } }; - overrides.name = Some(value); + if name_override { + overrides.name = Some(value); + } else if rename_all_override { + let rename_rule = RENAME_RULES + .iter() + .find(|rule| rule.0 == value) + .map(|val| val.1) + .ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule.0)) + .collect::>() + .join(", ") + ), + ) + })?; + + overrides.rename_all = Some(rename_rule); + } } Meta::Path(path) => { if !path.is_ident("transparent") { diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index e51acc7fd..1e91df4f6 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -22,7 +22,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -47,7 +50,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( accepts::enum_body(&name, &variants), @@ -69,7 +72,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "ToSql", &fields), From bc8ad8aee69f14e367de2f42c8d3a61c1d9c144b Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Mon, 27 Mar 2023 18:22:53 +1100 Subject: [PATCH 016/126] Distinguish between field and container attributes when parsing --- postgres-derive/src/composites.rs | 2 +- postgres-derive/src/enums.rs | 2 +- postgres-derive/src/fromsql.rs | 2 +- postgres-derive/src/overrides.rs | 8 +++++++- postgres-derive/src/tosql.rs | 2 +- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index dcff2c581..b6aad8ab3 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -14,7 +14,7 @@ pub struct Field { impl Field { pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { - let overrides = Overrides::extract(&raw.attrs)?; + let overrides = Overrides::extract(&raw.attrs, false)?; let ident = raw.ident.as_ref().unwrap().clone(); // field level name override takes precendence over container level rename_all override diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index d99eca1c4..3e4b5045f 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -18,7 +18,7 @@ impl Variant { )) } } - let overrides = Overrides::extract(&raw.attrs)?; + let overrides = Overrides::extract(&raw.attrs, false)?; // variant level name override takes precendence over container level rename_all override let name = overrides.name.unwrap_or_else(|| match rename_all { diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index 3736e01e9..4deb23ed2 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -15,7 +15,7 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; if overrides.name.is_some() && overrides.transparent { return Err(Error::new_spanned( diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 3918446a2..7f28375bc 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -10,7 +10,7 @@ pub struct Overrides { } impl Overrides { - pub fn extract(attrs: &[Attribute]) -> Result { + pub fn extract(attrs: &[Attribute], container_attr: bool) -> Result { let mut overrides = Overrides { name: None, rename_all: None, @@ -34,6 +34,12 @@ impl Overrides { Meta::NameValue(meta) => { let name_override = meta.path.is_ident("name"); let rename_all_override = meta.path.is_ident("rename_all"); + if !container_attr && rename_all_override { + return Err(Error::new_spanned( + &meta.path, + "rename_all is a container attribute", + )); + } if !name_override && !rename_all_override { return Err(Error::new_spanned(&meta.path, "unknown override")); } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index 1e91df4f6..dbeeb16c3 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -13,7 +13,7 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; if overrides.name.is_some() && overrides.transparent { return Err(Error::new_spanned( From d509b3bc52df9cf0d7f1f2ac5ac64b0bfc643160 Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Mon, 27 Mar 2023 18:45:05 +1100 Subject: [PATCH 017/126] Replaced case conversion with heck --- postgres-derive/Cargo.toml | 1 + postgres-derive/src/case.rs | 138 ++++++++++--------------------- postgres-derive/src/enums.rs | 2 +- postgres-derive/src/overrides.rs | 30 +++---- 4 files changed, 60 insertions(+), 111 deletions(-) diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 8470bc8a9..cfc8829f4 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -15,3 +15,4 @@ test = false syn = "2.0" proc-macro2 = "1.0" quote = "1.0" +heck = "0.4" \ No newline at end of file diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs index b128990c5..20ecc8eed 100644 --- a/postgres-derive/src/case.rs +++ b/postgres-derive/src/case.rs @@ -1,6 +1,11 @@ #[allow(deprecated, unused_imports)] use std::ascii::AsciiExt; +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTrainCase, + ToUpperCamelCase, +}; + use self::RenameRule::*; /// The different possible ways to change case of fields in a struct, or variants in an enum. @@ -26,78 +31,56 @@ pub enum RenameRule { KebabCase, /// Rename direct children to "SCREAMING-KEBAB-CASE" style. ScreamingKebabCase, + + /// Rename direct children to "Train-Case" style. + TrainCase, } -pub static RENAME_RULES: &[(&str, RenameRule)] = &[ - ("lowercase", LowerCase), - ("UPPERCASE", UpperCase), - ("PascalCase", PascalCase), - ("camelCase", CamelCase), - ("snake_case", SnakeCase), - ("SCREAMING_SNAKE_CASE", ScreamingSnakeCase), - ("kebab-case", KebabCase), - ("SCREAMING-KEBAB-CASE", ScreamingKebabCase), +pub const RENAME_RULES: &[&str] = &[ + "lowercase", + "UPPERCASE", + "PascalCase", + "camelCase", + "snake_case", + "SCREAMING_SNAKE_CASE", + "kebab-case", + "SCREAMING-KEBAB-CASE", + "Train-Case", ]; impl RenameRule { - /// Apply a renaming rule to an enum variant, returning the version expected in the source. - pub fn apply_to_variant(&self, variant: &str) -> String { - match *self { - PascalCase => variant.to_owned(), - LowerCase => variant.to_ascii_lowercase(), - UpperCase => variant.to_ascii_uppercase(), - CamelCase => variant[..1].to_ascii_lowercase() + &variant[1..], - SnakeCase => { - let mut snake = String::new(); - for (i, ch) in variant.char_indices() { - if i > 0 && ch.is_uppercase() { - snake.push('_'); - } - snake.push(ch.to_ascii_lowercase()); - } - snake - } - ScreamingSnakeCase => SnakeCase.apply_to_variant(variant).to_ascii_uppercase(), - KebabCase => SnakeCase.apply_to_variant(variant).replace('_', "-"), - ScreamingKebabCase => ScreamingSnakeCase - .apply_to_variant(variant) - .replace('_', "-"), + pub fn from_str(rule: &str) -> Option { + match rule { + "lowercase" => Some(LowerCase), + "UPPERCASE" => Some(UpperCase), + "PascalCase" => Some(PascalCase), + "camelCase" => Some(CamelCase), + "snake_case" => Some(SnakeCase), + "SCREAMING_SNAKE_CASE" => Some(ScreamingSnakeCase), + "kebab-case" => Some(KebabCase), + "SCREAMING-KEBAB-CASE" => Some(ScreamingKebabCase), + "Train-Case" => Some(TrainCase), + _ => None, } } - - /// Apply a renaming rule to a struct field, returning the version expected in the source. - pub fn apply_to_field(&self, field: &str) -> String { + /// Apply a renaming rule to an enum or struct field, returning the version expected in the source. + pub fn apply_to_field(&self, variant: &str) -> String { match *self { - LowerCase | SnakeCase => field.to_owned(), - UpperCase => field.to_ascii_uppercase(), - PascalCase => { - let mut pascal = String::new(); - let mut capitalize = true; - for ch in field.chars() { - if ch == '_' { - capitalize = true; - } else if capitalize { - pascal.push(ch.to_ascii_uppercase()); - capitalize = false; - } else { - pascal.push(ch); - } - } - pascal - } - CamelCase => { - let pascal = PascalCase.apply_to_field(field); - pascal[..1].to_ascii_lowercase() + &pascal[1..] - } - ScreamingSnakeCase => field.to_ascii_uppercase(), - KebabCase => field.replace('_', "-"), - ScreamingKebabCase => ScreamingSnakeCase.apply_to_field(field).replace('_', "-"), + LowerCase => variant.to_lowercase(), + UpperCase => variant.to_uppercase(), + PascalCase => variant.to_upper_camel_case(), + CamelCase => variant.to_lower_camel_case(), + SnakeCase => variant.to_snake_case(), + ScreamingSnakeCase => variant.to_shouty_snake_case(), + KebabCase => variant.to_kebab_case(), + ScreamingKebabCase => variant.to_shouty_kebab_case(), + TrainCase => variant.to_train_case(), } } } #[test] -fn rename_variants() { +fn rename_field() { for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ ( "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", @@ -115,42 +98,11 @@ fn rename_variants() { ("A", "a", "A", "a", "a", "A", "a", "A"), ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), ] { - assert_eq!(LowerCase.apply_to_variant(original), lower); - assert_eq!(UpperCase.apply_to_variant(original), upper); - assert_eq!(PascalCase.apply_to_variant(original), original); - assert_eq!(CamelCase.apply_to_variant(original), camel); - assert_eq!(SnakeCase.apply_to_variant(original), snake); - assert_eq!(ScreamingSnakeCase.apply_to_variant(original), screaming); - assert_eq!(KebabCase.apply_to_variant(original), kebab); - assert_eq!( - ScreamingKebabCase.apply_to_variant(original), - screaming_kebab - ); - } -} - -#[test] -fn rename_fields() { - for &(original, upper, pascal, camel, screaming, kebab, screaming_kebab) in &[ - ( - "outcome", "OUTCOME", "Outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", - ), - ( - "very_tasty", - "VERY_TASTY", - "VeryTasty", - "veryTasty", - "VERY_TASTY", - "very-tasty", - "VERY-TASTY", - ), - ("a", "A", "A", "a", "A", "a", "A"), - ("z42", "Z42", "Z42", "z42", "Z42", "z42", "Z42"), - ] { + assert_eq!(LowerCase.apply_to_field(original), lower); assert_eq!(UpperCase.apply_to_field(original), upper); - assert_eq!(PascalCase.apply_to_field(original), pascal); + assert_eq!(PascalCase.apply_to_field(original), original); assert_eq!(CamelCase.apply_to_field(original), camel); - assert_eq!(SnakeCase.apply_to_field(original), original); + assert_eq!(SnakeCase.apply_to_field(original), snake); assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); assert_eq!(KebabCase.apply_to_field(original), kebab); assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3e4b5045f..9a6dfa926 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -22,7 +22,7 @@ impl Variant { // variant level name override takes precendence over container level rename_all override let name = overrides.name.unwrap_or_else(|| match rename_all { - Some(rule) => rule.apply_to_variant(&raw.ident.to_string()), + Some(rule) => rule.apply_to_field(&raw.ident.to_string()), None => raw.ident.to_string(), }); Ok(Variant { diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 7f28375bc..99faeebb7 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -56,23 +56,19 @@ impl Overrides { if name_override { overrides.name = Some(value); } else if rename_all_override { - let rename_rule = RENAME_RULES - .iter() - .find(|rule| rule.0 == value) - .map(|val| val.1) - .ok_or_else(|| { - Error::new_spanned( - &meta.value, - format!( - "invalid rename_all rule, expected one of: {}", - RENAME_RULES - .iter() - .map(|rule| format!("\"{}\"", rule.0)) - .collect::>() - .join(", ") - ), - ) - })?; + let rename_rule = RenameRule::from_str(&value).ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule)) + .collect::>() + .join(", ") + ), + ) + })?; overrides.rename_all = Some(rename_rule); } From f4b181a20180f1853351be53a32865b6209d0ab4 Mon Sep 17 00:00:00 2001 From: jaydenelliott Date: Tue, 28 Mar 2023 22:25:50 +1100 Subject: [PATCH 018/126] Rename_all attribute documentation --- postgres-derive-test/src/enums.rs | 10 +++++----- postgres-derive/src/fromsql.rs | 4 ++-- postgres-derive/src/tosql.rs | 4 ++-- postgres-types/src/lib.rs | 31 +++++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index e44f37616..36d428437 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -58,15 +58,15 @@ fn rename_all_overrides() { #[derive(Debug, ToSql, FromSql, PartialEq)] #[postgres(name = "mood", rename_all = "snake_case")] enum Mood { - Sad, + VerySad, #[postgres(name = "okay")] Ok, - Happy, + VeryHappy, } let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); conn.execute( - "CREATE TYPE pg_temp.mood AS ENUM ('sad', 'okay', 'happy')", + "CREATE TYPE pg_temp.mood AS ENUM ('very_sad', 'okay', 'very_happy')", &[], ) .unwrap(); @@ -75,9 +75,9 @@ fn rename_all_overrides() { &mut conn, "mood", &[ - (Mood::Sad, "'sad'"), + (Mood::VerySad, "'very_sad'"), (Mood::Ok, "'okay'"), - (Mood::Happy, "'happy'"), + (Mood::VeryHappy, "'very_happy'"), ], ); } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index 4deb23ed2..a9150411a 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -17,10 +17,10 @@ use crate::overrides::Overrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index dbeeb16c3..ec7602312 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -15,10 +15,10 @@ use crate::overrides::Overrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..5fca049a7 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -125,6 +125,37 @@ //! Happy, //! } //! ``` +//! +//! Alternatively, the `#[postgres(rename_all = "...")]` attribute can be used to rename all fields or variants +//! with the chosen casing convention. This will not affect the struct or enum's type name. Note that +//! `#[postgres(name = "...")]` takes precendence when used in conjunction with `#[postgres(rename_all = "...")]`: +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(name = "mood", rename_all = "snake_case")] +//! enum Mood { +//! VerySad, // very_sad +//! #[postgres(name = "ok")] +//! Ok, // ok +//! VeryHappy, // very_happy +//! } +//! ``` +//! +//! The following case conventions are supported: +//! - `"lowercase"` +//! - `"UPPERCASE"` +//! - `"PascalCase"` +//! - `"camelCase"` +//! - `"snake_case"` +//! - `"SCREAMING_SNAKE_CASE"` +//! - `"kebab-case"` +//! - `"SCREAMING-KEBAB-CASE"` +//! - `"Train-Case"` + #![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] From b19fdd4b7ecab1e30e56f55dc95de8d53f9d14da Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Thu, 30 Mar 2023 19:30:40 -0400 Subject: [PATCH 019/126] Fix postgres-protocol constraint Closes #1012 --- tokio-postgres/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index e5451e2a2..4dc93e3a2 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -53,7 +53,7 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } From 45d51d708c645f0ebbd3d0dcf5f3eaad3d461916 Mon Sep 17 00:00:00 2001 From: Niklas Hallqvist Date: Tue, 4 Apr 2023 14:27:45 +0200 Subject: [PATCH 020/126] OpenBSD misses some TCP keepalive options --- tokio-postgres/src/keepalive.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index 74f453985..24d8d2c0e 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -12,12 +12,12 @@ impl From<&KeepaliveConfig> for TcpKeepalive { fn from(keepalive_config: &KeepaliveConfig) -> Self { let mut tcp_keepalive = Self::new().with_time(keepalive_config.idle); - #[cfg(not(any(target_os = "redox", target_os = "solaris")))] + #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "openbsd")))] if let Some(interval) = keepalive_config.interval { tcp_keepalive = tcp_keepalive.with_interval(interval); } - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows")))] + #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows", target_os = "openbsd")))] if let Some(retries) = keepalive_config.retries { tcp_keepalive = tcp_keepalive.with_retries(retries); } From e59a16524190db45eead594c61b6a9012ad3a3b9 Mon Sep 17 00:00:00 2001 From: Niklas Hallqvist Date: Tue, 4 Apr 2023 15:43:39 +0200 Subject: [PATCH 021/126] rustfmt --- tokio-postgres/src/keepalive.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index 24d8d2c0e..c409eb0ea 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -17,7 +17,12 @@ impl From<&KeepaliveConfig> for TcpKeepalive { tcp_keepalive = tcp_keepalive.with_interval(interval); } - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows", target_os = "openbsd")))] + #[cfg(not(any( + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "openbsd" + )))] if let Some(retries) = keepalive_config.retries { tcp_keepalive = tcp_keepalive.with_retries(retries); } From a67fe643a9dc483530ba1df5cf09e3dfdec90c98 Mon Sep 17 00:00:00 2001 From: Basti Ortiz <39114273+BastiDood@users.noreply.github.com> Date: Fri, 7 Apr 2023 21:39:37 +0800 Subject: [PATCH 022/126] refactor(types): simplify `<&str as ToSql>::to_sql` --- postgres-types/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..c34fbe66d 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1012,10 +1012,10 @@ impl ToSql for Vec { impl<'a> ToSql for &'a str { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match *ty { - ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), - ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), - ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + match ty.name() { + "ltree" => types::ltree_to_sql(self, w), + "lquery" => types::lquery_to_sql(self, w), + "ltxtquery" => types::ltxtquery_to_sql(self, w), _ => types::text_to_sql(self, w), } Ok(IsNull::No) From 98abdf9fa25a2e908fd62c5961655e00989fafa2 Mon Sep 17 00:00:00 2001 From: Basti Ortiz <39114273+BastiDood@users.noreply.github.com> Date: Fri, 7 Apr 2023 21:43:25 +0800 Subject: [PATCH 023/126] refactor(types): prefer `matches!` macro for readability --- postgres-types/src/lib.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index c34fbe66d..291e069da 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1022,18 +1022,10 @@ impl<'a> ToSql for &'a str { } fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty - if (ty.name() == "citext" - || ty.name() == "ltree" - || ty.name() == "lquery" - || ty.name() == "ltxtquery") => - { - true - } - _ => false, - } + matches!( + *ty, + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN + ) || matches!(ty.name(), "citext" | "ltree" | "lquery" | "ltxtquery") } to_sql_checked!(); From e71335ee43978311b2c1f253afef6c92abdaac88 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 1 May 2023 19:33:49 -0400 Subject: [PATCH 024/126] fix serialization of oidvector --- postgres-types/src/lib.rs | 8 +++++++- tokio-postgres/src/connect_socket.rs | 4 +++- tokio-postgres/tests/test/types/mod.rs | 11 +++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 291e069da..c4c448c4a 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -910,9 +910,15 @@ impl<'a, T: ToSql> ToSql for &'a [T] { _ => panic!("expected array type"), }; + // Arrays are normally one indexed by default but oidvector *requires* zero indexing + let lower_bound = match *ty { + Type::OID_VECTOR => 0, + _ => 1, + }; + let dimension = ArrayDimension { len: downcast(self.len())?, - lower_bound: 1, + lower_bound, }; types::array_to_sql( diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 9b3d31d72..1204ca1ff 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -14,7 +14,9 @@ pub(crate) async fn connect_socket( host: &Host, port: u16, connect_timeout: Option, - tcp_user_timeout: Option, + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< + Duration, + >, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { match host { diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 452d149fe..0f1d38242 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -739,3 +739,14 @@ async fn ltxtquery_any() { ) .await; } + +#[tokio::test] +async fn oidvector() { + test_type( + "oidvector", + // NB: postgres does not support empty oidarrays! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0u32, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +} From d92b3b0a63e7abba41d56cebd06356d1a50db879 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 1 May 2023 19:45:54 -0400 Subject: [PATCH 025/126] Fix int2vector serialization --- postgres-types/src/lib.rs | 4 ++-- tokio-postgres/tests/test/types/mod.rs | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index c4c448c4a..b03c389a9 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -910,9 +910,9 @@ impl<'a, T: ToSql> ToSql for &'a [T] { _ => panic!("expected array type"), }; - // Arrays are normally one indexed by default but oidvector *requires* zero indexing + // Arrays are normally one indexed by default but oidvector and int2vector *require* zero indexing let lower_bound = match *ty { - Type::OID_VECTOR => 0, + Type::OID_VECTOR | Type::INT2_VECTOR => 0, _ => 1, }; diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 0f1d38242..f1a44da08 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -750,3 +750,14 @@ async fn oidvector() { ) .await; } + +#[tokio::test] +async fn int2vector() { + test_type( + "int2vector", + // NB: postgres does not support empty int2vectors! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0i16, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +} From 80adf0448b95548dabd8354ae6988f801e7a5965 Mon Sep 17 00:00:00 2001 From: Ibiyemi Abiodun Date: Sun, 7 May 2023 13:37:52 -0400 Subject: [PATCH 026/126] allow `BorrowToSql` for non-static `Box` --- postgres-types/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 291e069da..6517b4a95 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1178,17 +1178,17 @@ impl BorrowToSql for &dyn ToSql { } } -impl sealed::Sealed for Box {} +impl<'a> sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() } } -impl sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> sealed::Sealed for Box {} +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() From 066b466f4443d0d51c6b1d409f3a2c93019ca27e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 7 May 2023 13:48:50 -0400 Subject: [PATCH 027/126] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8044b2f47..8e91c6faf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.64.0 + version: 1.65.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From 40954901a422838800a0f99608bf0ab308e5e9aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 14:01:30 +0000 Subject: [PATCH 028/126] Update criterion requirement from 0.4 to 0.5 Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version. - [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md) - [Commits](https://github.com/bheisler/criterion.rs/compare/0.4.0...0.5.0) --- updated-dependencies: - dependency-name: criterion dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- postgres/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index e0b2a249d..044bb91e1 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -45,5 +45,5 @@ tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" [dev-dependencies] -criterion = "0.4" +criterion = "0.5" tokio = { version = "1.0", features = ["rt-multi-thread"] } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 4dc93e3a2..b5c6d0ae6 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -61,7 +61,7 @@ tokio-util = { version = "0.7", features = ["codec"] } [dev-dependencies] futures-executor = "0.3" -criterion = "0.4" +criterion = "0.5" env_logger = "0.10" tokio = { version = "1.0", features = [ "macros", From 64bf779f7c91524b820e60226a6b8c8075d2dfa4 Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Sat, 3 Jun 2023 09:18:58 -0400 Subject: [PATCH 029/126] feat: add support for wasm Adds support for compiling to WASM environments that provide JS via wasm-bindgen. Because there's no standardized socket API the caller must provide a connection that implements AsyncRead/AsyncWrite to connect_raw. --- Cargo.toml | 1 + postgres-protocol/Cargo.toml | 3 +++ tokio-postgres/Cargo.toml | 4 +++- tokio-postgres/src/config.rs | 42 ++++++++++++++++++++++++++---------- tokio-postgres/src/lib.rs | 1 + 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4752836a7..80a7739c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "codegen", "postgres", diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index e32211369..1c6422e7d 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -19,3 +19,6 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2.9", features = ["js"] } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index b5c6d0ae6..af0e6dee0 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -55,10 +55,12 @@ pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } -socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +socket2 = { version = "0.5", features = ["all"] } + [dev-dependencies] futures-executor = "0.3" criterion = "0.5" diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a8aa7a9f5..2b2be08ef 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -3,6 +3,7 @@ #[cfg(feature = "runtime")] use crate::connect::connect; use crate::connect_raw::connect_raw; +#[cfg(not(target_arch = "wasm32"))] use crate::keepalive::KeepaliveConfig; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -165,6 +166,7 @@ pub struct Config { pub(crate) connect_timeout: Option, pub(crate) tcp_user_timeout: Option, pub(crate) keepalives: bool, + #[cfg(not(target_arch = "wasm32"))] pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, @@ -179,11 +181,6 @@ impl Default for Config { impl Config { /// Creates a new configuration. pub fn new() -> Config { - let keepalive_config = KeepaliveConfig { - idle: Duration::from_secs(2 * 60 * 60), - interval: None, - retries: None, - }; Config { user: None, password: None, @@ -196,7 +193,12 @@ impl Config { connect_timeout: None, tcp_user_timeout: None, keepalives: true, - keepalive_config, + #[cfg(not(target_arch = "wasm32"))] + keepalive_config: KeepaliveConfig { + idle: Duration::from_secs(2 * 60 * 60), + interval: None, + retries: None, + }, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, } @@ -377,6 +379,7 @@ impl Config { /// Sets the amount of idle time before a keepalive packet is sent on the connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { self.keepalive_config.idle = keepalives_idle; self @@ -384,6 +387,7 @@ impl Config { /// Gets the configured amount of idle time before a keepalive packet will /// be sent on the connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_idle(&self) -> Duration { self.keepalive_config.idle } @@ -392,12 +396,14 @@ impl Config { /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config { self.keepalive_config.interval = Some(keepalives_interval); self } /// Gets the time interval between TCP keepalive probes. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_interval(&self) -> Option { self.keepalive_config.interval } @@ -405,12 +411,14 @@ impl Config { /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config { self.keepalive_config.retries = Some(keepalives_retries); self } /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_retries(&self) -> Option { self.keepalive_config.retries } @@ -503,12 +511,14 @@ impl Config { self.tcp_user_timeout(Duration::from_secs(timeout as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives" => { let keepalives = value .parse::() .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?; self.keepalives(keepalives != 0); } + #[cfg(not(target_arch = "wasm32"))] "keepalives_idle" => { let keepalives_idle = value .parse::() @@ -517,6 +527,7 @@ impl Config { self.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_interval" => { let keepalives_interval = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_interval"))) @@ -525,6 +536,7 @@ impl Config { self.keepalives_interval(Duration::from_secs(keepalives_interval as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_retries" => { let keepalives_retries = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_retries"))) @@ -614,7 +626,8 @@ impl fmt::Debug for Config { } } - f.debug_struct("Config") + let mut config_dbg = &mut f.debug_struct("Config"); + config_dbg = config_dbg .field("user", &self.user) .field("password", &self.password.as_ref().map(|_| Redaction {})) .field("dbname", &self.dbname) @@ -625,10 +638,17 @@ impl fmt::Debug for Config { .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("tcp_user_timeout", &self.tcp_user_timeout) - .field("keepalives", &self.keepalives) - .field("keepalives_idle", &self.keepalive_config.idle) - .field("keepalives_interval", &self.keepalive_config.interval) - .field("keepalives_retries", &self.keepalive_config.retries) + .field("keepalives", &self.keepalives); + + #[cfg(not(target_arch = "wasm32"))] + { + config_dbg = config_dbg + .field("keepalives_idle", &self.keepalive_config.idle) + .field("keepalives_interval", &self.keepalive_config.interval) + .field("keepalives_retries", &self.keepalive_config.retries); + } + + config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) .finish() diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..2bb410187 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -163,6 +163,7 @@ mod copy_in; mod copy_out; pub mod error; mod generic_client; +#[cfg(not(target_arch = "wasm32"))] mod keepalive; mod maybe_tls_stream; mod portal; From 2230e88533acccf5632b2d43aff315c88a2507a2 Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Sat, 3 Jun 2023 17:32:48 -0400 Subject: [PATCH 030/126] add CI job for checking wasm Adds a CI job for ensuring the tokio-postgres crate builds on the wasm32-unknown-unknown target without the default features. --- .github/workflows/ci.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e91c6faf..46f97e48f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,6 +47,33 @@ jobs: key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - run: cargo clippy --all --all-targets + check-wasm32: + name: check-wasm32 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - run: rustup target add wasm32-unknown-unknown + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features + test: name: test runs-on: ubuntu-latest From edc7fdecfb9f81b923bfe904edefd41e7076fa8c Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Sun, 4 Jun 2023 13:02:03 -0400 Subject: [PATCH 031/126] gate wasm support behind feature flag --- Cargo.toml | 1 - postgres-protocol/Cargo.toml | 8 +++++--- tokio-postgres/Cargo.toml | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 80a7739c8..4752836a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,4 @@ [workspace] -resolver = "2" members = [ "codegen", "postgres", diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 1c6422e7d..ad609f6fa 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -8,6 +8,10 @@ license = "MIT/Apache-2.0" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" +[features] +default = [] +js = ["getrandom/js"] + [dependencies] base64 = "0.21" byteorder = "1.0" @@ -19,6 +23,4 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" - -[target.'cfg(target_arch = "wasm32")'.dependencies] -getrandom = { version = "0.2.9", features = ["js"] } +getrandom = { version = "0.2", optional = true } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index af0e6dee0..12d8a66fd 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -40,6 +40,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"] with-uuid-1 = ["postgres-types/with-uuid-1"] with-time-0_2 = ["postgres-types/with-time-0_2"] with-time-0_3 = ["postgres-types/with-time-0_3"] +js = ["postgres-protocol/js"] [dependencies] async-trait = "0.1" From 1f8fb7a16c131ed50a46fc139838327e8a604775 Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Wed, 7 Jun 2023 21:17:54 -0400 Subject: [PATCH 032/126] ignore dev deps in wasm ci --- .github/workflows/ci.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 46f97e48f..99cf652d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master - + clippy: name: clippy runs-on: ubuntu-latest @@ -72,7 +72,12 @@ jobs: with: path: target key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features + - run: | + # Hack: wasm support currently relies on not having tokio with features like socket enabled. With resolver 1 + # dev dependencies can add unwanted dependencies to the build, so we'll hackily disable them for this check. + + sed -i 's/\[dev-dependencies]/[ignore-dependencies]/g' ./tokio-postgres/Cargo.toml + cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features test: name: test From 635bac4665d4a744a523e6d843f67ffed33b6cff Mon Sep 17 00:00:00 2001 From: Zeb Piasecki Date: Fri, 9 Jun 2023 11:15:06 -0400 Subject: [PATCH 033/126] specify js feature for wasm ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 99cf652d2..0064369c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,7 +77,7 @@ jobs: # dev dependencies can add unwanted dependencies to the build, so we'll hackily disable them for this check. sed -i 's/\[dev-dependencies]/[ignore-dependencies]/g' ./tokio-postgres/Cargo.toml - cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features + cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js test: name: test From 6f19bb9000bd5e53cd7613f0f96a24c3657533b6 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 10 Jun 2023 10:21:34 -0400 Subject: [PATCH 034/126] clean up wasm32 test --- .github/workflows/ci.yml | 9 ++------- Cargo.toml | 1 + 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0064369c9..ebe0f600f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,13 +71,8 @@ jobs: - uses: actions/cache@v3 with: path: target - key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - - run: | - # Hack: wasm support currently relies on not having tokio with features like socket enabled. With resolver 1 - # dev dependencies can add unwanted dependencies to the build, so we'll hackily disable them for this check. - - sed -i 's/\[dev-dependencies]/[ignore-dependencies]/g' ./tokio-postgres/Cargo.toml - cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js + key: check-wasm32-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js test: name: test diff --git a/Cargo.toml b/Cargo.toml index 4752836a7..16e3739dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "postgres-types", "tokio-postgres", ] +resolver = "2" [profile.release] debug = 2 From 258fe68f193b7951e20f244ecbbf664d7629f0eb Mon Sep 17 00:00:00 2001 From: Vinicius Hirschle Date: Sat, 29 Apr 2023 21:52:01 -0300 Subject: [PATCH 035/126] feat(derive): add `#[postgres(allow_mismatch)]` --- .../compile-fail/invalid-allow-mismatch.rs | 31 ++++++++ .../invalid-allow-mismatch.stderr | 43 +++++++++++ postgres-derive-test/src/enums.rs | 72 ++++++++++++++++++- postgres-derive/src/accepts.rs | 42 ++++++----- postgres-derive/src/fromsql.rs | 22 +++++- postgres-derive/src/overrides.rs | 22 +++++- postgres-derive/src/tosql.rs | 22 +++++- postgres-types/src/lib.rs | 23 +++++- 8 files changed, 250 insertions(+), 27 deletions(-) create mode 100644 postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs create mode 100644 postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs new file mode 100644 index 000000000..52d0ba8f6 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs @@ -0,0 +1,31 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchStruct { + a: i32, +} + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchStruct { + a: i32, +} + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent, allow_mismatch)] +struct TransparentFromSqlAllowMismatchStruct(i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch, transparent)] +struct AllowMismatchFromSqlTransparentStruct(i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr new file mode 100644 index 000000000..a8e573248 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr @@ -0,0 +1,43 @@ +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:4:1 + | +4 | / #[postgres(allow_mismatch)] +5 | | struct ToSqlAllowMismatchStruct { +6 | | a: i32, +7 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:10:1 + | +10 | / #[postgres(allow_mismatch)] +11 | | struct FromSqlAllowMismatchStruct { +12 | | a: i32, +13 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:16:1 + | +16 | / #[postgres(allow_mismatch)] +17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32); + | |_______________________________________________^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:20:1 + | +20 | / #[postgres(allow_mismatch)] +21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32); + | |_________________________________________________^ + +error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)] + --> src/compile-fail/invalid-allow-mismatch.rs:24:25 + | +24 | #[postgres(transparent, allow_mismatch)] + | ^^^^^^^^^^^^^^ + +error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)] + --> src/compile-fail/invalid-allow-mismatch.rs:28:28 + | +28 | #[postgres(allow_mismatch, transparent)] + | ^^^^^^^^^^^ diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index 36d428437..f3e6c488c 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -1,5 +1,5 @@ use crate::test_type; -use postgres::{Client, NoTls}; +use postgres::{error::DbError, Client, NoTls}; use postgres_types::{FromSql, ToSql, WrongType}; use std::error::Error; @@ -131,3 +131,73 @@ fn missing_variant() { let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); assert!(err.source().unwrap().is::()); } + +#[test] +fn allow_mismatch_enums() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Bar); +} + +#[test] +fn missing_enum_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn + .query_one("SELECT $1::\"Foo\"", &[&Foo::Buz]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_and_renaming() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo", allow_mismatch)] + enum Foo { + #[postgres(name = "bar")] + Bar, + #[postgres(name = "buz")] + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Buz); +} + +#[test] +fn wrong_name_and_allow_mismatch() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs index 63473863a..a68538dcc 100644 --- a/postgres-derive/src/accepts.rs +++ b/postgres-derive/src/accepts.rs @@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream { } } -pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream { +pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream { let num_variants = variants.len(); let variant_names = variants.iter().map(|v| &v.name); - quote! { - if type_.name() != #name { - return false; + if allow_mismatch { + quote! { + type_.name() == #name } + } else { + quote! { + if type_.name() != #name { + return false; + } - match *type_.kind() { - ::postgres_types::Kind::Enum(ref variants) => { - if variants.len() != #num_variants { - return false; - } - - variants.iter().all(|v| { - match &**v { - #( - #variant_names => true, - )* - _ => false, + match *type_.kind() { + ::postgres_types::Kind::Enum(ref variants) => { + if variants.len() != #num_variants { + return false; } - }) + + variants.iter().all(|v| { + match &**v { + #( + #variant_names => true, + )* + _ => false, + } + }) + } + _ => false, } - _ => false, } } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index a9150411a..d3ac47f4f 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -48,6 +48,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )) } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { @@ -57,7 +77,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 99faeebb7..d50550bee 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -7,6 +7,7 @@ pub struct Overrides { pub name: Option, pub rename_all: Option, pub transparent: bool, + pub allow_mismatch: bool, } impl Overrides { @@ -15,6 +16,7 @@ impl Overrides { name: None, rename_all: None, transparent: false, + allow_mismatch: false, }; for attr in attrs { @@ -74,11 +76,25 @@ impl Overrides { } } Meta::Path(path) => { - if !path.is_ident("transparent") { + if path.is_ident("transparent") { + if overrides.allow_mismatch { + return Err(Error::new_spanned( + path, + "#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]", + )); + } + overrides.transparent = true; + } else if path.is_ident("allow_mismatch") { + if overrides.transparent { + return Err(Error::new_spanned( + path, + "#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]", + )); + } + overrides.allow_mismatch = true; + } else { return Err(Error::new_spanned(path, "unknown override")); } - - overrides.transparent = true; } bad => return Err(Error::new_spanned(bad, "unknown attribute")), } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index ec7602312..81d4834bf 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -44,6 +44,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { @@ -53,7 +73,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index edd723977..cb82e2f93 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -138,7 +138,6 @@ //! #[derive(Debug, ToSql, FromSql)] //! #[postgres(name = "mood", rename_all = "snake_case")] //! enum Mood { -//! VerySad, // very_sad //! #[postgres(name = "ok")] //! Ok, // ok //! VeryHappy, // very_happy @@ -155,10 +154,28 @@ //! - `"kebab-case"` //! - `"SCREAMING-KEBAB-CASE"` //! - `"Train-Case"` - +//! +//! ## Allowing Enum Mismatches +//! +//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum +//! variants between the Rust and Postgres types. +//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! #[postgres(allow_mismatch)] +//! enum Mood { +//! Happy, +//! Meh, +//! } +//! ``` #![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] - use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; use std::any::type_name; From b09e9cc6426728a9df665992a6a1e8cb2c4afbec Mon Sep 17 00:00:00 2001 From: Andrew Baxter Date: Thu, 20 Jul 2023 22:54:19 +0900 Subject: [PATCH 036/126] Add to_sql for bytes Cow as well --- postgres-types/src/lib.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index edd723977..34c8cc0b8 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1035,6 +1035,18 @@ impl ToSql for Box<[T]> { to_sql_checked!(); } +impl<'a> ToSql for Cow<'a, [u8]> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&str as ToSql>::to_sql(&self.as_ref(), ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[u8] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + impl ToSql for Vec { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { <&[u8] as ToSql>::to_sql(&&**self, ty, w) From 34c8dc9d1957f6b663c4236217ec7134ad1d3c5b Mon Sep 17 00:00:00 2001 From: andrew <> Date: Thu, 20 Jul 2023 23:30:27 +0900 Subject: [PATCH 037/126] Fixes --- postgres-types/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 34c8cc0b8..1f56c468f 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1037,7 +1037,7 @@ impl ToSql for Box<[T]> { impl<'a> ToSql for Cow<'a, [u8]> { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - <&str as ToSql>::to_sql(&self.as_ref(), ty, w) + <&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w) } fn accepts(ty: &Type) -> bool { From f7a264473d8ba78a280f1fe173ecb9f3662be7f3 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 22 Jul 2023 20:40:47 -0400 Subject: [PATCH 038/126] align hostaddr tls behavior with documentation --- tokio-postgres/src/cancel_query.rs | 14 +++++--------- tokio-postgres/src/cancel_query_raw.rs | 2 +- tokio-postgres/src/cancel_token.rs | 2 +- tokio-postgres/src/client.rs | 1 + tokio-postgres/src/config.rs | 6 +++--- tokio-postgres/src/connect.rs | 25 ++++++++++++++----------- tokio-postgres/src/connect_raw.rs | 2 +- tokio-postgres/src/connect_tls.rs | 9 +++++++-- 8 files changed, 33 insertions(+), 28 deletions(-) diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index d869b5824..8e35a4224 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -1,5 +1,5 @@ use crate::client::SocketConfig; -use crate::config::{Host, SslMode}; +use crate::config::SslMode; use crate::tls::MakeTlsConnect; use crate::{cancel_query_raw, connect_socket, Error, Socket}; use std::io; @@ -24,14 +24,10 @@ where } }; - let hostname = match &config.host { - Host::Tcp(host) => &**host, - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", - }; - let tls = tls - .make_tls_connect(hostname) + let tls = config + .hostname + .map(|s| tls.make_tls_connect(&s)) + .transpose() .map_err(|e| Error::tls(e.into()))?; let socket = connect_socket::connect_socket( diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index c89dc581f..cae887183 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -8,7 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; pub async fn cancel_query_raw( stream: S, mode: SslMode, - tls: T, + tls: Option, process_id: i32, secret_key: i32, ) -> Result<(), Error> diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index d048a3c82..9671de726 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -54,7 +54,7 @@ impl CancelToken { cancel_query_raw::cancel_query_raw( stream, self.ssl_mode, - tls, + Some(tls), self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 8b7df4e87..ac486813e 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -154,6 +154,7 @@ impl InnerClient { #[derive(Clone)] pub(crate) struct SocketConfig { pub host: Host, + pub hostname: Option, pub port: u16, pub connect_timeout: Option, pub tcp_user_timeout: Option, diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b18e3b8af..c88c5ff35 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -97,9 +97,9 @@ pub enum Host { /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// - or if host specifies an IP address, that value will be used directly. +/// or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// with time constraints. However, a host name is required for TLS certificate verification. /// Specifically: /// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. /// The connection attempt will fail if the authentication method requires a host name; @@ -645,7 +645,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, tls, self).await + connect_raw(stream, Some(tls), self).await } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 32a0a76b9..abb1a0118 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -52,16 +52,17 @@ where .unwrap_or(5432); // The value of host is used as the hostname for TLS validation, - // if it's not present, use the value of hostaddr. let hostname = match host { - Some(Host::Tcp(host)) => host.clone(), + Some(Host::Tcp(host)) => Some(host.clone()), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] - Some(Host::Unix(_)) => "".to_string(), - None => hostaddr.map_or("".to_string(), |ipaddr| ipaddr.to_string()), + Some(Host::Unix(_)) => None, + None => None, }; - let tls = tls - .make_tls_connect(&hostname) + let tls = hostname + .as_ref() + .map(|s| tls.make_tls_connect(s)) + .transpose() .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, @@ -78,7 +79,7 @@ where } }; - match connect_once(&addr, port, tls, config).await { + match connect_once(addr, hostname, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -88,16 +89,17 @@ where } async fn connect_once( - host: &Host, + host: Host, + hostname: Option, port: u16, - tls: T, + tls: Option, config: &Config, ) -> Result<(Client, Connection), Error> where T: TlsConnect, { let socket = connect_socket( - host, + &host, port, config.connect_timeout, config.tcp_user_timeout, @@ -151,7 +153,8 @@ where } client.set_socket_config(SocketConfig { - host: host.clone(), + host, + hostname, port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d97636221..2db6a66b9 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -80,7 +80,7 @@ where pub async fn connect_raw( stream: S, - tls: T, + tls: Option, config: &Config, ) -> Result<(Client, Connection), Error> where diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 5ef21ac5c..d75dcde90 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -10,7 +10,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub async fn connect_tls( mut stream: S, mode: SslMode, - tls: T, + tls: Option, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -18,7 +18,11 @@ where { match mode { SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), - SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { + SslMode::Prefer + if tls + .as_ref() + .map_or(false, |tls| !tls.can_connect(ForcePrivateApi)) => + { return Ok(MaybeTlsStream::Raw(stream)) } SslMode::Prefer | SslMode::Require => {} @@ -40,6 +44,7 @@ where } let stream = tls + .ok_or_else(|| Error::tls("no hostname provided for TLS handshake".into()))? .connect(stream) .await .map_err(|e| Error::tls(e.into()))?; From b57574598ec0985d9b471144fe038886b6d8b92a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 22 Jul 2023 21:09:08 -0400 Subject: [PATCH 039/126] fix test --- tokio-postgres/src/cancel_query.rs | 10 +++++----- tokio-postgres/src/cancel_query_raw.rs | 5 +++-- tokio-postgres/src/cancel_token.rs | 3 ++- tokio-postgres/src/config.rs | 2 +- tokio-postgres/src/connect.rs | 11 +++++------ tokio-postgres/src/connect_raw.rs | 5 +++-- tokio-postgres/src/connect_tls.rs | 14 +++++++------- 7 files changed, 26 insertions(+), 24 deletions(-) diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index 8e35a4224..4a7766d60 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -24,11 +24,10 @@ where } }; - let tls = config - .hostname - .map(|s| tls.make_tls_connect(&s)) - .transpose() + let tls = tls + .make_tls_connect(config.hostname.as_deref().unwrap_or("")) .map_err(|e| Error::tls(e.into()))?; + let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( &config.host, @@ -39,5 +38,6 @@ where ) .await?; - cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key) + .await } diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index cae887183..41aafe7d9 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -8,7 +8,8 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; pub async fn cancel_query_raw( stream: S, mode: SslMode, - tls: Option, + tls: T, + has_hostname: bool, process_id: i32, secret_key: i32, ) -> Result<(), Error> @@ -16,7 +17,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index 9671de726..c925ce0ca 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -54,7 +54,8 @@ impl CancelToken { cancel_query_raw::cancel_query_raw( stream, self.ssl_mode, - Some(tls), + tls, + true, self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index c88c5ff35..a7fa19312 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -645,7 +645,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, Some(tls), self).await + connect_raw(stream, tls, true, self).await } } diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index abb1a0118..441ad1238 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -59,10 +59,8 @@ where Some(Host::Unix(_)) => None, None => None, }; - let tls = hostname - .as_ref() - .map(|s| tls.make_tls_connect(s)) - .transpose() + let tls = tls + .make_tls_connect(hostname.as_deref().unwrap_or("")) .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, @@ -92,7 +90,7 @@ async fn connect_once( host: Host, hostname: Option, port: u16, - tls: Option, + tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where @@ -110,7 +108,8 @@ where }, ) .await?; - let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + let has_hostname = hostname.is_some(); + let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 2db6a66b9..254ca9f0c 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -80,14 +80,15 @@ where pub async fn connect_raw( stream: S, - tls: Option, + tls: T, + has_hostname: bool, config: &Config, ) -> Result<(Client, Connection), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let stream = connect_tls(stream, config.ssl_mode, tls).await?; + let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index d75dcde90..2b1229125 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -10,7 +10,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub async fn connect_tls( mut stream: S, mode: SslMode, - tls: Option, + tls: T, + has_hostname: bool, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -18,11 +19,7 @@ where { match mode { SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), - SslMode::Prefer - if tls - .as_ref() - .map_or(false, |tls| !tls.can_connect(ForcePrivateApi)) => - { + SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { return Ok(MaybeTlsStream::Raw(stream)) } SslMode::Prefer | SslMode::Require => {} @@ -43,8 +40,11 @@ where } } + if !has_hostname { + return Err(Error::tls("no hostname provided for TLS handshake".into())); + } + let stream = tls - .ok_or_else(|| Error::tls("no hostname provided for TLS handshake".into()))? .connect(stream) .await .map_err(|e| Error::tls(e.into()))?; From 3346858dd26b20d63eaae8f3db86773b6896b4c3 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Jul 2023 09:52:56 -0400 Subject: [PATCH 040/126] Implement load balancing --- tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/cancel_query.rs | 2 +- tokio-postgres/src/client.rs | 14 ++++- tokio-postgres/src/config.rs | 43 +++++++++++++ tokio-postgres/src/connect.rs | 93 +++++++++++++++++++++------- tokio-postgres/src/connect_socket.rs | 65 +++++++------------ 6 files changed, 149 insertions(+), 69 deletions(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 12d8a66fd..12c4bd689 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -58,6 +58,7 @@ postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +rand = "0.8.5" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index 4a7766d60..078d4b8b6 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -30,7 +30,7 @@ where let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( - &config.host, + &config.addr, config.port, config.connect_timeout, config.tcp_user_timeout, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ac486813e..2185d2146 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,4 @@ use crate::codec::{BackendMessages, FrontendMessage}; -#[cfg(feature = "runtime")] -use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; @@ -27,6 +25,8 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; +use std::net::IpAddr; +use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -153,7 +153,7 @@ impl InnerClient { #[cfg(feature = "runtime")] #[derive(Clone)] pub(crate) struct SocketConfig { - pub host: Host, + pub addr: Addr, pub hostname: Option, pub port: u16, pub connect_timeout: Option, @@ -161,6 +161,14 @@ pub(crate) struct SocketConfig { pub keepalive: Option, } +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub(crate) enum Addr { + Tcp(IpAddr), + #[cfg(unix)] + Unix(PathBuf), +} + /// An asynchronous PostgreSQL client. /// /// The client is one half of what is returned when a connection is established. Users interact with the database diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a7fa19312..87d77d35a 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -60,6 +60,16 @@ pub enum ChannelBinding { Require, } +/// Load balancing configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum LoadBalanceHosts { + /// Make connection attempts to hosts in the order provided. + Disable, + /// Make connection attempts to hosts in a random order. + Random, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -129,6 +139,12 @@ pub enum Host { /// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel /// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. /// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -190,6 +206,7 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) load_balance_hosts: LoadBalanceHosts, } impl Default for Config { @@ -222,6 +239,7 @@ impl Config { }, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + load_balance_hosts: LoadBalanceHosts::Disable, } } @@ -489,6 +507,19 @@ impl Config { self.channel_binding } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.load_balance_hosts = load_balance_hosts; + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.load_balance_hosts + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -612,6 +643,18 @@ impl Config { }; self.channel_binding(channel_binding); } + "load_balance_hosts" => { + let load_balance_hosts = match value { + "disable" => LoadBalanceHosts::Disable, + "random" => LoadBalanceHosts::Random, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "load_balance_hosts", + )))) + } + }; + self.load_balance_hosts(load_balance_hosts); + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 441ad1238..ca57b9cdd 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -1,12 +1,14 @@ -use crate::client::SocketConfig; -use crate::config::{Host, TargetSessionAttrs}; +use crate::client::{Addr, SocketConfig}; +use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; -use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::tls::MakeTlsConnect; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use rand::seq::SliceRandom; use std::task::Poll; use std::{cmp, io}; +use tokio::net; pub async fn connect( mut tls: T, @@ -40,8 +42,13 @@ where return Err(Error::config("invalid number of ports".into())); } + let mut indices = (0..num_hosts).collect::>(); + if config.load_balance_hosts == LoadBalanceHosts::Random { + indices.shuffle(&mut rand::thread_rng()); + } + let mut error = None; - for i in 0..num_hosts { + for i in indices { let host = config.host.get(i); let hostaddr = config.hostaddr.get(i); let port = config @@ -59,25 +66,15 @@ where Some(Host::Unix(_)) => None, None => None, }; - let tls = tls - .make_tls_connect(hostname.as_deref().unwrap_or("")) - .map_err(|e| Error::tls(e.into()))?; // Try to use the value of hostaddr to establish the TCP connection, // fallback to host if hostaddr is not present. let addr = match hostaddr { Some(ipaddr) => Host::Tcp(ipaddr.to_string()), - None => { - if let Some(host) = host { - host.clone() - } else { - // This is unreachable. - return Err(Error::config("both host and hostaddr are empty".into())); - } - } + None => host.cloned().unwrap(), }; - match connect_once(addr, hostname, port, tls, config).await { + match connect_host(addr, hostname, port, &mut tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -86,18 +83,66 @@ where Err(error.unwrap()) } -async fn connect_once( +async fn connect_host( host: Host, hostname: Option, port: u16, - tls: T, + tls: &mut T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + match host { + Host::Tcp(host) => { + let mut addrs = net::lookup_host((&*host, port)) + .await + .map_err(Error::connect)? + .collect::>(); + + if config.load_balance_hosts == LoadBalanceHosts::Random { + addrs.shuffle(&mut rand::thread_rng()); + } + + let mut last_err = None; + for addr in addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => { + last_err = Some(e); + continue; + } + }; + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + #[cfg(unix)] + Host::Unix(path) => { + connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await + } + } +} + +async fn connect_once( + addr: Addr, + hostname: Option<&str>, + port: u16, + tls: &mut T, config: &Config, ) -> Result<(Client, Connection), Error> where - T: TlsConnect, + T: MakeTlsConnect, { let socket = connect_socket( - &host, + &addr, port, config.connect_timeout, config.tcp_user_timeout, @@ -108,6 +153,10 @@ where }, ) .await?; + + let tls = tls + .make_tls_connect(hostname.unwrap_or("")) + .map_err(|e| Error::tls(e.into()))?; let has_hostname = hostname.is_some(); let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; @@ -152,8 +201,8 @@ where } client.set_socket_config(SocketConfig { - host, - hostname, + addr, + hostname: hostname.map(|s| s.to_string()), port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 1204ca1ff..082cad5dc 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,17 +1,17 @@ -use crate::config::Host; +use crate::client::Addr; use crate::keepalive::KeepaliveConfig; use crate::{Error, Socket}; use socket2::{SockRef, TcpKeepalive}; use std::future::Future; use std::io; use std::time::Duration; +use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; -use tokio::net::{self, TcpStream}; use tokio::time; pub(crate) async fn connect_socket( - host: &Host, + addr: &Addr, port: u16, connect_timeout: Option, #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< @@ -19,53 +19,32 @@ pub(crate) async fn connect_socket( >, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { - match host { - Host::Tcp(host) => { - let addrs = net::lookup_host((&**host, port)) - .await - .map_err(Error::connect)?; + match addr { + Addr::Tcp(ip) => { + let stream = + connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?; - let mut last_err = None; + stream.set_nodelay(true).map_err(Error::connect)?; - for addr in addrs { - let stream = - match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { - Ok(stream) => stream, - Err(e) => { - last_err = Some(e); - continue; - } - }; - - stream.set_nodelay(true).map_err(Error::connect)?; - - let sock_ref = SockRef::from(&stream); - #[cfg(target_os = "linux")] - { - sock_ref - .set_tcp_user_timeout(tcp_user_timeout) - .map_err(Error::connect)?; - } - - if let Some(keepalive_config) = keepalive_config { - sock_ref - .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) - .map_err(Error::connect)?; - } + let sock_ref = SockRef::from(&stream); + #[cfg(target_os = "linux")] + { + sock_ref + .set_tcp_user_timeout(tcp_user_timeout) + .map_err(Error::connect)?; + } - return Ok(Socket::new_tcp(stream)); + if let Some(keepalive_config) = keepalive_config { + sock_ref + .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) + .map_err(Error::connect)?; } - Err(last_err.unwrap_or_else(|| { - Error::connect(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })) + return Ok(Socket::new_tcp(stream)); } #[cfg(unix)] - Host::Unix(path) => { - let path = path.join(format!(".s.PGSQL.{}", port)); + Addr::Unix(dir) => { + let path = dir.join(format!(".s.PGSQL.{}", port)); let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?; Ok(Socket::new_unix(socket)) } From babc8562276cb51288671530045faa094ee7f35d Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Jul 2023 09:55:27 -0400 Subject: [PATCH 041/126] clippy --- tokio-postgres/src/connect_socket.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 082cad5dc..f27131178 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -40,7 +40,7 @@ pub(crate) async fn connect_socket( .map_err(Error::connect)?; } - return Ok(Socket::new_tcp(stream)); + Ok(Socket::new_tcp(stream)) } #[cfg(unix)] Addr::Unix(dir) => { From 84aed6312fb01ffa7664290b86af5e442ed8f6e9 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Jul 2023 09:56:32 -0400 Subject: [PATCH 042/126] fix wasm build --- tokio-postgres/src/client.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 2185d2146..427a05049 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -25,7 +25,9 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; +#[cfg(feature = "runtime")] use std::net::IpAddr; +#[cfg(feature = "runtime")] use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; From 98814b86bbe1c0daac2f29ffd55c675199b1877a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Aug 2023 16:22:18 +0300 Subject: [PATCH 043/126] Set user to executing processes' user by default. This mimics the behaviour of libpq and some other libraries (see #1024). This commit uses the `whoami` crate, and thus goes as far as defaulting the user to the executing process' user name on all operating systems. --- tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/config.rs | 21 +++++++++++---------- tokio-postgres/src/connect_raw.rs | 9 ++------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 12c4bd689..29cf26829 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -59,6 +59,7 @@ postgres-types = { version = "0.2.4", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" +whoami = "1.4.1" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 87d77d35a..a94667dc9 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -93,7 +93,7 @@ pub enum Host { /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -190,7 +190,7 @@ pub enum Host { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - pub(crate) user: Option, + user: String, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, @@ -219,7 +219,7 @@ impl Config { /// Creates a new configuration. pub fn new() -> Config { Config { - user: None, + user: whoami::username(), password: None, dbname: None, options: None, @@ -245,16 +245,17 @@ impl Config { /// Sets the user to authenticate with. /// - /// Required. + /// If the user is not set, then this defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { - self.user = Some(user.to_string()); + self.user = user.to_string(); self } - /// Gets the user to authenticate with, if one has been configured with - /// the `user` method. - pub fn get_user(&self) -> Option<&str> { - self.user.as_deref() + /// Gets the user to authenticate with. + /// If no user has been configured with the [`user`](Config::user) method, + /// then this defaults to the user executing this process. + pub fn get_user(&self) -> &str { + &self.user } /// Sets the password to authenticate with. @@ -1124,7 +1125,7 @@ mod tests { fn test_simple_parsing() { let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; let config = s.parse::().unwrap(); - assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!("pass_user", config.get_user()); assert_eq!(Some("postgres"), config.get_dbname()); assert_eq!( [ diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 254ca9f0c..bb511c47e 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -113,9 +113,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &config.user { - params.push(("user", &**user)); - } + params.push(("user", config.get_user())); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -158,10 +156,7 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = config - .user - .as_ref() - .ok_or_else(|| Error::config("user missing".into()))?; + let user = config.get_user(); let pass = config .password .as_ref() From 4c4059a63d273b94badf1c90998ffaa7c67091c0 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Aug 2023 18:48:57 +0300 Subject: [PATCH 044/126] Propagate changes from `tokio-postgres` to `postgres`. --- postgres/src/config.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 2a8e63862..0e1fbde62 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -29,7 +29,7 @@ use tokio_postgres::{Error, Socket}; /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -143,15 +143,16 @@ impl Config { /// Sets the user to authenticate with. /// - /// Required. + /// If the user is not set, then this defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.config.user(user); self } - /// Gets the user to authenticate with, if one has been configured with - /// the `user` method. - pub fn get_user(&self) -> Option<&str> { + /// Gets the user to authenticate with. + /// If no user has been configured with the [`user`](Config::user) method, + /// then this defaults to the user executing this process. + pub fn get_user(&self) -> &str { self.config.get_user() } From 7a5b19a7861d784a0a743f89447d4c732ac44b90 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Aug 2023 19:09:00 +0300 Subject: [PATCH 045/126] Update Rust version in CI to 1.67.0. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ebe0f600f..9a669a40f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.65.0 + version: 1.67.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From a4543783707cc2fdbba3db4bfe1fc6168582de7e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 19:53:26 -0400 Subject: [PATCH 046/126] Restore back compat --- postgres/src/config.rs | 7 +++++-- tokio-postgres/src/config.rs | 15 +++++++++------ tokio-postgres/src/connect_raw.rs | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 0e1fbde62..1839c9cb3 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -150,9 +150,12 @@ impl Config { } /// Gets the user to authenticate with. + /// /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. - pub fn get_user(&self) -> &str { + /// then this defaults to the user executing this process. It always + /// returns `Some`. + // FIXME remove option + pub fn get_user(&self) -> Option<&str> { self.config.get_user() } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a94667dc9..0da5fc689 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -190,7 +190,7 @@ pub enum Host { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - user: String, + pub(crate) user: String, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, @@ -245,17 +245,20 @@ impl Config { /// Sets the user to authenticate with. /// - /// If the user is not set, then this defaults to the user executing this process. + /// Defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.user = user.to_string(); self } /// Gets the user to authenticate with. + /// /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. - pub fn get_user(&self) -> &str { - &self.user + /// then this defaults to the user executing this process. It always + /// returns `Some`. + // FIXME remove option + pub fn get_user(&self) -> Option<&str> { + Some(&self.user) } /// Sets the password to authenticate with. @@ -1125,7 +1128,7 @@ mod tests { fn test_simple_parsing() { let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; let config = s.parse::().unwrap(); - assert_eq!("pass_user", config.get_user()); + assert_eq!(Some("pass_user"), config.get_user()); assert_eq!(Some("postgres"), config.get_dbname()); assert_eq!( [ diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index bb511c47e..11cc48ef8 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -113,7 +113,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - params.push(("user", config.get_user())); + params.push(("user", &config.user)); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -156,7 +156,7 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = config.get_user(); + let user = &config.user; let pass = config .password .as_ref() From 496f46c8f5e8e76e0b148c7ef57dbccc11778597 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:04:18 -0400 Subject: [PATCH 047/126] Release postgres-protocol v0.6.6 --- postgres-protocol/CHANGELOG.md | 6 ++++++ postgres-protocol/Cargo.toml | 2 +- postgres-protocol/src/lib.rs | 1 - postgres-types/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/postgres-protocol/CHANGELOG.md b/postgres-protocol/CHANGELOG.md index 034fd637c..1c371675c 100644 --- a/postgres-protocol/CHANGELOG.md +++ b/postgres-protocol/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.6.6 -2023-08-19 + +### Added + +* Added the `js` feature for WASM support. + ## v0.6.5 - 2023-03-27 ### Added diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index ad609f6fa..b44994811 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-protocol" -version = "0.6.5" +version = "0.6.6" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..83d9bf55c 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -9,7 +9,6 @@ //! //! This library assumes that the `client_encoding` backend parameter has been //! set to `UTF8`. It will most likely not behave properly if that is not the case. -#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] #![warn(missing_docs, rust_2018_idioms, clippy::all)] use byteorder::{BigEndian, ByteOrder}; diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 35cdd6e7b..686d0036d 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -30,7 +30,7 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 29cf26829..f9f49da3e 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -54,7 +54,7 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } From 43e15690f492f3ae8088677fd8d5df18f73b3e85 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:11:35 -0400 Subject: [PATCH 048/126] Release postgres-derive v0.4.5 --- postgres-derive/CHANGELOG.md | 7 +++++++ postgres-derive/Cargo.toml | 2 +- postgres-types/Cargo.toml | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/postgres-derive/CHANGELOG.md b/postgres-derive/CHANGELOG.md index 22714acc2..b0075fa8e 100644 --- a/postgres-derive/CHANGELOG.md +++ b/postgres-derive/CHANGELOG.md @@ -1,5 +1,12 @@ # Change Log +## v0.4.5 - 2023-08-19 + +### Added + +* Added a `rename_all` option for enum and struct derives. +* Added an `allow_mismatch` option to disable strict enum variant checks against the Postgres type. + ## v0.4.4 - 2023-03-27 ### Changed diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 78bec3d41..51ebb5663 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-derive" -version = "0.4.4" +version = "0.4.5" authors = ["Steven Fackler "] license = "MIT/Apache-2.0" edition = "2018" diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 686d0036d..15de00702 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -31,7 +31,7 @@ with-time-0_3 = ["time-03"] bytes = "1.0" fallible-iterator = "0.2" postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } -postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } +postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true } From 6f7ab44d5bc8548a4e7fb69d46d3b85a14101144 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:14:01 -0400 Subject: [PATCH 049/126] Release postgres-types v0.2.6 --- postgres-types/CHANGELOG.md | 15 +++++++++++++-- postgres-types/Cargo.toml | 2 +- postgres-types/src/lib.rs | 1 - 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md index 0f42f3495..72a1cbb6a 100644 --- a/postgres-types/CHANGELOG.md +++ b/postgres-types/CHANGELOG.md @@ -1,14 +1,25 @@ # Change Log +## v0.2.6 - 2023-08-19 + +### Fixed + +* Fixed serialization to `OIDVECTOR` and `INT2VECTOR`. + +### Added + +* Removed the `'static` requirement for the `impl BorrowToSql for Box`. +* Added a `ToSql` implementation for `Cow<[u8]>`. + ## v0.2.5 - 2023-03-27 -## Added +### Added * Added support for multi-range types. ## v0.2.4 - 2022-08-20 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `Box<[T]>`. * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 15de00702..193d159a1 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-types" -version = "0.2.5" +version = "0.2.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index d27adfe0e..52b5c773a 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -174,7 +174,6 @@ //! Meh, //! } //! ``` -#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; From 3d0a593ea610fb51b25a34087131470c94e3fe58 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:20:13 -0400 Subject: [PATCH 050/126] Release tokio-postgres v0.7.9 --- tokio-postgres/CHANGELOG.md | 13 +++++++++++++ tokio-postgres/Cargo.toml | 4 ++-- tokio-postgres/src/lib.rs | 1 - 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 3345a1d43..41a1a65d1 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,5 +1,18 @@ # Change Log +## v0.7.9 + +## Fixed + +* Fixed builds on OpenBSD. + +## Added + +* Added the `js` feature for WASM support. +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + ## v0.7.8 ## Added diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index f9f49da3e..3b33cc8f6 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.8" +version = "0.7.9" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -55,7 +55,7 @@ percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } -postgres-types = { version = "0.2.4", path = "../postgres-types" } +postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 2bb410187..ff8e93ddc 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -116,7 +116,6 @@ //! | `with-uuid-1` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 1.0 | no | //! | `with-time-0_2` | Enable support for the 0.2 version of the `time` crate. | [time](https://crates.io/crates/time/0.2.0) 0.2 | no | //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | -#![doc(html_root_url = "https://docs.rs/tokio-postgres/0.7")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] pub use crate::cancel_token::CancelToken; From e08a38f9f6f06a67d699209d54097fa8a567a578 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:33:21 -0400 Subject: [PATCH 051/126] sync postgres config up with tokio-postgres --- postgres/src/config.rs | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 1839c9cb3..0f936fdc4 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -13,7 +13,9 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -43,9 +45,9 @@ use tokio_postgres::{Error, Socket}; /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// - or if host specifies an IP address, that value will be used directly. +/// or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// with time constraints. However, a host name is required for TLS certificate verification. /// Specifically: /// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. /// The connection attempt will fail if the authentication method requires a host name; @@ -72,6 +74,15 @@ use tokio_postgres::{Error, Socket}; /// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that /// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server /// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -80,7 +91,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// host=/var/run/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' /// ``` /// /// ```not_rust @@ -94,7 +105,7 @@ use tokio_postgres::{Error, Socket}; /// # Url /// /// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, -/// and the format accept query parameters for all of the key-value pairs described in the section above. Multiple +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple /// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, /// as the path component of the URL specifies the database name. /// @@ -105,7 +116,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql://user:password@%2Fvar%2Frun%2Fpostgresql/mydb?connect_timeout=10 +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 /// ``` /// /// ```not_rust @@ -113,7 +124,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql:///mydb?user=user&host=/var/run/postgresql +/// postgresql:///mydb?user=user&host=/var/lib/postgresql /// ``` #[derive(Clone)] pub struct Config { @@ -396,6 +407,19 @@ impl Config { self.config.get_channel_binding() } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.config.load_balance_hosts(load_balance_hosts); + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.config.get_load_balance_hosts() + } + /// Sets the notice callback. /// /// This callback will be invoked with the contents of every From f45527fe5f4f566328973097511a33d771d3f300 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:34:02 -0400 Subject: [PATCH 052/126] remove bogus docs --- postgres/src/config.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 0f936fdc4..f83244b2e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -1,6 +1,4 @@ //! Connection configuration. -//! -//! Requires the `runtime` Cargo feature (enabled by default). use crate::connection::Connection; use crate::Client; From 75cc986d8c40024eca45139edc6c366231d147ea Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 19 Aug 2023 20:37:16 -0400 Subject: [PATCH 053/126] Release postgres v0.19.6 --- postgres/CHANGELOG.md | 14 +++++++++++--- postgres/Cargo.toml | 8 +++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index b8263a04a..fe9e8dbe8 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,20 +1,28 @@ # Change Log +## v0.19.6 - 2023-08-19 + +### Added + +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + ## v0.19.5 - 2023-03-27 -## Added +### Added * Added `keepalives_interval` and `keepalives_retries` config options. * Added the `tcp_user_timeout` config option. * Added `RowIter::rows_affected`. -## Changed +### Changed * Passing an incorrect number of parameters to a query method now returns an error instead of panicking. ## v0.19.4 - 2022-08-21 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. * Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 044bb91e1..ff626f86c 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres" -version = "0.19.5" +version = "0.19.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -39,11 +39,9 @@ with-time-0_3 = ["tokio-postgres/with-time-0_3"] bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } -tokio-postgres = { version = "0.7.8", path = "../tokio-postgres" } - -tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" +tokio-postgres = { version = "0.7.9", path = "../tokio-postgres" } +tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] criterion = "0.5" -tokio = { version = "1.0", features = ["rt-multi-thread"] } From cb609be758f3fb5af537f04b584a2ee0cebd5e79 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:31:22 -0400 Subject: [PATCH 054/126] Defer username default --- postgres/src/config.rs | 8 ++------ tokio-postgres/src/config.rs | 16 ++++++---------- tokio-postgres/src/connect_raw.rs | 21 +++++++++++++++------ 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index f83244b2e..a32ddc78e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -158,12 +158,8 @@ impl Config { self } - /// Gets the user to authenticate with. - /// - /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. It always - /// returns `Some`. - // FIXME remove option + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. pub fn get_user(&self) -> Option<&str> { self.config.get_user() } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0da5fc689..b178eac80 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -190,7 +190,7 @@ pub enum Host { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - pub(crate) user: String, + pub(crate) user: Option, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, @@ -219,7 +219,7 @@ impl Config { /// Creates a new configuration. pub fn new() -> Config { Config { - user: whoami::username(), + user: None, password: None, dbname: None, options: None, @@ -247,18 +247,14 @@ impl Config { /// /// Defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { - self.user = user.to_string(); + self.user = Some(user.to_string()); self } - /// Gets the user to authenticate with. - /// - /// If no user has been configured with the [`user`](Config::user) method, - /// then this defaults to the user executing this process. It always - /// returns `Some`. - // FIXME remove option + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. pub fn get_user(&self) -> Option<&str> { - Some(&self.user) + self.user.as_deref() } /// Sets the password to authenticate with. diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 11cc48ef8..f19bb50c4 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -96,8 +96,10 @@ where delayed: VecDeque::new(), }; - startup(&mut stream, config).await?; - authenticate(&mut stream, config).await?; + let user = config.user.clone().unwrap_or_else(whoami::username); + + startup(&mut stream, config, &user).await?; + authenticate(&mut stream, config, &user).await?; let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); @@ -107,13 +109,17 @@ where Ok((client, connection)) } -async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn startup( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - params.push(("user", &config.user)); + params.push(("user", user)); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -133,7 +139,11 @@ where .map_err(Error::io) } -async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn authenticate( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, @@ -156,7 +166,6 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = &config.user; let pass = config .password .as_ref() From b411e5c3cb71d43fc9249b5d3ca38a7213470069 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:35:48 -0400 Subject: [PATCH 055/126] clippy --- postgres-protocol/src/types/test.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 6f1851fc2..3e33b08f0 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -174,7 +174,7 @@ fn ltree_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -182,7 +182,7 @@ fn ltree_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } #[test] @@ -202,7 +202,7 @@ fn lquery_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) + assert!(lquery_from_sql(query.as_slice()).is_ok()) } #[test] @@ -210,7 +210,7 @@ fn lquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) + assert!(lquery_from_sql(query.as_slice()).is_err()) } #[test] @@ -230,7 +230,7 @@ fn ltxtquery_str() { let mut query = vec![1u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -238,5 +238,5 @@ fn ltxtquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } From 016e9a3b8557c267f650090e1501d5efd00de908 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:40:01 -0400 Subject: [PATCH 056/126] avoid a silly clone --- tokio-postgres/src/connect_raw.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index f19bb50c4..19be9eb01 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -13,6 +13,7 @@ use postgres_protocol::authentication::sasl; use postgres_protocol::authentication::sasl::ScramSha256; use postgres_protocol::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol::message::frontend; +use std::borrow::Cow; use std::collections::{HashMap, VecDeque}; use std::io; use std::pin::Pin; @@ -96,7 +97,10 @@ where delayed: VecDeque::new(), }; - let user = config.user.clone().unwrap_or_else(whoami::username); + let user = config + .user + .as_deref() + .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed); startup(&mut stream, config, &user).await?; authenticate(&mut stream, config, &user).await?; From 234e20bb000ccf17d08341bd66e48d1105c3960a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:40:40 -0400 Subject: [PATCH 057/126] bump ci version --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9a669a40f..008158fb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.67.0 + version: 1.70.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From c50fcbd9fb6f0df53d2300fb429af1c6c128007f Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:45:34 -0400 Subject: [PATCH 058/126] Release tokio-postgres v0.7.10 --- tokio-postgres/CHANGELOG.md | 6 ++++++ tokio-postgres/Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 41a1a65d1..2bee9a1c4 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.7.10 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + ## v0.7.9 ## Fixed diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 3b33cc8f6..ec5e3cbec 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.9" +version = "0.7.10" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" From c5ff8cfd86e897b7c197f52684a37a4f17cecb75 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 25 Aug 2023 13:48:08 -0400 Subject: [PATCH 059/126] Release postgres v0.19.7 --- postgres/CHANGELOG.md | 6 ++++++ postgres/Cargo.toml | 4 ++-- tokio-postgres/CHANGELOG.md | 6 +++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index fe9e8dbe8..7f856b5ac 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.19.7 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + ## v0.19.6 - 2023-08-19 ### Added diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index ff626f86c..18406da9f 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres" -version = "0.19.6" +version = "0.19.7" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -40,7 +40,7 @@ bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } log = "0.4" -tokio-postgres = { version = "0.7.9", path = "../tokio-postgres" } +tokio-postgres = { version = "0.7.10", path = "../tokio-postgres" } tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 2bee9a1c4..75448d130 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,12 +1,12 @@ # Change Log -## v0.7.10 +## v0.7.10 - 2023-08-25 ## Fixed * Defered default username lookup to avoid regressing `Config` behavior. -## v0.7.9 +## v0.7.9 - 2023-08-19 ## Fixed @@ -19,7 +19,7 @@ * Added support for the `load_balance_hosts` config option to randomize connection ordering. * The `user` config option now defaults to the executing process's user. -## v0.7.8 +## v0.7.8 - 2023-05-27 ## Added From b1306a4a74317ac142ae9b93445360e9597380ec Mon Sep 17 00:00:00 2001 From: ds-cbo <82801887+ds-cbo@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:31:41 +0200 Subject: [PATCH 060/126] remove rustc-serialize dependency --- postgres-types/Cargo.toml | 4 +++- postgres/src/lib.rs | 2 +- tokio-postgres/CHANGELOG.md | 6 ++++++ tokio-postgres/Cargo.toml | 3 +-- tokio-postgres/src/lib.rs | 2 +- tokio-postgres/tests/test/types/eui48_04.rs | 18 ------------------ tokio-postgres/tests/test/types/mod.rs | 2 -- 7 files changed, 12 insertions(+), 25 deletions(-) delete mode 100644 tokio-postgres/tests/test/types/eui48_04.rs diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 193d159a1..cfd083637 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -39,8 +39,10 @@ chrono-04 = { version = "0.4.16", package = "chrono", default-features = false, "clock", ], optional = true } cidr-02 = { version = "0.2", package = "cidr", optional = true } +# eui48-04 will stop compiling and support will be removed +# See https://github.com/sfackler/rust-postgres/issues/1073 eui48-04 = { version = "0.4", package = "eui48", optional = true } -eui48-1 = { version = "1.0", package = "eui48", optional = true } +eui48-1 = { version = "1.0", package = "eui48", optional = true, default-features = false } geo-types-06 = { version = "0.6", package = "geo-types", optional = true } geo-types-0_7 = { version = "0.7", package = "geo-types", optional = true } serde-1 = { version = "1.0", package = "serde", optional = true } diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index fbe85cbde..ddf1609ad 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -55,7 +55,7 @@ //! | ------- | ----------- | ------------------ | ------- | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 0.4 | no | +//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 75448d130..bd076eef9 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## Unreleased + +* Disable `rustc-serialize` compatibility of `eui48-1` dependency +* Remove tests for `eui48-04` + + ## v0.7.10 - 2023-08-25 ## Fixed diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index ec5e3cbec..bb58eb2d9 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -78,8 +78,7 @@ tokio = { version = "1.0", features = [ bit-vec-06 = { version = "0.6", package = "bit-vec" } chrono-04 = { version = "0.4", package = "chrono", default-features = false } -eui48-04 = { version = "0.4", package = "eui48" } -eui48-1 = { version = "1.0", package = "eui48" } +eui48-1 = { version = "1.0", package = "eui48", default-features = false } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } serde-1 = { version = "1.0", package = "serde" } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index ff8e93ddc..2973d33b0 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -107,7 +107,7 @@ //! | `array-impls` | Enables `ToSql` and `FromSql` trait impls for arrays | - | no | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 0.4 | no | +//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | diff --git a/tokio-postgres/tests/test/types/eui48_04.rs b/tokio-postgres/tests/test/types/eui48_04.rs deleted file mode 100644 index 074faa37e..000000000 --- a/tokio-postgres/tests/test/types/eui48_04.rs +++ /dev/null @@ -1,18 +0,0 @@ -use eui48_04::MacAddress; - -use crate::types::test_type; - -#[tokio::test] -async fn test_eui48_params() { - test_type( - "MACADDR", - &[ - ( - Some(MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()), - "'12-34-56-ab-cd-ef'", - ), - (None, "NULL"), - ], - ) - .await -} diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index f1a44da08..62d54372a 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -17,8 +17,6 @@ use bytes::BytesMut; mod bit_vec_06; #[cfg(feature = "with-chrono-0_4")] mod chrono_04; -#[cfg(feature = "with-eui48-0_4")] -mod eui48_04; #[cfg(feature = "with-eui48-1")] mod eui48_1; #[cfg(feature = "with-geo-types-0_6")] From ea9e0e5cddc2e772179027e635afa11d64feea2b Mon Sep 17 00:00:00 2001 From: ds-cbo <82801887+ds-cbo@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:43:56 +0100 Subject: [PATCH 061/126] replace deprecated chrono::DateTime::from_utc --- postgres-types/src/chrono_04.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index 0ec92437d..6011b549e 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -40,7 +40,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_utc(naive, Utc)) + Ok(Utc::from_utc_datetime(naive)) } accepts!(TIMESTAMPTZ); From b4ebc4e7ec6ee52930bc22e2ad29b66687852623 Mon Sep 17 00:00:00 2001 From: ds-cbo <82801887+ds-cbo@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:39:50 +0100 Subject: [PATCH 062/126] add missing import --- postgres-types/src/chrono_04.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index 6011b549e..f995d483c 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -1,5 +1,7 @@ use bytes::BytesMut; -use chrono_04::{DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use chrono_04::{ + DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc, +}; use postgres_protocol::types; use std::error::Error; From 19a6ef767bf6b2070ffe9efd43af514b6a31f2d2 Mon Sep 17 00:00:00 2001 From: ds-cbo <82801887+ds-cbo@users.noreply.github.com> Date: Tue, 31 Oct 2023 09:54:13 +0100 Subject: [PATCH 063/126] fix more deprecated chrono functions --- postgres-types/src/chrono_04.rs | 2 +- tokio-postgres/tests/test/types/chrono_04.rs | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index f995d483c..d599bde02 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -42,7 +42,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(Utc::from_utc_datetime(naive)) + Ok(Utc.from_utc_datetime(&naive)) } accepts!(TIMESTAMPTZ); diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index a8e9e5afa..b010055ba 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -53,10 +53,9 @@ async fn test_with_special_naive_date_time_params() { async fn test_date_time_params() { fn make_check(time: &str) -> (Option>, &str) { ( - Some( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), + Some(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + )), time, ) } @@ -76,10 +75,9 @@ async fn test_date_time_params() { async fn test_with_special_date_time_params() { fn make_check(time: &str) -> (Timestamp>, &str) { ( - Timestamp::Value( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), + Timestamp::Value(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + )), time, ) } From 863c1d6039e8fe114e48d62c0451d6eb5e4867a2 Mon Sep 17 00:00:00 2001 From: James Guthrie Date: Tue, 7 Nov 2023 22:09:39 +0100 Subject: [PATCH 064/126] fix code block --- postgres-types/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 52b5c773a..aaf145e6b 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -168,6 +168,8 @@ //! 'Happy' //! ); //! ``` +//! +//! ```rust //! #[postgres(allow_mismatch)] //! enum Mood { //! Happy, From 10edbcb46c44933417e8d2e7a1c1d63c4119beb3 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Tue, 7 Nov 2023 16:23:06 -0500 Subject: [PATCH 065/126] Update lib.rs --- postgres-types/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index aaf145e6b..2f02f6e5f 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -170,6 +170,11 @@ //! ``` //! //! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] //! #[postgres(allow_mismatch)] //! enum Mood { //! Happy, From 02bab67280f8a850b816754b29eb0364708604ec Mon Sep 17 00:00:00 2001 From: "Michael P. Jung" Date: Tue, 5 Dec 2023 13:54:20 +0100 Subject: [PATCH 066/126] Add table_oid and field_id to columns of prepared statements --- tokio-postgres/CHANGELOG.md | 2 +- tokio-postgres/src/prepare.rs | 8 +++++++- tokio-postgres/src/statement.rs | 23 +++++++++++++++++------ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index bd076eef9..9f5eb9521 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -4,7 +4,7 @@ * Disable `rustc-serialize` compatibility of `eui48-1` dependency * Remove tests for `eui48-04` - +* Add `table_oid` and `field_id` fields to `Columns` struct of prepared statements. ## v0.7.10 - 2023-08-25 diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..1ab34e2df 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -12,6 +12,7 @@ use log::debug; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::future::Future; +use std::num::{NonZeroI16, NonZeroU32}; use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -95,7 +96,12 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column { + name: field.name().to_string(), + table_oid: NonZeroU32::new(field.table_oid()), + column_id: NonZeroI16::new(field.column_id()), + type_, + }; columns.push(column); } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..73d56c220 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -5,6 +5,7 @@ use crate::types::Type; use postgres_protocol::message::frontend; use std::{ fmt, + num::{NonZeroI16, NonZeroU32}, sync::{Arc, Weak}, }; @@ -66,20 +67,28 @@ impl Statement { /// Information about a column of a query. pub struct Column { - name: String, - type_: Type, + pub(crate) name: String, + pub(crate) table_oid: Option, + pub(crate) column_id: Option, + pub(crate) type_: Type, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } - } - /// Returns the name of the column. pub fn name(&self) -> &str { &self.name } + /// Returns the OID of the underlying database table. + pub fn table_oid(&self) -> Option { + self.table_oid + } + + /// Return the column ID within the underlying database table. + pub fn column_id(&self) -> Option { + self.column_id + } + /// Returns the type of the column. pub fn type_(&self) -> &Type { &self.type_ @@ -90,6 +99,8 @@ impl fmt::Debug for Column { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Column") .field("name", &self.name) + .field("table_oid", &self.table_oid) + .field("column_id", &self.column_id) .field("type", &self.type_) .finish() } From 87876150d79e637767247176e339bf01a8b32d3b Mon Sep 17 00:00:00 2001 From: "Michael P. Jung" Date: Tue, 5 Dec 2023 14:09:44 +0100 Subject: [PATCH 067/126] Simplify Debug impl of Column --- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/statement.rs | 17 +++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 1ab34e2df..0302cdb4c 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -100,7 +100,7 @@ pub async fn prepare( name: field.name().to_string(), table_oid: NonZeroU32::new(field.table_oid()), column_id: NonZeroI16::new(field.column_id()), - type_, + r#type: type_, }; columns.push(column); } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 73d56c220..fe3b6b7a1 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -4,7 +4,6 @@ use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; use std::{ - fmt, num::{NonZeroI16, NonZeroU32}, sync::{Arc, Weak}, }; @@ -66,11 +65,12 @@ impl Statement { } /// Information about a column of a query. +#[derive(Debug)] pub struct Column { pub(crate) name: String, pub(crate) table_oid: Option, pub(crate) column_id: Option, - pub(crate) type_: Type, + pub(crate) r#type: Type, } impl Column { @@ -91,17 +91,6 @@ impl Column { /// Returns the type of the column. pub fn type_(&self) -> &Type { - &self.type_ - } -} - -impl fmt::Debug for Column { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Column") - .field("name", &self.name) - .field("table_oid", &self.table_oid) - .field("column_id", &self.column_id) - .field("type", &self.type_) - .finish() + &self.r#type } } From bbc04145de7a83dfa66cb3cf4a68878da2c1cc32 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 11 Dec 2023 19:06:22 -0500 Subject: [PATCH 068/126] Update id types --- tokio-postgres/src/prepare.rs | 5 ++--- tokio-postgres/src/statement.rs | 13 +++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 0302cdb4c..07fb45694 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -12,7 +12,6 @@ use log::debug; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::future::Future; -use std::num::{NonZeroI16, NonZeroU32}; use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -98,8 +97,8 @@ pub async fn prepare( let type_ = get_type(client, field.type_oid()).await?; let column = Column { name: field.name().to_string(), - table_oid: NonZeroU32::new(field.table_oid()), - column_id: NonZeroI16::new(field.column_id()), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), r#type: type_, }; columns.push(column); diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index fe3b6b7a1..c5d657738 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,10 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; -use std::{ - num::{NonZeroI16, NonZeroU32}, - sync::{Arc, Weak}, -}; +use std::sync::{Arc, Weak}; struct StatementInner { client: Weak, @@ -68,8 +65,8 @@ impl Statement { #[derive(Debug)] pub struct Column { pub(crate) name: String, - pub(crate) table_oid: Option, - pub(crate) column_id: Option, + pub(crate) table_oid: Option, + pub(crate) column_id: Option, pub(crate) r#type: Type, } @@ -80,12 +77,12 @@ impl Column { } /// Returns the OID of the underlying database table. - pub fn table_oid(&self) -> Option { + pub fn table_oid(&self) -> Option { self.table_oid } /// Return the column ID within the underlying database table. - pub fn column_id(&self) -> Option { + pub fn column_id(&self) -> Option { self.column_id } From 90c92c2ae8577a8e771333c701280485c45ad602 Mon Sep 17 00:00:00 2001 From: Troy Benson Date: Mon, 15 Jan 2024 16:33:27 +0000 Subject: [PATCH 069/126] feat(types): add default derive to json wrapper Adds a Default impl for `Json where T: Default` allowing for other structs to use the wrapper and implement Default. --- postgres-types/src/serde_json_1.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-types/src/serde_json_1.rs b/postgres-types/src/serde_json_1.rs index b98d561d1..715c33f98 100644 --- a/postgres-types/src/serde_json_1.rs +++ b/postgres-types/src/serde_json_1.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use std::io::Read; /// A wrapper type to allow arbitrary `Serialize`/`Deserialize` types to convert to Postgres JSON values. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Default, Debug, PartialEq, Eq)] pub struct Json(pub T); impl<'a, T> FromSql<'a> for Json From 2f150a7e50ee03cbccf52792b7e4507dbcef0301 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:49:51 +0000 Subject: [PATCH 070/126] Update env_logger requirement from 0.10 to 0.11 Updates the requirements on [env_logger](https://github.com/rust-cli/env_logger) to permit the latest version. - [Release notes](https://github.com/rust-cli/env_logger/releases) - [Changelog](https://github.com/rust-cli/env_logger/blob/main/CHANGELOG.md) - [Commits](https://github.com/rust-cli/env_logger/compare/v0.10.0...v0.11.0) --- updated-dependencies: - dependency-name: env_logger dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- tokio-postgres/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index bb58eb2d9..237f3d2f1 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -67,7 +67,7 @@ socket2 = { version = "0.5", features = ["all"] } [dev-dependencies] futures-executor = "0.3" criterion = "0.5" -env_logger = "0.10" +env_logger = "0.11" tokio = { version = "1.0", features = [ "macros", "net", From 7bc3deb989b3030681b742801bfeaca7f67e1e1e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 22 Jan 2024 20:49:47 -0500 Subject: [PATCH 071/126] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 008158fb0..0cc823d35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.70.0 + version: 1.71.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From a92c6eb2b65e12d7145a14cec23888d64c4b13e4 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 22 Jan 2024 20:54:11 -0500 Subject: [PATCH 072/126] Update main.rs --- tokio-postgres/tests/test/main.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..737f46631 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -303,6 +303,7 @@ async fn custom_range() { } #[tokio::test] +#[allow(clippy::get_first)] async fn simple_query() { let client = connect("user=postgres").await; From 289cf887600785e723628dcbc1f7a2267cd52917 Mon Sep 17 00:00:00 2001 From: Charles Samuels Date: Fri, 16 Feb 2024 10:55:08 -0800 Subject: [PATCH 073/126] add #[track_caller] to the Row::get() functions This small quality-of-life improvement changes these errors: thread '' panicked at /../.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-postgres-0.7.10/src/row.rs:151:25: error retrieving column 0: error deserializing column 0: a Postgres value was `NULL` to: thread '' panicked at my-program.rs:100:25: error retrieving column 0: error deserializing column 0: a Postgres value was `NULL` --- tokio-postgres/src/row.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..3c79de603 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -141,6 +141,7 @@ impl Row { /// # Panics /// /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + #[track_caller] pub fn get<'a, I, T>(&'a self, idx: I) -> T where I: RowIndex + fmt::Display, @@ -239,6 +240,7 @@ impl SimpleQueryRow { /// # Panics /// /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + #[track_caller] pub fn get(&self, idx: I) -> Option<&str> where I: RowIndex + fmt::Display, From 25314a91c95dc8f75062e337eb363188c63df5d4 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 17 Feb 2024 09:52:44 -0500 Subject: [PATCH 074/126] Bump CI version --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0cc823d35..641a42722 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.71.0 + version: 1.74.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 From a9ca481c88fb619c6d35f2a6b64253bb46240c5d Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 3 Mar 2024 16:37:30 +0100 Subject: [PATCH 075/126] Added ReadOnly session attr --- tokio-postgres/src/config.rs | 3 +++ tokio-postgres/tests/test/parse.rs | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b178eac80..c78346fff 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -34,6 +34,8 @@ pub enum TargetSessionAttrs { Any, /// The session must allow writes. ReadWrite, + /// The session allow only reads. + ReadOnly, } /// TLS configuration. @@ -622,6 +624,7 @@ impl Config { let target_session_attrs = match value { "any" => TargetSessionAttrs::Any, "read-write" => TargetSessionAttrs::ReadWrite, + "read-only" => TargetSessionAttrs::ReadOnly, _ => { return Err(Error::config_parse(Box::new(InvalidValue( "target_session_attrs", diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index 2c11899ca..04d422e27 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -34,6 +34,14 @@ fn settings() { .keepalives_idle(Duration::from_secs(30)) .target_session_attrs(TargetSessionAttrs::ReadWrite), ); + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-only", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::ReadOnly), + ); } #[test] From 6a01730cbfed5d9c0aa694401704e6fe7ec0c8b5 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sun, 3 Mar 2024 19:17:50 +0100 Subject: [PATCH 076/126] Added ReadOnly session attr --- tokio-postgres/src/connect.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ca57b9cdd..8189cb91c 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -160,7 +160,7 @@ where let has_hostname = hostname.is_some(); let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; - if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { + if config.target_session_attrs != TargetSessionAttrs::Any { let rows = client.simple_query_raw("SHOW transaction_read_only"); pin_mut!(rows); @@ -185,11 +185,21 @@ where match next.await.transpose()? { Some(SimpleQueryMessage::Row(row)) => { - if row.try_get(0)? == Some("on") { + let read_only_result = row.try_get(0)?; + if read_only_result == Some("on") + && config.target_session_attrs == TargetSessionAttrs::ReadWrite + { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, "database does not allow writes", ))); + } else if read_only_result == Some("off") + && config.target_session_attrs == TargetSessionAttrs::ReadOnly + { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database is not read only", + ))); } else { break; } From 4217553586c4ce390179a281834b8f2c3197863e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 13:45:26 +0000 Subject: [PATCH 077/126] Update base64 requirement from 0.21 to 0.22 Updates the requirements on [base64](https://github.com/marshallpierce/rust-base64) to permit the latest version. - [Changelog](https://github.com/marshallpierce/rust-base64/blob/master/RELEASE-NOTES.md) - [Commits](https://github.com/marshallpierce/rust-base64/compare/v0.21.0...v0.22.0) --- updated-dependencies: - dependency-name: base64 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- postgres-protocol/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index b44994811..bc83fc4e6 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -13,7 +13,7 @@ default = [] js = ["getrandom/js"] [dependencies] -base64 = "0.21" +base64 = "0.22" byteorder = "1.0" bytes = "1.0" fallible-iterator = "0.2" From 9d7c43c73955638624e75957a333fac5d9be1c02 Mon Sep 17 00:00:00 2001 From: novacrazy Date: Sun, 11 Feb 2024 03:03:19 -0600 Subject: [PATCH 078/126] Shrink query_opt/query_one codegen size very slightly --- tokio-postgres/src/client.rs | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 427a05049..d48a23a60 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -274,19 +274,9 @@ impl Client { where T: ?Sized + ToStatement, { - let stream = self.query_raw(statement, slice_iter(params)).await?; - pin_mut!(stream); - - let row = match stream.try_next().await? { - Some(row) => row, - None => return Err(Error::row_count()), - }; - - if stream.try_next().await?.is_some() { - return Err(Error::row_count()); - } - - Ok(row) + self.query_opt(statement, params) + .await + .and_then(|res| res.ok_or_else(Error::row_count)) } /// Executes a statements which returns zero or one rows, returning it. @@ -310,16 +300,22 @@ impl Client { let stream = self.query_raw(statement, slice_iter(params)).await?; pin_mut!(stream); - let row = match stream.try_next().await? { - Some(row) => row, - None => return Ok(None), - }; + let mut first = None; + + // Originally this was two calls to `try_next().await?`, + // once for the first element, and second to error if more than one. + // + // However, this new form with only one .await in a loop generates + // slightly smaller codegen/stack usage for the resulting future. + while let Some(row) = stream.try_next().await? { + if first.is_some() { + return Err(Error::row_count()); + } - if stream.try_next().await?.is_some() { - return Err(Error::row_count()); + first = Some(row); } - Ok(Some(row)) + Ok(first) } /// The maximally flexible version of [`query`]. From 97436303232127dbd448d71a50c6365bdbee083c Mon Sep 17 00:00:00 2001 From: laxjesse Date: Wed, 13 Mar 2024 11:10:58 -0400 Subject: [PATCH 079/126] use `split_once` instead of `split` to parse lsn strings [`str::split`](https://doc.rust-lang.org/std/primitive.str.html#method.split) allocates a vector and generates considerably more instructions when compiled than [`str::split_once`](https://doc.rust-lang.org/std/primitive.str.html#method.split_once). [`u64::from_str_radix(split_lo, 16)`](https://doc.rust-lang.org/std/primitive.u64.html#method.from_str_radix) will error if the `lsn_str` contains more than one `/` so this change should result in the same behavior as the current implementation despite not explicitly checking this. --- postgres-types/CHANGELOG.md | 6 ++++++ postgres-types/src/pg_lsn.rs | 18 ++++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md index 72a1cbb6a..157a2cc7d 100644 --- a/postgres-types/CHANGELOG.md +++ b/postgres-types/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## Unreleased + +### Changed + +* `FromStr` implementation for `PgLsn` no longer allocates a `Vec` when splitting an lsn string on it's `/`. + ## v0.2.6 - 2023-08-19 ### Fixed diff --git a/postgres-types/src/pg_lsn.rs b/postgres-types/src/pg_lsn.rs index f0bbf4022..f339f9689 100644 --- a/postgres-types/src/pg_lsn.rs +++ b/postgres-types/src/pg_lsn.rs @@ -33,16 +33,14 @@ impl FromStr for PgLsn { type Err = ParseLsnError; fn from_str(lsn_str: &str) -> Result { - let split: Vec<&str> = lsn_str.split('/').collect(); - if split.len() == 2 { - let (hi, lo) = ( - u64::from_str_radix(split[0], 16).map_err(|_| ParseLsnError(()))?, - u64::from_str_radix(split[1], 16).map_err(|_| ParseLsnError(()))?, - ); - Ok(PgLsn((hi << 32) | lo)) - } else { - Err(ParseLsnError(())) - } + let Some((split_hi, split_lo)) = lsn_str.split_once('/') else { + return Err(ParseLsnError(())); + }; + let (hi, lo) = ( + u64::from_str_radix(split_hi, 16).map_err(|_| ParseLsnError(()))?, + u64::from_str_radix(split_lo, 16).map_err(|_| ParseLsnError(()))?, + ); + Ok(PgLsn((hi << 32) | lo)) } } From 3836a3052065bccf53001b832a21823204bfa137 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Wed, 10 Apr 2024 17:42:13 +0200 Subject: [PATCH 080/126] Make license metadata SPDX compliant --- postgres-derive/Cargo.toml | 4 ++-- postgres-native-tls/Cargo.toml | 2 +- postgres-openssl/Cargo.toml | 2 +- postgres-protocol/Cargo.toml | 2 +- postgres-types/Cargo.toml | 2 +- postgres/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 51ebb5663..5d1604b24 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -2,7 +2,7 @@ name = "postgres-derive" version = "0.4.5" authors = ["Steven Fackler "] -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" edition = "2018" description = "An internal crate used by postgres-types" repository = "https://github.com/sfackler/rust-postgres" @@ -15,4 +15,4 @@ test = false syn = "2.0" proc-macro2 = "1.0" quote = "1.0" -heck = "0.4" \ No newline at end of file +heck = "0.4" diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index 1f2f6385d..936eeeaa4 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -3,7 +3,7 @@ name = "postgres-native-tls" version = "0.5.0" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "TLS support for tokio-postgres via native-tls" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" diff --git a/postgres-openssl/Cargo.toml b/postgres-openssl/Cargo.toml index 8671308af..b7ebd3385 100644 --- a/postgres-openssl/Cargo.toml +++ b/postgres-openssl/Cargo.toml @@ -3,7 +3,7 @@ name = "postgres-openssl" version = "0.5.0" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "TLS support for tokio-postgres via openssl" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index bc83fc4e6..a8a130495 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -4,7 +4,7 @@ version = "0.6.6" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index cfd083637..bf011251b 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -3,7 +3,7 @@ name = "postgres-types" version = "0.2.6" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "Conversions between Rust and Postgres values" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 18406da9f..2ff3c875e 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -3,7 +3,7 @@ name = "postgres" version = "0.19.7" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "A native, synchronous PostgreSQL client" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 237f3d2f1..b3e56314f 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -3,7 +3,7 @@ name = "tokio-postgres" version = "0.7.10" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "A native, asynchronous PostgreSQL client" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" From 670cd7d5802dfb3b0b6b1eadd480f5c9730bb0b0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 10 Apr 2024 16:21:21 +0000 Subject: [PATCH 081/126] Update heck requirement from 0.4 to 0.5 --- updated-dependencies: - dependency-name: heck dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- postgres-derive/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 5d1604b24..cbae6c77b 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -15,4 +15,4 @@ test = false syn = "2.0" proc-macro2 = "1.0" quote = "1.0" -heck = "0.4" +heck = "0.5" From 3c6dbe9b8c7bfad82c646f34092e3fa1d321b723 Mon Sep 17 00:00:00 2001 From: Yuri Astrakhan Date: Wed, 1 May 2024 22:46:06 -0400 Subject: [PATCH 082/126] Avoid extra clone in config if possible Using `impl Into` instead of `&str` in a fn arg allows both `&str` and `String` as parameters - thus if the caller already has a String object that it doesn't need, it can pass it in without extra cloning. The same might be done with the password, but may require closer look. --- tokio-postgres/src/config.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index c78346fff..62b45f793 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -248,8 +248,8 @@ impl Config { /// Sets the user to authenticate with. /// /// Defaults to the user executing this process. - pub fn user(&mut self, user: &str) -> &mut Config { - self.user = Some(user.to_string()); + pub fn user(&mut self, user: impl Into) -> &mut Config { + self.user = Some(user.into()); self } @@ -277,8 +277,8 @@ impl Config { /// Sets the name of the database to connect to. /// /// Defaults to the user. - pub fn dbname(&mut self, dbname: &str) -> &mut Config { - self.dbname = Some(dbname.to_string()); + pub fn dbname(&mut self, dbname: impl Into) -> &mut Config { + self.dbname = Some(dbname.into()); self } @@ -289,8 +289,8 @@ impl Config { } /// Sets command line options used to configure the server. - pub fn options(&mut self, options: &str) -> &mut Config { - self.options = Some(options.to_string()); + pub fn options(&mut self, options: impl Into) -> &mut Config { + self.options = Some(options.into()); self } @@ -301,8 +301,8 @@ impl Config { } /// Sets the value of the `application_name` runtime parameter. - pub fn application_name(&mut self, application_name: &str) -> &mut Config { - self.application_name = Some(application_name.to_string()); + pub fn application_name(&mut self, application_name: impl Into) -> &mut Config { + self.application_name = Some(application_name.into()); self } @@ -330,7 +330,9 @@ impl Config { /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. /// There must be either no hosts, or the same number of hosts as hostaddrs. - pub fn host(&mut self, host: &str) -> &mut Config { + pub fn host(&mut self, host: impl Into) -> &mut Config { + let host = host.into(); + #[cfg(unix)] { if host.starts_with('/') { @@ -338,7 +340,7 @@ impl Config { } } - self.host.push(Host::Tcp(host.to_string())); + self.host.push(Host::Tcp(host)); self } @@ -990,7 +992,7 @@ impl<'a> UrlParser<'a> { let mut it = creds.splitn(2, ':'); let user = self.decode(it.next().unwrap())?; - self.config.user(&user); + self.config.user(user); if let Some(password) = it.next() { let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); @@ -1053,7 +1055,7 @@ impl<'a> UrlParser<'a> { }; if !dbname.is_empty() { - self.config.dbname(&self.decode(dbname)?); + self.config.dbname(self.decode(dbname)?); } Ok(()) From d5d75d3a2f064425436c08b6a8f2da2b985aab3d Mon Sep 17 00:00:00 2001 From: vsuryamurthy Date: Thu, 23 May 2024 17:18:41 +0200 Subject: [PATCH 083/126] add simple_query to GenericClient in tokio_postgres --- tokio-postgres/CHANGELOG.md | 1 + tokio-postgres/src/generic_client.rs | 35 ++++++++++++++++++---------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 9f5eb9521..775c22e34 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -5,6 +5,7 @@ * Disable `rustc-serialize` compatibility of `eui48-1` dependency * Remove tests for `eui48-04` * Add `table_oid` and `field_id` fields to `Columns` struct of prepared statements. +* Add `GenericClient::simple_query`. ## v0.7.10 - 2023-08-25 diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..d80dd3b86 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -1,6 +1,6 @@ use crate::query::RowStream; use crate::types::{BorrowToSql, ToSql, Type}; -use crate::{Client, Error, Row, Statement, ToStatement, Transaction}; +use crate::{Client, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction}; use async_trait::async_trait; mod private { @@ -12,12 +12,12 @@ mod private { /// This trait is "sealed", and cannot be implemented outside of this crate. #[async_trait] pub trait GenericClient: private::Sealed { - /// Like `Client::execute`. + /// Like [`Client::execute`]. async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::execute_raw`. + /// Like [`Client::execute_raw`]. async fn execute_raw(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement + Sync + Send, @@ -25,12 +25,12 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; - /// Like `Client::query`. + /// Like [`Client::query`]. async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::query_one`. + /// Like [`Client::query_one`]. async fn query_one( &self, statement: &T, @@ -39,7 +39,7 @@ pub trait GenericClient: private::Sealed { where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::query_opt`. + /// Like [`Client::query_opt`]. async fn query_opt( &self, statement: &T, @@ -48,7 +48,7 @@ pub trait GenericClient: private::Sealed { where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::query_raw`. + /// Like [`Client::query_raw`]. async fn query_raw(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement + Sync + Send, @@ -56,23 +56,26 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; - /// Like `Client::prepare`. + /// Like [`Client::prepare`]. async fn prepare(&self, query: &str) -> Result; - /// Like `Client::prepare_typed`. + /// Like [`Client::prepare_typed`]. async fn prepare_typed( &self, query: &str, parameter_types: &[Type], ) -> Result; - /// Like `Client::transaction`. + /// Like [`Client::transaction`]. async fn transaction(&mut self) -> Result, Error>; - /// Like `Client::batch_execute`. + /// Like [`Client::batch_execute`]. async fn batch_execute(&self, query: &str) -> Result<(), Error>; - /// Returns a reference to the underlying `Client`. + /// Like [`Client::simple_query`]. + async fn simple_query(&self, query: &str) -> Result, Error>; + + /// Returns a reference to the underlying [`Client`]. fn client(&self) -> &Client; } @@ -156,6 +159,10 @@ impl GenericClient for Client { self.batch_execute(query).await } + async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query(query).await + } + fn client(&self) -> &Client { self } @@ -243,6 +250,10 @@ impl GenericClient for Transaction<'_> { self.batch_execute(query).await } + async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query(query).await + } + fn client(&self) -> &Client { self.client() } From fbecae11ace79376b20ae8b9a587ab577e8287cd Mon Sep 17 00:00:00 2001 From: Duarte Nunes Date: Mon, 11 Mar 2024 14:43:50 -0300 Subject: [PATCH 084/126] feat(types): add 'js' feature for wasm Enables the "js" feature of postgres-protocol. --- postgres-types/Cargo.toml | 1 + tokio-postgres/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index bf011251b..33296db2c 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -13,6 +13,7 @@ categories = ["database"] [features] derive = ["postgres-derive"] array-impls = ["array-init"] +js = ["postgres-protocol/js"] with-bit-vec-0_6 = ["bit-vec-06"] with-cidr-0_2 = ["cidr-02"] with-chrono-0_4 = ["chrono-04"] diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index b3e56314f..2e080cfb2 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -40,7 +40,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"] with-uuid-1 = ["postgres-types/with-uuid-1"] with-time-0_2 = ["postgres-types/with-time-0_2"] with-time-0_3 = ["postgres-types/with-time-0_3"] -js = ["postgres-protocol/js"] +js = ["postgres-protocol/js", "postgres-types/js"] [dependencies] async-trait = "0.1" From 6cd4652bad6ac8474235c23d0e4e96cc4aa4d8db Mon Sep 17 00:00:00 2001 From: Dane Rigby Date: Tue, 28 May 2024 21:57:27 -0500 Subject: [PATCH 085/126] Add RowDescription to SimpleQueryMessage --- tokio-postgres/src/lib.rs | 6 ++++++ tokio-postgres/src/simple_query.rs | 5 +++-- tokio-postgres/tests/test/main.rs | 13 ++++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 2973d33b0..d650f4db9 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -118,6 +118,10 @@ //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use std::sync::Arc; + +use simple_query::SimpleColumn; + pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; @@ -248,6 +252,8 @@ pub enum SimpleQueryMessage { /// /// The number of rows modified or selected is returned. CommandComplete(u64), + /// Column values of the proceeding row values + RowDescription(Arc<[SimpleColumn]>) } fn slice_iter<'a>( diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index bcc6d928b..4e0b7734d 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -95,14 +95,15 @@ impl Stream for SimpleQueryStream { return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))); } Message::RowDescription(body) => { - let columns = body + let columns: Arc<[SimpleColumn]> = body .fields() .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) .collect::>() .map_err(Error::parse)? .into(); - *this.columns = Some(columns); + *this.columns = Some(columns.clone()); + return Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns.clone())))); } Message::DataRow(body) => { let row = match &this.columns { diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 737f46631..4fa72aec9 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -328,6 +328,13 @@ async fn simple_query() { _ => panic!("unexpected message"), } match &messages[2] { + SimpleQueryMessage::RowDescription(columns) => { + assert_eq!(columns.get(0).map(|c| c.name()), Some("id")); + assert_eq!(columns.get(1).map(|c| c.name()), Some("name")); + } + _ => panic!("unexpected message") + } + match &messages[3] { SimpleQueryMessage::Row(row) => { assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); @@ -336,7 +343,7 @@ async fn simple_query() { } _ => panic!("unexpected message"), } - match &messages[3] { + match &messages[4] { SimpleQueryMessage::Row(row) => { assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); @@ -345,11 +352,11 @@ async fn simple_query() { } _ => panic!("unexpected message"), } - match messages[4] { + match messages[5] { SimpleQueryMessage::CommandComplete(2) => {} _ => panic!("unexpected message"), } - assert_eq!(messages.len(), 5); + assert_eq!(messages.len(), 6); } #[tokio::test] From 7afead9a13d54f1c5ce9bef5eda1fb7ced26db61 Mon Sep 17 00:00:00 2001 From: Dane Rigby Date: Tue, 28 May 2024 22:08:41 -0500 Subject: [PATCH 086/126] Formatting updates --- tokio-postgres/src/lib.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index d650f4db9..6c6266736 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -118,10 +118,6 @@ //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | #![warn(rust_2018_idioms, clippy::all, missing_docs)] -use std::sync::Arc; - -use simple_query::SimpleColumn; - pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; @@ -134,7 +130,7 @@ pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; pub use crate::query::RowStream; pub use crate::row::{Row, SimpleQueryRow}; -pub use crate::simple_query::SimpleQueryStream; +pub use crate::simple_query::{SimpleQueryStream, SimpleColumn}; #[cfg(feature = "runtime")] pub use crate::socket::Socket; pub use crate::statement::{Column, Statement}; @@ -145,6 +141,7 @@ pub use crate::to_statement::ToStatement; pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; use crate::types::ToSql; +use std::sync::Arc; pub mod binary_copy; mod bind; From eec06021d9ebe1c1c2fcc47666a76ce257ae2891 Mon Sep 17 00:00:00 2001 From: Dane Rigby Date: Tue, 28 May 2024 23:50:50 -0500 Subject: [PATCH 087/126] Clippy compliance --- tokio-postgres/src/simple_query.rs | 54 ++++++++++++++---------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 4e0b7734d..e84806d36 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -85,36 +85,34 @@ impl Stream for SimpleQueryStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - loop { - match ready!(this.responses.poll_next(cx)?) { - Message::CommandComplete(body) => { - let rows = extract_row_affected(&body)?; - return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))); - } - Message::EmptyQueryResponse => { - return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))); - } - Message::RowDescription(body) => { - let columns: Arc<[SimpleColumn]> = body - .fields() - .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) - .collect::>() - .map_err(Error::parse)? - .into(); + match ready!(this.responses.poll_next(cx)?) { + Message::CommandComplete(body) => { + let rows = extract_row_affected(&body)?; + Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))) + } + Message::EmptyQueryResponse => { + Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))) + } + Message::RowDescription(body) => { + let columns: Arc<[SimpleColumn]> = body + .fields() + .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) + .collect::>() + .map_err(Error::parse)? + .into(); - *this.columns = Some(columns.clone()); - return Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns.clone())))); - } - Message::DataRow(body) => { - let row = match &this.columns { - Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, - None => return Poll::Ready(Some(Err(Error::unexpected_message()))), - }; - return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); - } - Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + *this.columns = Some(columns.clone()); + Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns.clone())))) + } + Message::DataRow(body) => { + let row = match &this.columns { + Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, + None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + }; + Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))) } + Message::ReadyForQuery(_) => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), } } } From bd6350c2fff2201d680a1814acf7a9208f4b7ad4 Mon Sep 17 00:00:00 2001 From: Dane Rigby Date: Wed, 29 May 2024 23:32:18 -0500 Subject: [PATCH 088/126] Formatting --- tokio-postgres/src/lib.rs | 4 ++-- tokio-postgres/src/simple_query.rs | 4 +++- tokio-postgres/tests/test/main.rs | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 6c6266736..a603158fb 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -130,7 +130,7 @@ pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; pub use crate::query::RowStream; pub use crate::row::{Row, SimpleQueryRow}; -pub use crate::simple_query::{SimpleQueryStream, SimpleColumn}; +pub use crate::simple_query::{SimpleColumn, SimpleQueryStream}; #[cfg(feature = "runtime")] pub use crate::socket::Socket; pub use crate::statement::{Column, Statement}; @@ -250,7 +250,7 @@ pub enum SimpleQueryMessage { /// The number of rows modified or selected is returned. CommandComplete(u64), /// Column values of the proceeding row values - RowDescription(Arc<[SimpleColumn]>) + RowDescription(Arc<[SimpleColumn]>), } fn slice_iter<'a>( diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index e84806d36..86af8e739 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -102,7 +102,9 @@ impl Stream for SimpleQueryStream { .into(); *this.columns = Some(columns.clone()); - Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns.clone())))) + Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription( + columns.clone(), + )))) } Message::DataRow(body) => { let row = match &this.columns { diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 4fa72aec9..e85960ab6 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -330,9 +330,9 @@ async fn simple_query() { match &messages[2] { SimpleQueryMessage::RowDescription(columns) => { assert_eq!(columns.get(0).map(|c| c.name()), Some("id")); - assert_eq!(columns.get(1).map(|c| c.name()), Some("name")); + assert_eq!(columns.get(1).map(|c| c.name()), Some("name")); } - _ => panic!("unexpected message") + _ => panic!("unexpected message"), } match &messages[3] { SimpleQueryMessage::Row(row) => { From f3976680c6d7004b04b3ba39f90f2956ce6d7010 Mon Sep 17 00:00:00 2001 From: Ramnivas Laddad Date: Sun, 26 May 2024 11:05:00 -0700 Subject: [PATCH 089/126] Work with pools that don't support prepared statements Introduce a new `query_with_param_types` method that allows to specify Postgres type parameters. This obviated the need to use prepared statementsjust to obtain parameter types for a query. It then combines parse, bind, and execute in a single packet. Related: #1017, #1067 --- tokio-postgres/src/client.rs | 82 +++++++++++++++ tokio-postgres/src/generic_client.rs | 46 +++++++++ tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/query.rs | 146 ++++++++++++++++++++++++++- tokio-postgres/src/statement.rs | 13 +++ tokio-postgres/src/transaction.rs | 27 +++++ tokio-postgres/tests/test/main.rs | 106 +++++++++++++++++++ 7 files changed, 416 insertions(+), 6 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index d48a23a60..431bfa792 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -364,6 +364,88 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Like `query`, but requires the types of query parameters to be explicitly specified. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for prepare, execute, and close). Thus, + /// this is suitable in environments where prepared statements aren't supported (such as Cloudflare Workers with Hyperdrive). + /// + /// # Examples + /// + /// ```no_run + /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { + /// use tokio_postgres::types::ToSql; + /// use tokio_postgres::types::Type; + /// use futures_util::{pin_mut, TryStreamExt}; + /// + /// let rows = client.query_with_param_types( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)], + /// ).await?; + /// + /// for row in rows { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_raw_with_param_types(statement, params) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query_with_param_types`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The parameters must specify value along with their Postgres type. This allows performing + /// queries without three round trips (for prepare, execute, and close). + /// + /// [`query_with_param_types`]: #method.query_with_param_types + /// + /// # Examples + /// + /// ```no_run + /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { + /// use tokio_postgres::types::ToSql; + /// use tokio_postgres::types::Type; + /// use futures_util::{pin_mut, TryStreamExt}; + /// + /// let mut it = client.query_raw_with_param_types( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)], + /// ).await?; + /// + /// pin_mut!(it); + /// while let Some(row) = it.try_next().await? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query_raw_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result { + fn slice_iter<'a>( + s: &'a [(&'a (dyn ToSql + Sync), Type)], + ) -> impl ExactSizeIterator + 'a { + s.iter() + .map(|(param, param_type)| (*param as _, param_type.clone())) + } + + query::query_with_param_types(&self.inner, statement, slice_iter(params)).await + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..3a0b09233 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,6 +56,20 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_with_param_types` + async fn query_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error>; + + /// Like `Client::query_raw_with_param_types`. + async fn query_raw_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -136,6 +150,22 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_with_param_types(statement, params).await + } + + async fn query_raw_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result { + self.query_raw_with_param_types(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -222,6 +252,22 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_with_param_types(statement, params).await + } + + async fn query_raw_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result { + self.query_raw_with_param_types(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 07fb45694..1d9bacb16 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub(crate) async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e6e1d00a8..b9cc66405 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,21 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; +use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; +use crate::{Column, Error, Portal, Row, Statement}; use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; -use postgres_protocol::message::backend::{CommandCompleteBody, Message}; +use postgres_protocol::message::backend::{CommandCompleteBody, Message, RowDescriptionBody}; use postgres_protocol::message::frontend; +use postgres_types::Type; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -50,13 +54,125 @@ where }; let responses = start(client, buf).await?; Ok(RowStream { - statement, + statement: statement, responses, rows_affected: None, _p: PhantomPinned, }) } +enum QueryProcessingState { + Empty, + ParseCompleted, + BindCompleted, + ParameterDescribed, + Final(Vec), +} + +/// State machine for processing messages for `query_with_param_types`. +impl QueryProcessingState { + pub async fn process_message( + self, + client: &Arc, + message: Message, + ) -> Result { + match (self, message) { + (QueryProcessingState::Empty, Message::ParseComplete) => { + Ok(QueryProcessingState::ParseCompleted) + } + (QueryProcessingState::ParseCompleted, Message::BindComplete) => { + Ok(QueryProcessingState::BindCompleted) + } + (QueryProcessingState::BindCompleted, Message::ParameterDescription(_)) => { + Ok(QueryProcessingState::ParameterDescribed) + } + ( + QueryProcessingState::ParameterDescribed, + Message::RowDescription(row_description), + ) => Self::form_final(client, Some(row_description)).await, + (QueryProcessingState::ParameterDescribed, Message::NoData) => { + Self::form_final(client, None).await + } + (_, Message::ErrorResponse(body)) => Err(Error::db(body)), + _ => Err(Error::unexpected_message()), + } + } + + async fn form_final( + client: &Arc, + row_description: Option, + ) -> Result { + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column { + name: field.name().to_string(), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), + r#type: type_, + }; + columns.push(column); + } + } + + Ok(Self::Final(columns)) + } +} + +pub async fn query_with_param_types<'a, P, I>( + client: &Arc, + query: &str, + params: I, +) -> Result +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip(); + + let params = params.into_iter(); + + let param_oids = param_types.iter().map(|t| t.oid()).collect::>(); + + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?; + + encode_bind_with_statement_name_and_param_types("", ¶m_types, params, "", buf)?; + + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + + frontend::execute("", 0, buf).map_err(Error::encode)?; + + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + let mut state = QueryProcessingState::Empty; + + loop { + let message = responses.next().await?; + + state = state.process_message(client, message).await?; + + if let QueryProcessingState::Final(columns) = state { + return Ok(RowStream { + statement: Statement::unnamed(vec![], columns), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } + } +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, @@ -164,7 +280,27 @@ where I: IntoIterator, I::IntoIter: ExactSizeIterator, { - let param_types = statement.params(); + encode_bind_with_statement_name_and_param_types( + statement.name(), + statement.params(), + params, + portal, + buf, + ) +} + +fn encode_bind_with_statement_name_and_param_types( + statement_name: &str, + param_types: &[Type], + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ let params = params.into_iter(); if param_types.len() != params.len() { @@ -181,7 +317,7 @@ where let mut error_idx = 0; let r = frontend::bind( portal, - statement.name(), + statement_name, param_formats, params.zip(param_types).enumerate(), |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) { diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index c5d657738..2b88ecd3b 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -14,6 +14,10 @@ struct StatementInner { impl Drop for StatementInner { fn drop(&mut self) { + if self.name.is_empty() { + // Unnamed statements don't need to be closed + return; + } if let Some(client) = self.client.upgrade() { let buf = client.with_buf(|buf| { frontend::close(b'S', &self.name, buf).unwrap(); @@ -46,6 +50,15 @@ impl Statement { })) } + pub(crate) fn unnamed(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + pub(crate) fn name(&self) -> &str { &self.0.name } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..5a6094b56 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -227,6 +227,33 @@ impl<'a> Transaction<'a> { query::query_portal(self.client.inner(), portal, max_rows).await } + /// Like `Client::query_with_param_types`. + pub async fn query_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_raw_with_param_types(statement, params) + .await? + .try_collect() + .await + } + + /// Like `Client::query_raw_with_param_types`. + pub async fn query_raw_with_param_types( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result { + fn slice_iter<'a>( + s: &'a [(&'a (dyn ToSql + Sync), Type)], + ) -> impl ExactSizeIterator + 'a { + s.iter() + .map(|(param, param_type)| (*param as _, param_type.clone())) + } + query::query_with_param_types(self.client.inner(), statement, slice_iter(params)).await + } + /// Like `Client::copy_in`. pub async fn copy_in(&self, statement: &T) -> Result, Error> where diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 737f46631..925c99206 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -952,3 +952,109 @@ async fn deferred_constraint() { .await .unwrap_err(); } + +#[tokio::test] +async fn query_with_param_types_no_transaction() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT, + age INT + ); + INSERT INTO foo (name, age) VALUES ('alice', 20), ('bob', 30), ('carol', 40); + ", + ) + .await + .unwrap(); + + let rows: Vec = client + .query_with_param_types( + "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", + &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + let first_row = &rows[0]; + assert_eq!(first_row.get::<_, &str>(0), "bob"); + assert_eq!(first_row.get::<_, i32>(1), 30); + assert_eq!(first_row.get::<_, &str>(2), "literal"); + assert_eq!(first_row.get::<_, i32>(3), 5); + + let second_row = &rows[1]; + assert_eq!(second_row.get::<_, &str>(0), "carol"); + assert_eq!(second_row.get::<_, i32>(1), 40); + assert_eq!(second_row.get::<_, &str>(2), "literal"); + assert_eq!(second_row.get::<_, i32>(3), 5); +} + +#[tokio::test] +async fn query_with_param_types_with_transaction() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT, + age INT + ); + ", + ) + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + + let rows: Vec = transaction + .query_with_param_types( + "INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age", + &[ + (&"alice", Type::TEXT), + (&20i32, Type::INT4), + (&"bob", Type::TEXT), + (&30i32, Type::INT4), + (&"carol", Type::TEXT), + (&40i32, Type::INT4), + ], + ) + .await + .unwrap(); + let inserted_values: Vec<(String, i32)> = rows + .iter() + .map(|row| (row.get::<_, String>(0), row.get::<_, i32>(1))) + .collect(); + assert_eq!( + inserted_values, + [ + ("alice".to_string(), 20), + ("bob".to_string(), 30), + ("carol".to_string(), 40) + ] + ); + + let rows: Vec = transaction + .query_with_param_types( + "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", + &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + let first_row = &rows[0]; + assert_eq!(first_row.get::<_, &str>(0), "bob"); + assert_eq!(first_row.get::<_, i32>(1), 30); + assert_eq!(first_row.get::<_, &str>(2), "literal"); + assert_eq!(first_row.get::<_, i32>(3), 5); + + let second_row = &rows[1]; + assert_eq!(second_row.get::<_, &str>(0), "carol"); + assert_eq!(second_row.get::<_, i32>(1), 40); + assert_eq!(second_row.get::<_, &str>(2), "literal"); + assert_eq!(second_row.get::<_, i32>(3), 5); +} From 84994dad1aa9c3ef5c813b95c86c80dbfa4b7f0d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 6 Jul 2024 11:23:26 -0400 Subject: [PATCH 090/126] Derive Clone for Row --- postgres-protocol/src/message/backend.rs | 2 +- tokio-postgres/src/row.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 1b5be1098..c4439b26a 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -524,7 +524,7 @@ impl CopyOutResponseBody { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct DataRowBody { storage: Bytes, len: u16, diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index 3c79de603..767c26921 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -95,6 +95,7 @@ where } /// A row of data returned from the database by a query. +#[derive(Clone)] pub struct Row { statement: Statement, body: DataRowBody, From 2b1949dd2f8745fcfaefe4b5e228684c25997265 Mon Sep 17 00:00:00 2001 From: Sidney Cammeresi Date: Sat, 6 Jul 2024 11:00:41 -0700 Subject: [PATCH 091/126] impl Debug for Statement The lack of this common trait bound caused some unpleasantness. For example, the following didn't compile: let x = OnceLock::new(); let stmt = db.prepare(...)?; x.set(stmt).expect(...); // returns Result<(), T=Statement> where T: Debug --- tokio-postgres/src/statement.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index c5d657738..4955d3b41 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -61,6 +61,16 @@ impl Statement { } } +impl std::fmt::Debug for Statement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("Statement") + .field("name", &self.0.name) + .field("params", &self.0.params) + .field("columns", &self.0.columns) + .finish_non_exhaustive() + } +} + /// Information about a column of a query. #[derive(Debug)] pub struct Column { From 1f312194928c5a385d51d52e5d13ca59d3dc1b43 Mon Sep 17 00:00:00 2001 From: Sidney Cammeresi Date: Sat, 6 Jul 2024 12:29:09 -0700 Subject: [PATCH 092/126] Fix a few nits pointed out by clippy - ...::max_value() -> ..::MAX - delete explicit import of signed integer types --- postgres-protocol/src/lib.rs | 2 +- postgres-types/src/lib.rs | 2 +- postgres-types/src/special.rs | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 83d9bf55c..e0de3b6c6 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -60,7 +60,7 @@ macro_rules! from_usize { impl FromUsize for $t { #[inline] fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::max_value() as usize { + if x > <$t>::MAX as usize { Err(io::Error::new( io::ErrorKind::InvalidInput, "value too large to transmit", diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 2f02f6e5f..492039766 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1222,7 +1222,7 @@ impl ToSql for IpAddr { } fn downcast(len: usize) -> Result> { - if len > i32::max_value() as usize { + if len > i32::MAX as usize { Err("value too large to transmit".into()) } else { Ok(len as i32) diff --git a/postgres-types/src/special.rs b/postgres-types/src/special.rs index 1a865287e..d8541bf0e 100644 --- a/postgres-types/src/special.rs +++ b/postgres-types/src/special.rs @@ -1,7 +1,6 @@ use bytes::BytesMut; use postgres_protocol::types; use std::error::Error; -use std::{i32, i64}; use crate::{FromSql, IsNull, ToSql, Type}; From 263b0068af39072bc7be05b6500e47f263cbd43e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 6 Jul 2024 19:21:37 -0400 Subject: [PATCH 093/126] Handle non-UTF8 error fields --- postgres-protocol/src/message/backend.rs | 10 +++++-- tokio-postgres/src/error/mod.rs | 37 ++++++++++++------------ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index c4439b26a..73b169288 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -633,7 +633,7 @@ impl<'a> FallibleIterator for ErrorFields<'a> { } let value_end = find_null(self.buf, 0)?; - let value = get_str(&self.buf[..value_end])?; + let value = &self.buf[..value_end]; self.buf = &self.buf[value_end + 1..]; Ok(Some(ErrorField { type_, value })) @@ -642,7 +642,7 @@ impl<'a> FallibleIterator for ErrorFields<'a> { pub struct ErrorField<'a> { type_: u8, - value: &'a str, + value: &'a [u8], } impl<'a> ErrorField<'a> { @@ -652,7 +652,13 @@ impl<'a> ErrorField<'a> { } #[inline] + #[deprecated(note = "use value_bytes instead", since = "0.6.7")] pub fn value(&self) -> &str { + str::from_utf8(self.value).expect("error field value contained non-UTF8 bytes") + } + + #[inline] + pub fn value_bytes(&self) -> &[u8] { self.value } } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index f1e2644c6..75664d258 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -107,14 +107,15 @@ impl DbError { let mut routine = None; while let Some(field) = fields.next()? { + let value = String::from_utf8_lossy(field.value_bytes()); match field.type_() { - b'S' => severity = Some(field.value().to_owned()), - b'C' => code = Some(SqlState::from_code(field.value())), - b'M' => message = Some(field.value().to_owned()), - b'D' => detail = Some(field.value().to_owned()), - b'H' => hint = Some(field.value().to_owned()), + b'S' => severity = Some(value.into_owned()), + b'C' => code = Some(SqlState::from_code(&value)), + b'M' => message = Some(value.into_owned()), + b'D' => detail = Some(value.into_owned()), + b'H' => hint = Some(value.into_owned()), b'P' => { - normal_position = Some(field.value().parse::().map_err(|_| { + normal_position = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`P` field did not contain an integer", @@ -122,32 +123,32 @@ impl DbError { })?); } b'p' => { - internal_position = Some(field.value().parse::().map_err(|_| { + internal_position = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`p` field did not contain an integer", ) })?); } - b'q' => internal_query = Some(field.value().to_owned()), - b'W' => where_ = Some(field.value().to_owned()), - b's' => schema = Some(field.value().to_owned()), - b't' => table = Some(field.value().to_owned()), - b'c' => column = Some(field.value().to_owned()), - b'd' => datatype = Some(field.value().to_owned()), - b'n' => constraint = Some(field.value().to_owned()), - b'F' => file = Some(field.value().to_owned()), + b'q' => internal_query = Some(value.into_owned()), + b'W' => where_ = Some(value.into_owned()), + b's' => schema = Some(value.into_owned()), + b't' => table = Some(value.into_owned()), + b'c' => column = Some(value.into_owned()), + b'd' => datatype = Some(value.into_owned()), + b'n' => constraint = Some(value.into_owned()), + b'F' => file = Some(value.into_owned()), b'L' => { - line = Some(field.value().parse::().map_err(|_| { + line = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`L` field did not contain an integer", ) })?); } - b'R' => routine = Some(field.value().to_owned()), + b'R' => routine = Some(value.into_owned()), b'V' => { - parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { + parsed_severity = Some(Severity::from_str(&value).ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "`V` field contained an invalid value", From cfd91632be877543a7e19e7a05816ed5d241b559 Mon Sep 17 00:00:00 2001 From: Dane Rigby Date: Sun, 7 Jul 2024 13:56:35 -0500 Subject: [PATCH 094/126] PR Fix: Only use single clone for RowDescription --- tokio-postgres/src/simple_query.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 86af8e739..b6500260e 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -101,9 +101,9 @@ impl Stream for SimpleQueryStream { .map_err(Error::parse)? .into(); - *this.columns = Some(columns.clone()); + *this.columns = Some(columns); Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription( - columns.clone(), + this.columns.as_ref().unwrap().clone(), )))) } Message::DataRow(body) => { From 3f8f5ded337a0122959f6e4a3dc9343bf6c6ee70 Mon Sep 17 00:00:00 2001 From: Ramnivas Laddad Date: Sun, 7 Jul 2024 16:21:40 -0700 Subject: [PATCH 095/126] Replace the state machine to process messages with a direct match statements --- tokio-postgres/src/query.rs | 101 ++++++++++-------------------------- 1 file changed, 27 insertions(+), 74 deletions(-) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index b9cc66405..2bdfa14cc 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -9,7 +9,7 @@ use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; -use postgres_protocol::message::backend::{CommandCompleteBody, Message, RowDescriptionBody}; +use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; use postgres_types::Type; use std::fmt; @@ -61,66 +61,6 @@ where }) } -enum QueryProcessingState { - Empty, - ParseCompleted, - BindCompleted, - ParameterDescribed, - Final(Vec), -} - -/// State machine for processing messages for `query_with_param_types`. -impl QueryProcessingState { - pub async fn process_message( - self, - client: &Arc, - message: Message, - ) -> Result { - match (self, message) { - (QueryProcessingState::Empty, Message::ParseComplete) => { - Ok(QueryProcessingState::ParseCompleted) - } - (QueryProcessingState::ParseCompleted, Message::BindComplete) => { - Ok(QueryProcessingState::BindCompleted) - } - (QueryProcessingState::BindCompleted, Message::ParameterDescription(_)) => { - Ok(QueryProcessingState::ParameterDescribed) - } - ( - QueryProcessingState::ParameterDescribed, - Message::RowDescription(row_description), - ) => Self::form_final(client, Some(row_description)).await, - (QueryProcessingState::ParameterDescribed, Message::NoData) => { - Self::form_final(client, None).await - } - (_, Message::ErrorResponse(body)) => Err(Error::db(body)), - _ => Err(Error::unexpected_message()), - } - } - - async fn form_final( - client: &Arc, - row_description: Option, - ) -> Result { - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, field.type_oid()).await?; - let column = Column { - name: field.name().to_string(), - table_oid: Some(field.table_oid()).filter(|n| *n != 0), - column_id: Some(field.column_id()).filter(|n| *n != 0), - r#type: type_, - }; - columns.push(column); - } - } - - Ok(Self::Final(columns)) - } -} - pub async fn query_with_param_types<'a, P, I>( client: &Arc, query: &str, @@ -155,20 +95,33 @@ where let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - let mut state = QueryProcessingState::Empty; - loop { - let message = responses.next().await?; - - state = state.process_message(client, message).await?; - - if let QueryProcessingState::Final(columns) = state { - return Ok(RowStream { - statement: Statement::unnamed(vec![], columns), - responses, - rows_affected: None, - _p: PhantomPinned, - }); + match responses.next().await? { + Message::ParseComplete + | Message::BindComplete + | Message::ParameterDescription(_) + | Message::NoData => {} + Message::RowDescription(row_description) => { + let mut columns: Vec = vec![]; + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column { + name: field.name().to_string(), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), + r#type: type_, + }; + columns.push(column); + } + return Ok(RowStream { + statement: Statement::unnamed(vec![], columns), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } + _ => return Err(Error::unexpected_message()), } } } From 74eb4dbf7399cb96500f2b60a2b838805471a26a Mon Sep 17 00:00:00 2001 From: Ramnivas Laddad Date: Sun, 7 Jul 2024 16:43:41 -0700 Subject: [PATCH 096/126] Remove query_raw_with_param_types as per PR feedback --- tokio-postgres/src/client.rs | 56 ++++++---------------------- tokio-postgres/src/generic_client.rs | 23 ------------ tokio-postgres/src/transaction.rs | 20 +--------- 3 files changed, 12 insertions(+), 87 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 431bfa792..e420bcf2f 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -366,8 +366,13 @@ impl Client { /// Like `query`, but requires the types of query parameters to be explicitly specified. /// - /// Compared to `query`, this method allows performing queries without three round trips (for prepare, execute, and close). Thus, - /// this is suitable in environments where prepared statements aren't supported (such as Cloudflare Workers with Hyperdrive). + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. /// /// # Examples /// @@ -394,48 +399,6 @@ impl Client { statement: &str, params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error> { - self.query_raw_with_param_types(statement, params) - .await? - .try_collect() - .await - } - - /// The maximally flexible version of [`query_with_param_types`]. - /// - /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list - /// provided, 1-indexed. - /// - /// The parameters must specify value along with their Postgres type. This allows performing - /// queries without three round trips (for prepare, execute, and close). - /// - /// [`query_with_param_types`]: #method.query_with_param_types - /// - /// # Examples - /// - /// ```no_run - /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { - /// use tokio_postgres::types::ToSql; - /// use tokio_postgres::types::Type; - /// use futures_util::{pin_mut, TryStreamExt}; - /// - /// let mut it = client.query_raw_with_param_types( - /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", - /// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)], - /// ).await?; - /// - /// pin_mut!(it); - /// while let Some(row) = it.try_next().await? { - /// let foo: i32 = row.get("foo"); - /// println!("foo: {}", foo); - /// } - /// # Ok(()) - /// # } - /// ``` - pub async fn query_raw_with_param_types( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result { fn slice_iter<'a>( s: &'a [(&'a (dyn ToSql + Sync), Type)], ) -> impl ExactSizeIterator + 'a { @@ -443,7 +406,10 @@ impl Client { .map(|(param, param_type)| (*param as _, param_type.clone())) } - query::query_with_param_types(&self.inner, statement, slice_iter(params)).await + query::query_with_param_types(&self.inner, statement, slice_iter(params)) + .await? + .try_collect() + .await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 3a0b09233..b892015dc 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -63,13 +63,6 @@ pub trait GenericClient: private::Sealed { params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error>; - /// Like `Client::query_raw_with_param_types`. - async fn query_raw_with_param_types( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result; - /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -158,14 +151,6 @@ impl GenericClient for Client { self.query_with_param_types(statement, params).await } - async fn query_raw_with_param_types( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result { - self.query_raw_with_param_types(statement, params).await - } - async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -260,14 +245,6 @@ impl GenericClient for Transaction<'_> { self.query_with_param_types(statement, params).await } - async fn query_raw_with_param_types( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result { - self.query_raw_with_param_types(statement, params).await - } - async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 5a6094b56..8a0ad2224 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -233,25 +233,7 @@ impl<'a> Transaction<'a> { statement: &str, params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error> { - self.query_raw_with_param_types(statement, params) - .await? - .try_collect() - .await - } - - /// Like `Client::query_raw_with_param_types`. - pub async fn query_raw_with_param_types( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result { - fn slice_iter<'a>( - s: &'a [(&'a (dyn ToSql + Sync), Type)], - ) -> impl ExactSizeIterator + 'a { - s.iter() - .map(|(param, param_type)| (*param as _, param_type.clone())) - } - query::query_with_param_types(self.client.inner(), statement, slice_iter(params)).await + self.client.query_with_param_types(statement, params).await } /// Like `Client::copy_in`. From 2647024c660ca27701898325a8772b83bece4982 Mon Sep 17 00:00:00 2001 From: Dane Rigby Date: Sun, 7 Jul 2024 21:30:23 -0500 Subject: [PATCH 097/126] PR Fix: Clone first then move --- tokio-postgres/src/simple_query.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index b6500260e..24473b896 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -101,10 +101,8 @@ impl Stream for SimpleQueryStream { .map_err(Error::parse)? .into(); - *this.columns = Some(columns); - Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription( - this.columns.as_ref().unwrap().clone(), - )))) + *this.columns = Some(columns.clone()); + Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns)))) } Message::DataRow(body) => { let row = match &this.columns { From dbd4d02e2f3a367b949e356e9dda40c08272d954 Mon Sep 17 00:00:00 2001 From: Ramnivas Laddad Date: Mon, 8 Jul 2024 17:21:32 -0700 Subject: [PATCH 098/126] Address review comment to rename query_with_param_types to query_typed --- tokio-postgres/src/client.rs | 6 +++--- tokio-postgres/src/generic_client.rs | 12 ++++++------ tokio-postgres/src/query.rs | 2 +- tokio-postgres/src/transaction.rs | 6 +++--- tokio-postgres/tests/test/main.rs | 10 +++++----- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index e420bcf2f..2b29351a5 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -382,7 +382,7 @@ impl Client { /// use tokio_postgres::types::Type; /// use futures_util::{pin_mut, TryStreamExt}; /// - /// let rows = client.query_with_param_types( + /// let rows = client.query_typed( /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", /// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)], /// ).await?; @@ -394,7 +394,7 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn query_with_param_types( + pub async fn query_typed( &self, statement: &str, params: &[(&(dyn ToSql + Sync), Type)], @@ -406,7 +406,7 @@ impl Client { .map(|(param, param_type)| (*param as _, param_type.clone())) } - query::query_with_param_types(&self.inner, statement, slice_iter(params)) + query::query_typed(&self.inner, statement, slice_iter(params)) .await? .try_collect() .await diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index e43bddfea..b91d78064 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,8 +56,8 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; - /// Like [`Client::query_with_param_types`] - async fn query_with_param_types( + /// Like [`Client::query_typed`] + async fn query_typed( &self, statement: &str, params: &[(&(dyn ToSql + Sync), Type)], @@ -146,12 +146,12 @@ impl GenericClient for Client { self.query_raw(statement, params).await } - async fn query_with_param_types( + async fn query_typed( &self, statement: &str, params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error> { - self.query_with_param_types(statement, params).await + self.query_typed(statement, params).await } async fn prepare(&self, query: &str) -> Result { @@ -244,12 +244,12 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } - async fn query_with_param_types( + async fn query_typed( &self, statement: &str, params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error> { - self.query_with_param_types(statement, params).await + self.query_typed(statement, params).await } async fn prepare(&self, query: &str) -> Result { diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 2bdfa14cc..b54e095df 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -61,7 +61,7 @@ where }) } -pub async fn query_with_param_types<'a, P, I>( +pub async fn query_typed<'a, P, I>( client: &Arc, query: &str, params: I, diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 8a0ad2224..3e62b2ac7 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -227,13 +227,13 @@ impl<'a> Transaction<'a> { query::query_portal(self.client.inner(), portal, max_rows).await } - /// Like `Client::query_with_param_types`. - pub async fn query_with_param_types( + /// Like `Client::query_typed`. + pub async fn query_typed( &self, statement: &str, params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error> { - self.client.query_with_param_types(statement, params).await + self.client.query_typed(statement, params).await } /// Like `Client::copy_in`. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 925c99206..7ddb7a36a 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -954,7 +954,7 @@ async fn deferred_constraint() { } #[tokio::test] -async fn query_with_param_types_no_transaction() { +async fn query_typed_no_transaction() { let client = connect("user=postgres").await; client @@ -971,7 +971,7 @@ async fn query_with_param_types_no_transaction() { .unwrap(); let rows: Vec = client - .query_with_param_types( + .query_typed( "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], ) @@ -993,7 +993,7 @@ async fn query_with_param_types_no_transaction() { } #[tokio::test] -async fn query_with_param_types_with_transaction() { +async fn query_typed_with_transaction() { let mut client = connect("user=postgres").await; client @@ -1011,7 +1011,7 @@ async fn query_with_param_types_with_transaction() { let transaction = client.transaction().await.unwrap(); let rows: Vec = transaction - .query_with_param_types( + .query_typed( "INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age", &[ (&"alice", Type::TEXT), @@ -1038,7 +1038,7 @@ async fn query_with_param_types_with_transaction() { ); let rows: Vec = transaction - .query_with_param_types( + .query_typed( "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], ) From 0fa32471ef2e20b7f2e554d6d97cde3a67f1d494 Mon Sep 17 00:00:00 2001 From: Ramnivas Laddad Date: Tue, 9 Jul 2024 17:59:39 -0700 Subject: [PATCH 099/126] Fix a clippy warning --- tokio-postgres/src/query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index b54e095df..e304bbaea 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -54,7 +54,7 @@ where }; let responses = start(client, buf).await?; Ok(RowStream { - statement: statement, + statement, responses, rows_affected: None, _p: PhantomPinned, From 71c836b980799256a7f266195382fc8449fca5e4 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 13 Jul 2024 20:45:32 -0400 Subject: [PATCH 100/126] query_typed tweaks --- postgres/src/client.rs | 65 ++++++++++++++++++++++++++++ postgres/src/generic_client.rs | 45 +++++++++++++++++++ postgres/src/transaction.rs | 29 +++++++++++++ tokio-postgres/src/client.rs | 63 +++++++++++++++++---------- tokio-postgres/src/generic_client.rs | 22 ++++++++++ tokio-postgres/src/query.rs | 63 +++++++++++---------------- tokio-postgres/src/transaction.rs | 27 ++++++++---- 7 files changed, 243 insertions(+), 71 deletions(-) diff --git a/postgres/src/client.rs b/postgres/src/client.rs index c8e14cf81..42ce6dec9 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -257,6 +257,71 @@ impl Client { Ok(RowIter::new(self.connection.as_ref(), stream)) } + /// Like `query`, but requires the types of query parameters to be explicitly specified. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + pub fn query_typed( + &mut self, + query: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.connection + .block_on(self.client.query_typed(query, params)) + } + + /// The maximally flexible version of [`query_typed`]. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + /// + /// [`query_typed`]: #method.query_typed + /// + /// # Examples + /// ```no_run + /// # use postgres::{Client, NoTls}; + /// use postgres::types::{ToSql, Type}; + /// use fallible_iterator::FallibleIterator; + /// # fn main() -> Result<(), postgres::Error> { + /// # let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let params: Vec<(String, Type)> = vec![ + /// ("first param".into(), Type::TEXT), + /// ("second param".into(), Type::TEXT), + /// ]; + /// let mut it = client.query_typed_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// )?; + /// + /// while let Some(row) = it.next()? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn query_typed_raw(&mut self, query: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator, + { + let stream = self + .connection + .block_on(self.client.query_typed_raw(query, params))?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + /// Creates a new prepared statement. /// /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), diff --git a/postgres/src/generic_client.rs b/postgres/src/generic_client.rs index 12f07465d..7b534867c 100644 --- a/postgres/src/generic_client.rs +++ b/postgres/src/generic_client.rs @@ -44,6 +44,19 @@ pub trait GenericClient: private::Sealed { I: IntoIterator, I::IntoIter: ExactSizeIterator; + /// Like [`Client::query_typed`] + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error>; + + /// Like [`Client::query_typed_raw`] + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send; + /// Like `Client::prepare`. fn prepare(&mut self, query: &str) -> Result; @@ -115,6 +128,22 @@ impl GenericClient for Client { self.query_raw(query, params) } + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params) + } + + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params) + } + fn prepare(&mut self, query: &str) -> Result { self.prepare(query) } @@ -195,6 +224,22 @@ impl GenericClient for Transaction<'_> { self.query_raw(query, params) } + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params) + } + + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params) + } + fn prepare(&mut self, query: &str) -> Result { self.prepare(query) } diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 17c49c406..5c8c15973 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -115,6 +115,35 @@ impl<'a> Transaction<'a> { Ok(RowIter::new(self.connection.as_ref(), stream)) } + /// Like `Client::query_typed`. + pub fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_typed(statement, params), + ) + } + + /// Like `Client::query_typed_raw`. + pub fn query_typed_raw(&mut self, query: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator, + { + let stream = self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_typed_raw(query, params), + )?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + /// Binds parameters to a statement, creating a "portal". /// /// Portals can be used with the `query_portal` method to page through the results of a query without being forced diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 2b29351a5..b04f05f88 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -333,7 +333,6 @@ impl Client { /// /// ```no_run /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { - /// use tokio_postgres::types::ToSql; /// use futures_util::{pin_mut, TryStreamExt}; /// /// let params: Vec = vec![ @@ -373,43 +372,59 @@ impl Client { /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the /// parameter of the list provided, 1-indexed. + pub async fn query_typed( + &self, + query: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed_raw(query, params.iter().map(|(v, t)| (*v, t.clone()))) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query_typed`]. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + /// + /// [`query_typed`]: #method.query_typed /// /// # Examples /// /// ```no_run /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { - /// use tokio_postgres::types::ToSql; - /// use tokio_postgres::types::Type; /// use futures_util::{pin_mut, TryStreamExt}; + /// use tokio_postgres::types::Type; /// - /// let rows = client.query_typed( + /// let params: Vec<(String, Type)> = vec![ + /// ("first param".into(), Type::TEXT), + /// ("second param".into(), Type::TEXT), + /// ]; + /// let mut it = client.query_typed_raw( /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", - /// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)], + /// params, /// ).await?; /// - /// for row in rows { - /// let foo: i32 = row.get("foo"); - /// println!("foo: {}", foo); + /// pin_mut!(it); + /// while let Some(row) = it.try_next().await? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); /// } /// # Ok(()) /// # } /// ``` - pub async fn query_typed( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result, Error> { - fn slice_iter<'a>( - s: &'a [(&'a (dyn ToSql + Sync), Type)], - ) -> impl ExactSizeIterator + 'a { - s.iter() - .map(|(param, param_type)| (*param as _, param_type.clone())) - } - - query::query_typed(&self.inner, statement, slice_iter(params)) - .await? - .try_collect() - .await + pub async fn query_typed_raw(&self, query: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator, + { + query::query_typed(&self.inner, query, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index b91d78064..6e7dffeb1 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -63,6 +63,12 @@ pub trait GenericClient: private::Sealed { params: &[(&(dyn ToSql + Sync), Type)], ) -> Result, Error>; + /// Like [`Client::query_typed_raw`] + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send; + /// Like [`Client::prepare`]. async fn prepare(&self, query: &str) -> Result; @@ -154,6 +160,14 @@ impl GenericClient for Client { self.query_typed(statement, params).await } + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -252,6 +266,14 @@ impl GenericClient for Transaction<'_> { self.query_typed(statement, params).await } + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e304bbaea..be42d66b6 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -69,29 +69,21 @@ pub async fn query_typed<'a, P, I>( where P: BorrowToSql, I: IntoIterator, - I::IntoIter: ExactSizeIterator, { - let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip(); - - let params = params.into_iter(); - - let param_oids = param_types.iter().map(|t| t.oid()).collect::>(); - - let params = params.into_iter(); - - let buf = client.with_buf(|buf| { - frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?; - - encode_bind_with_statement_name_and_param_types("", ¶m_types, params, "", buf)?; - - frontend::describe(b'S', "", buf).map_err(Error::encode)?; - - frontend::execute("", 0, buf).map_err(Error::encode)?; + let buf = { + let params = params.into_iter().collect::>(); + let param_oids = params.iter().map(|(_, t)| t.oid()).collect::>(); - frontend::sync(buf); + client.with_buf(|buf| { + frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?; + encode_bind_raw("", params, "", buf)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); - Ok(buf.split().freeze()) - })?; + Ok(buf.split().freeze()) + })? + }; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; @@ -233,47 +225,42 @@ where I: IntoIterator, I::IntoIter: ExactSizeIterator, { - encode_bind_with_statement_name_and_param_types( + let params = params.into_iter(); + if params.len() != statement.params().len() { + return Err(Error::parameters(params.len(), statement.params().len())); + } + + encode_bind_raw( statement.name(), - statement.params(), - params, + params.zip(statement.params().iter().cloned()), portal, buf, ) } -fn encode_bind_with_statement_name_and_param_types( +fn encode_bind_raw( statement_name: &str, - param_types: &[Type], params: I, portal: &str, buf: &mut BytesMut, ) -> Result<(), Error> where P: BorrowToSql, - I: IntoIterator, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { - let params = params.into_iter(); - - if param_types.len() != params.len() { - return Err(Error::parameters(params.len(), param_types.len())); - } - let (param_formats, params): (Vec<_>, Vec<_>) = params - .zip(param_types.iter()) - .map(|(p, ty)| (p.borrow_to_sql().encode_format(ty) as i16, p)) + .into_iter() + .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty))) .unzip(); - let params = params.into_iter(); - let mut error_idx = 0; let r = frontend::bind( portal, statement_name, param_formats, - params.zip(param_types).enumerate(), - |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) { + params.into_iter().enumerate(), + |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) { Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Err(e) => { diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 3e62b2ac7..17a50b60f 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,24 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_typed`. + pub async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.client.query_typed(statement, params).await + } + + /// Like `Client::query_typed_raw`. + pub async fn query_typed_raw(&self, query: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator, + { + self.client.query_typed_raw(query, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, @@ -227,15 +245,6 @@ impl<'a> Transaction<'a> { query::query_portal(self.client.inner(), portal, max_rows).await } - /// Like `Client::query_typed`. - pub async fn query_typed( - &self, - statement: &str, - params: &[(&(dyn ToSql + Sync), Type)], - ) -> Result, Error> { - self.client.query_typed(statement, params).await - } - /// Like `Client::copy_in`. pub async fn copy_in(&self, statement: &T) -> Result, Error> where From a0b2d701ebee8fd5c5b3d6ee5cf0cde5d7f36a65 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 21 Jul 2024 20:04:35 -0400 Subject: [PATCH 101/126] Fix cancellation of TransactionBuilder::start --- tokio-postgres/src/client.rs | 42 ++--------------------- tokio-postgres/src/transaction_builder.rs | 40 +++++++++++++++++++-- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index b04f05f88..92eabde36 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,4 +1,4 @@ -use crate::codec::{BackendMessages, FrontendMessage}; +use crate::codec::BackendMessages; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; @@ -21,7 +21,7 @@ use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use postgres_protocol::message::{backend::Message, frontend}; +use postgres_protocol::message::backend::Message; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; @@ -532,43 +532,7 @@ impl Client { /// /// The transaction will roll back by default - use the `commit` method to commit it. pub async fn transaction(&mut self) -> Result, Error> { - struct RollbackIfNotDone<'me> { - client: &'me Client, - done: bool, - } - - impl<'a> Drop for RollbackIfNotDone<'a> { - fn drop(&mut self) { - if self.done { - return; - } - - let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); - buf.split().freeze() - }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); - } - } - - // This is done, as `Future` created by this method can be dropped after - // `RequestMessages` is synchronously send to the `Connection` by - // `batch_execute()`, but before `Responses` is asynchronously polled to - // completion. In that case `Transaction` won't be created and thus - // won't be rolled back. - { - let mut cleaner = RollbackIfNotDone { - client: self, - done: false, - }; - self.batch_execute("BEGIN").await?; - cleaner.done = true; - } - - Ok(Transaction::new(self)) + self.build_transaction().start().await } /// Returns a builder for a transaction with custom settings. diff --git a/tokio-postgres/src/transaction_builder.rs b/tokio-postgres/src/transaction_builder.rs index 9718ac588..93e9e9801 100644 --- a/tokio-postgres/src/transaction_builder.rs +++ b/tokio-postgres/src/transaction_builder.rs @@ -1,4 +1,6 @@ -use crate::{Client, Error, Transaction}; +use postgres_protocol::message::frontend; + +use crate::{codec::FrontendMessage, connection::RequestMessages, Client, Error, Transaction}; /// The isolation level of a database transaction. #[derive(Debug, Copy, Clone)] @@ -106,7 +108,41 @@ impl<'a> TransactionBuilder<'a> { query.push_str(s); } - self.client.batch_execute(&query).await?; + struct RollbackIfNotDone<'me> { + client: &'me Client, + done: bool, + } + + impl<'a> Drop for RollbackIfNotDone<'a> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } + + // This is done as `Future` created by this method can be dropped after + // `RequestMessages` is synchronously send to the `Connection` by + // `batch_execute()`, but before `Responses` is asynchronously polled to + // completion. In that case `Transaction` won't be created and thus + // won't be rolled back. + { + let mut cleaner = RollbackIfNotDone { + client: self.client, + done: false, + }; + self.client.batch_execute(&query).await?; + cleaner.done = true; + } Ok(Transaction::new(self.client)) } From c3580774fcdc4597dac81e1128ef8bef1e6ff3a7 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 21 Jul 2024 20:23:50 -0400 Subject: [PATCH 102/126] Release postgres-protocol v0.6.7 --- postgres-protocol/CHANGELOG.md | 17 ++++++++++++++++- postgres-protocol/Cargo.toml | 2 +- postgres-types/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/postgres-protocol/CHANGELOG.md b/postgres-protocol/CHANGELOG.md index 1c371675c..54dce91b0 100644 --- a/postgres-protocol/CHANGELOG.md +++ b/postgres-protocol/CHANGELOG.md @@ -1,6 +1,21 @@ # Change Log -## v0.6.6 -2023-08-19 +## v0.6.7 - 2024-07-21 + +### Deprecated + +* Deprecated `ErrorField::value`. + +### Added + +* Added a `Clone` implementation for `DataRowBody`. +* Added `ErrorField::value_bytes`. + +### Changed + +* Upgraded `base64`. + +## v0.6.6 - 2023-08-19 ### Added diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index a8a130495..49cf2d59c 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-protocol" -version = "0.6.6" +version = "0.6.7" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 33296db2c..984fd186f 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -31,7 +31,7 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 2e080cfb2..92f4ee696 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -54,7 +54,7 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } From 6b4566b132ca4a81c06eaf35eb63318a69360f48 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 21 Jul 2024 20:28:22 -0400 Subject: [PATCH 103/126] Release postgres-types v0.2.7 --- postgres-types/CHANGELOG.md | 8 ++++++++ postgres-types/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md index 157a2cc7d..1e5cd31d8 100644 --- a/postgres-types/CHANGELOG.md +++ b/postgres-types/CHANGELOG.md @@ -2,9 +2,17 @@ ## Unreleased +## v0.2.7 - 2024-07-21 + +### Added + +* Added `Default` implementation for `Json`. +* Added a `js` feature for WASM compatibility. + ### Changed * `FromStr` implementation for `PgLsn` no longer allocates a `Vec` when splitting an lsn string on it's `/`. +* The `eui48-1` feature no longer enables default features of the `eui48` library. ## v0.2.6 - 2023-08-19 diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 984fd186f..e2d21b358 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-types" -version = "0.2.6" +version = "0.2.7" authors = ["Steven Fackler "] edition = "2018" license = "MIT OR Apache-2.0" diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 92f4ee696..f762b1184 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -55,7 +55,7 @@ percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } -postgres-types = { version = "0.2.5", path = "../postgres-types" } +postgres-types = { version = "0.2.7", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" From 92266188e8fd081be8e29d425b9fd334d2039196 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 21 Jul 2024 20:36:18 -0400 Subject: [PATCH 104/126] Release tokio-postgres v0.7.11 --- postgres-native-tls/Cargo.toml | 2 +- postgres-openssl/Cargo.toml | 2 +- postgres/Cargo.toml | 2 +- tokio-postgres/CHANGELOG.md | 24 +++++++++++++++++++++--- tokio-postgres/Cargo.toml | 2 +- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index 936eeeaa4..6c17d0889 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -19,7 +19,7 @@ runtime = ["tokio-postgres/runtime"] native-tls = "0.2" tokio = "1.0" tokio-native-tls = "0.3" -tokio-postgres = { version = "0.7.0", path = "../tokio-postgres", default-features = false } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } [dev-dependencies] futures-util = "0.3" diff --git a/postgres-openssl/Cargo.toml b/postgres-openssl/Cargo.toml index b7ebd3385..7c19070bf 100644 --- a/postgres-openssl/Cargo.toml +++ b/postgres-openssl/Cargo.toml @@ -19,7 +19,7 @@ runtime = ["tokio-postgres/runtime"] openssl = "0.10" tokio = "1.0" tokio-openssl = "0.6" -tokio-postgres = { version = "0.7.0", path = "../tokio-postgres", default-features = false } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } [dev-dependencies] futures-util = "0.3" diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 2ff3c875e..f1dc3c685 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -40,7 +40,7 @@ bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } log = "0.4" -tokio-postgres = { version = "0.7.10", path = "../tokio-postgres" } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres" } tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 775c22e34..e0be26296 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -2,10 +2,28 @@ ## Unreleased +## v0.7.11 - 2024-07-21 + +### Fixed + +* Fixed handling of non-UTF8 error fields which can be sent after failed handshakes. +* Fixed cancellation handling of `TransactionBuilder::start` futures. + +### Added + +* Added `table_oid` and `field_id` fields to `Columns` struct of prepared statements. +* Added `GenericClient::simple_query`. +* Added `#[track_caller]` to `Row::get` and `SimpleQueryRow::get`. +* Added `TargetSessionAttrs::ReadOnly`. +* Added `Debug` implementation for `Statement`. +* Added `Clone` implementation for `Row`. +* Added `SimpleQueryMessage::RowDescription`. +* Added `{Client, Transaction, GenericClient}::query_typed`. + +### Changed + * Disable `rustc-serialize` compatibility of `eui48-1` dependency -* Remove tests for `eui48-04` -* Add `table_oid` and `field_id` fields to `Columns` struct of prepared statements. -* Add `GenericClient::simple_query`. +* Config setters now take `impl Into`. ## v0.7.10 - 2023-08-25 diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index f762b1184..c2f80dc7e 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.10" +version = "0.7.11" authors = ["Steven Fackler "] edition = "2018" license = "MIT OR Apache-2.0" From 9f196e7f5ba6067efe55f758d743cdfd9b606cff Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 21 Jul 2024 20:38:52 -0400 Subject: [PATCH 105/126] Release postgres v0.19.8 --- postgres-native-tls/Cargo.toml | 2 +- postgres-openssl/Cargo.toml | 2 +- postgres/CHANGELOG.md | 6 ++++++ postgres/Cargo.toml | 2 +- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index 6c17d0889..02259b3dc 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -24,4 +24,4 @@ tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-featu [dev-dependencies] futures-util = "0.3" tokio = { version = "1.0", features = ["macros", "net", "rt"] } -postgres = { version = "0.19.0", path = "../postgres" } +postgres = { version = "0.19.8", path = "../postgres" } diff --git a/postgres-openssl/Cargo.toml b/postgres-openssl/Cargo.toml index 7c19070bf..9013384a2 100644 --- a/postgres-openssl/Cargo.toml +++ b/postgres-openssl/Cargo.toml @@ -24,4 +24,4 @@ tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-featu [dev-dependencies] futures-util = "0.3" tokio = { version = "1.0", features = ["macros", "net", "rt"] } -postgres = { version = "0.19.0", path = "../postgres" } +postgres = { version = "0.19.8", path = "../postgres" } diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index 7f856b5ac..258cdb518 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.19.8 - 2024-07-21 + +### Added + +* Added `{Client, Transaction, GenericClient}::query_typed`. + ## v0.19.7 - 2023-08-25 ## Fixed diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index f1dc3c685..ff95c4f14 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres" -version = "0.19.7" +version = "0.19.8" authors = ["Steven Fackler "] edition = "2018" license = "MIT OR Apache-2.0" From 0fc4005ed31e3705a04cb7e58eb220d89b922dd0 Mon Sep 17 00:00:00 2001 From: Ramnivas Laddad Date: Mon, 22 Jul 2024 15:07:44 -0700 Subject: [PATCH 106/126] For `query_typed`, deal with the no-data case. If a query returns no data, we receive `Message::NoData`, which signals the completion of the query. However, we treated it as a no-op, leading to processing other messages and eventual failure. This PR fixes the issue and updates the `query_typed` tests to cover this scenario. --- tokio-postgres/src/query.rs | 13 +++++++++---- tokio-postgres/tests/test/main.rs | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index be42d66b6..3ab002871 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -89,10 +89,15 @@ where loop { match responses.next().await? { - Message::ParseComplete - | Message::BindComplete - | Message::ParameterDescription(_) - | Message::NoData => {} + Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {} + Message::NoData => { + return Ok(RowStream { + statement: Statement::unnamed(vec![], vec![]), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } Message::RowDescription(row_description) => { let mut columns: Vec = vec![]; let mut it = row_description.fields(); diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 84c46d101..9a6aa26fe 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -997,6 +997,13 @@ async fn query_typed_no_transaction() { assert_eq!(second_row.get::<_, i32>(1), 40); assert_eq!(second_row.get::<_, &str>(2), "literal"); assert_eq!(second_row.get::<_, i32>(3), 5); + + // Test for UPDATE that returns no data + let updated_rows = client + .query_typed("UPDATE foo set age = 33", &[]) + .await + .unwrap(); + assert_eq!(updated_rows.len(), 0); } #[tokio::test] @@ -1064,4 +1071,11 @@ async fn query_typed_with_transaction() { assert_eq!(second_row.get::<_, i32>(1), 40); assert_eq!(second_row.get::<_, &str>(2), "literal"); assert_eq!(second_row.get::<_, i32>(3), 5); + + // Test for UPDATE that returns no data + let updated_rows = transaction + .query_typed("UPDATE foo set age = 33", &[]) + .await + .unwrap(); + assert_eq!(updated_rows.len(), 0); } From aa10f0d75cb23757c9a87fe58363e4e26ae19d1e Mon Sep 17 00:00:00 2001 From: Qiu Chaofan Date: Tue, 23 Jul 2024 13:36:51 +0800 Subject: [PATCH 107/126] Support AIX keepalive --- tokio-postgres/src/keepalive.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index c409eb0ea..7bdd76341 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -12,12 +12,18 @@ impl From<&KeepaliveConfig> for TcpKeepalive { fn from(keepalive_config: &KeepaliveConfig) -> Self { let mut tcp_keepalive = Self::new().with_time(keepalive_config.idle); - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "openbsd")))] + #[cfg(not(any( + target_os = "aix", + target_os = "redox", + target_os = "solaris", + target_os = "openbsd" + )))] if let Some(interval) = keepalive_config.interval { tcp_keepalive = tcp_keepalive.with_interval(interval); } #[cfg(not(any( + target_os = "aix", target_os = "redox", target_os = "solaris", target_os = "windows", From 6e68c100ca6ecf65fff03b0013ffb7850e66a86e Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Wed, 21 Aug 2024 16:12:27 +0300 Subject: [PATCH 108/126] revert fork patches --- .github/workflows/ci.yml | 4 +- README.md | 64 +- docker/sql_setup.sh | 2 - postgres-protocol/src/message/backend.rs | 765 +------------------ postgres-protocol/src/types/test.rs | 12 +- postgres-types/src/chrono_04.rs | 4 +- postgres/src/config.rs | 78 +- tokio-postgres/Cargo.toml | 3 +- tokio-postgres/src/cancel_query.rs | 16 +- tokio-postgres/src/cancel_query_raw.rs | 3 +- tokio-postgres/src/cancel_token.rs | 1 - tokio-postgres/src/client.rs | 50 +- tokio-postgres/src/config.rs | 279 +------ tokio-postgres/src/connect.rs | 126 +-- tokio-postgres/src/connect_raw.rs | 11 +- tokio-postgres/src/connect_socket.rs | 65 +- tokio-postgres/src/connect_tls.rs | 16 +- tokio-postgres/src/connection.rs | 20 - tokio-postgres/src/copy_both.rs | 248 ------ tokio-postgres/src/copy_in.rs | 38 +- tokio-postgres/src/copy_out.rs | 25 +- tokio-postgres/src/lib.rs | 3 - tokio-postgres/src/replication.rs | 173 ----- tokio-postgres/src/simple_query.rs | 2 +- tokio-postgres/tests/test/main.rs | 6 +- tokio-postgres/tests/test/replication.rs | 149 ---- tokio-postgres/tests/test/runtime.rs | 52 -- tokio-postgres/tests/test/types/chrono_04.rs | 12 +- 28 files changed, 166 insertions(+), 2061 deletions(-) delete mode 100644 tokio-postgres/src/copy_both.rs delete mode 100644 tokio-postgres/src/replication.rs delete mode 100644 tokio-postgres/tests/test/replication.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0ae551dc..8044b2f47 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master - + clippy: name: clippy runs-on: ubuntu-latest @@ -55,7 +55,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.77.0 + version: 1.64.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 diff --git a/README.md b/README.md index 7c2cc04f8..b81a6716f 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,46 @@ -# Materialize fork of Rust-Postgres +# Rust-Postgres -This repo serves as a staging area for Materialize patches to the -[rust-postgres] client before they are accepted upstream. +PostgreSQL support for Rust. -There are no releases from this fork. The [MaterializeInc/materialize] -repository simply pins a recent commit from the `master` branch. Other projects -are welcome to do the same. The `master` branch is never force pushed. Upstream -changes are periodically into `master` via `git merge`. +## postgres [![Latest Version](https://img.shields.io/crates/v/postgres.svg)](https://crates.io/crates/postgres) -## Adding a new patch +[Documentation](https://docs.rs/postgres) -Develop your patch against the master branch of the upstream [rust-postgres] -project. Open a PR with your changes. If your PR is not merged quickly, open the -same PR against this repository and request a review from a Materialize -engineer. +A native, synchronous PostgreSQL client. -The long-term goal is to get every patch merged upstream. +## tokio-postgres [![Latest Version](https://img.shields.io/crates/v/tokio-postgres.svg)](https://crates.io/crates/tokio-postgres) -## Integrating upstream changes +[Documentation](https://docs.rs/tokio-postgres) -```shell -git clone https://github.com/MaterializeInc/rust-postgres.git -git remote add upstream https://github.com/sfackler/rust-postgres.git -git checkout master -git pull -git checkout -b integrate-upstream -git fetch upstream -git merge upstream/master -# Resolve any conflicts, then open a PR against this repository with the merge commit. -``` +A native, asynchronous PostgreSQL client. -[rust-postgres]: https://github.com/sfackler/rust-postgres -[MaterializeInc/materialize]: https://github.com/MaterializeInc/materialize +## postgres-types [![Latest Version](https://img.shields.io/crates/v/postgres-types.svg)](https://crates.io/crates/postgres-types) + +[Documentation](https://docs.rs/postgres-types) + +Conversions between Rust and Postgres types. + +## postgres-native-tls [![Latest Version](https://img.shields.io/crates/v/postgres-native-tls.svg)](https://crates.io/crates/postgres-native-tls) + +[Documentation](https://docs.rs/postgres-native-tls) + +TLS support for postgres and tokio-postgres via native-tls. + +## postgres-openssl [![Latest Version](https://img.shields.io/crates/v/postgres-openssl.svg)](https://crates.io/crates/postgres-openssl) + +[Documentation](https://docs.rs/postgres-openssl) + +TLS support for postgres and tokio-postgres via openssl. + +# Running test suite + +The test suite requires postgres to be running in the correct configuration. The easiest way to do this is with docker: + +1. Install `docker` and `docker-compose`. + 1. On ubuntu: `sudo apt install docker.io docker-compose`. +1. Make sure your user has permissions for docker. + 1. On ubuntu: ``sudo usermod -aG docker $USER`` +1. Change to top-level directory of `rust-postgres` repo. +1. Run `docker-compose up -d`. +1. Run `cargo test`. +1. Run `docker-compose stop`. diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 051a12000..0315ac805 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,7 +64,6 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' -wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -83,7 +82,6 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust -host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index fcfcfd260..1b5be1098 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -9,9 +9,8 @@ use std::io::{self, Read}; use std::ops::Range; use std::str; -use crate::{Lsn, Oid}; +use crate::Oid; -// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -23,7 +22,6 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; -pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -35,33 +33,6 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; -// replication message tags -pub const XLOG_DATA_TAG: u8 = b'w'; -pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; - -// logical replication message tags -const BEGIN_TAG: u8 = b'B'; -const COMMIT_TAG: u8 = b'C'; -const ORIGIN_TAG: u8 = b'O'; -const RELATION_TAG: u8 = b'R'; -const TYPE_TAG: u8 = b'Y'; -const INSERT_TAG: u8 = b'I'; -const UPDATE_TAG: u8 = b'U'; -const DELETE_TAG: u8 = b'D'; -const TRUNCATE_TAG: u8 = b'T'; -const TUPLE_NEW_TAG: u8 = b'N'; -const TUPLE_KEY_TAG: u8 = b'K'; -const TUPLE_OLD_TAG: u8 = b'O'; -const TUPLE_DATA_NULL_TAG: u8 = b'n'; -const TUPLE_DATA_TOAST_TAG: u8 = b'u'; -const TUPLE_DATA_TEXT_TAG: u8 = b't'; - -// replica identity tags -const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; -const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; -const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; -const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; - #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -122,7 +93,6 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), - CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -220,16 +190,6 @@ impl Message { storage, }) } - COPY_BOTH_RESPONSE_TAG => { - let format = buf.read_u8()?; - let len = buf.read_u16::()?; - let storage = buf.read_all(); - Message::CopyBothResponse(CopyBothResponseBody { - format, - len, - storage, - }) - } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -318,59 +278,6 @@ impl Message { } } -/// An enum representing Postgres backend replication messages. -#[non_exhaustive] -#[derive(Debug)] -pub enum ReplicationMessage { - XLogData(XLogDataBody), - PrimaryKeepAlive(PrimaryKeepAliveBody), -} - -impl ReplicationMessage { - #[inline] - pub fn parse(buf: &Bytes) -> io::Result { - let mut buf = Buffer { - bytes: buf.clone(), - idx: 0, - }; - - let tag = buf.read_u8()?; - - let replication_message = match tag { - XLOG_DATA_TAG => { - let wal_start = buf.read_u64::()?; - let wal_end = buf.read_u64::()?; - let timestamp = buf.read_i64::()?; - let data = buf.read_all(); - ReplicationMessage::XLogData(XLogDataBody { - wal_start, - wal_end, - timestamp, - data, - }) - } - PRIMARY_KEEPALIVE_TAG => { - let wal_end = buf.read_u64::()?; - let timestamp = buf.read_i64::()?; - let reply = buf.read_u8()?; - ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { - wal_end, - timestamp, - reply, - }) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replication message tag `{}`", tag), - )); - } - }; - - Ok(replication_message) - } -} - struct Buffer { bytes: Bytes, idx: usize, @@ -617,27 +524,6 @@ impl CopyOutResponseBody { } } -pub struct CopyBothResponseBody { - storage: Bytes, - len: u16, - format: u8, -} - -impl CopyBothResponseBody { - #[inline] - pub fn format(&self) -> u8 { - self.format - } - - #[inline] - pub fn column_formats(&self) -> ColumnFormats<'_> { - ColumnFormats { - remaining: self.len, - buf: &self.storage, - } - } -} - #[derive(Debug)] pub struct DataRowBody { storage: Bytes, @@ -896,655 +782,6 @@ impl RowDescriptionBody { } } -#[derive(Debug)] -pub struct XLogDataBody { - wal_start: u64, - wal_end: u64, - timestamp: i64, - data: D, -} - -impl XLogDataBody { - #[inline] - pub fn wal_start(&self) -> u64 { - self.wal_start - } - - #[inline] - pub fn wal_end(&self) -> u64 { - self.wal_end - } - - #[inline] - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - pub fn data(&self) -> &D { - &self.data - } - - #[inline] - pub fn into_data(self) -> D { - self.data - } - - pub fn map_data(self, f: F) -> Result, E> - where - F: Fn(D) -> Result, - { - let data = f(self.data)?; - Ok(XLogDataBody { - wal_start: self.wal_start, - wal_end: self.wal_end, - timestamp: self.timestamp, - data, - }) - } -} - -#[derive(Debug)] -pub struct PrimaryKeepAliveBody { - wal_end: u64, - timestamp: i64, - reply: u8, -} - -impl PrimaryKeepAliveBody { - #[inline] - pub fn wal_end(&self) -> u64 { - self.wal_end - } - - #[inline] - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - pub fn reply(&self) -> u8 { - self.reply - } -} - -#[non_exhaustive] -/// A message of the logical replication stream -#[derive(Debug)] -pub enum LogicalReplicationMessage { - /// A BEGIN statement - Begin(BeginBody), - /// A BEGIN statement - Commit(CommitBody), - /// An Origin replication message - /// Note that there can be multiple Origin messages inside a single transaction. - Origin(OriginBody), - /// A Relation replication message - Relation(RelationBody), - /// A Type replication message - Type(TypeBody), - /// An INSERT statement - Insert(InsertBody), - /// An UPDATE statement - Update(UpdateBody), - /// A DELETE statement - Delete(DeleteBody), - /// A TRUNCATE statement - Truncate(TruncateBody), -} - -impl LogicalReplicationMessage { - pub fn parse(buf: &Bytes) -> io::Result { - let mut buf = Buffer { - bytes: buf.clone(), - idx: 0, - }; - - let tag = buf.read_u8()?; - - let logical_replication_message = match tag { - BEGIN_TAG => Self::Begin(BeginBody { - final_lsn: buf.read_u64::()?, - timestamp: buf.read_i64::()?, - xid: buf.read_u32::()?, - }), - COMMIT_TAG => Self::Commit(CommitBody { - flags: buf.read_i8()?, - commit_lsn: buf.read_u64::()?, - end_lsn: buf.read_u64::()?, - timestamp: buf.read_i64::()?, - }), - ORIGIN_TAG => Self::Origin(OriginBody { - commit_lsn: buf.read_u64::()?, - name: buf.read_cstr()?, - }), - RELATION_TAG => { - let rel_id = buf.read_u32::()?; - let namespace = buf.read_cstr()?; - let name = buf.read_cstr()?; - let replica_identity = match buf.read_u8()? { - REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, - REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, - REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, - REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replica identity tag `{}`", tag), - )); - } - }; - let column_len = buf.read_i16::()?; - - let mut columns = Vec::with_capacity(column_len as usize); - for _ in 0..column_len { - columns.push(Column::parse(&mut buf)?); - } - - Self::Relation(RelationBody { - rel_id, - namespace, - name, - replica_identity, - columns, - }) - } - TYPE_TAG => Self::Type(TypeBody { - id: buf.read_u32::()?, - namespace: buf.read_cstr()?, - name: buf.read_cstr()?, - }), - INSERT_TAG => { - let rel_id = buf.read_u32::()?; - let tag = buf.read_u8()?; - - let tuple = match tag { - TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unexpected tuple tag `{}`", tag), - )); - } - }; - - Self::Insert(InsertBody { rel_id, tuple }) - } - UPDATE_TAG => { - let rel_id = buf.read_u32::()?; - let tag = buf.read_u8()?; - - let mut key_tuple = None; - let mut old_tuple = None; - - let new_tuple = match tag { - TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, - TUPLE_OLD_TAG | TUPLE_KEY_TAG => { - if tag == TUPLE_OLD_TAG { - old_tuple = Some(Tuple::parse(&mut buf)?); - } else { - key_tuple = Some(Tuple::parse(&mut buf)?); - } - - match buf.read_u8()? { - TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unexpected tuple tag `{}`", tag), - )); - } - } - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown tuple tag `{}`", tag), - )); - } - }; - - Self::Update(UpdateBody { - rel_id, - key_tuple, - old_tuple, - new_tuple, - }) - } - DELETE_TAG => { - let rel_id = buf.read_u32::()?; - let tag = buf.read_u8()?; - - let mut key_tuple = None; - let mut old_tuple = None; - - match tag { - TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), - TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown tuple tag `{}`", tag), - )); - } - } - - Self::Delete(DeleteBody { - rel_id, - key_tuple, - old_tuple, - }) - } - TRUNCATE_TAG => { - let relation_len = buf.read_i32::()?; - let options = buf.read_i8()?; - - let mut rel_ids = Vec::with_capacity(relation_len as usize); - for _ in 0..relation_len { - rel_ids.push(buf.read_u32::()?); - } - - Self::Truncate(TruncateBody { options, rel_ids }) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replication message tag `{}`", tag), - )); - } - }; - - Ok(logical_replication_message) - } -} - -/// A row as it appears in the replication stream -#[derive(Debug)] -pub struct Tuple(Vec); - -impl Tuple { - #[inline] - /// The tuple data of this tuple - pub fn tuple_data(&self) -> &[TupleData] { - &self.0 - } -} - -impl Tuple { - fn parse(buf: &mut Buffer) -> io::Result { - let col_len = buf.read_i16::()?; - let mut tuple = Vec::with_capacity(col_len as usize); - for _ in 0..col_len { - tuple.push(TupleData::parse(buf)?); - } - - Ok(Tuple(tuple)) - } -} - -/// A column as it appears in the replication stream -#[derive(Debug)] -pub struct Column { - flags: i8, - name: Bytes, - type_id: i32, - type_modifier: i32, -} - -impl Column { - #[inline] - /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as - /// part of the key. - pub fn flags(&self) -> i8 { - self.flags - } - - #[inline] - /// Name of the column. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } - - #[inline] - /// ID of the column's data type. - pub fn type_id(&self) -> i32 { - self.type_id - } - - #[inline] - /// Type modifier of the column (`atttypmod`). - pub fn type_modifier(&self) -> i32 { - self.type_modifier - } -} - -impl Column { - fn parse(buf: &mut Buffer) -> io::Result { - Ok(Self { - flags: buf.read_i8()?, - name: buf.read_cstr()?, - type_id: buf.read_i32::()?, - type_modifier: buf.read_i32::()?, - }) - } -} - -/// The data of an individual column as it appears in the replication stream -#[derive(Debug)] -pub enum TupleData { - /// Represents a NULL value - Null, - /// Represents an unchanged TOASTed value (the actual value is not sent). - UnchangedToast, - /// Column data as text formatted value. - Text(Bytes), -} - -impl TupleData { - fn parse(buf: &mut Buffer) -> io::Result { - let type_tag = buf.read_u8()?; - - let tuple = match type_tag { - TUPLE_DATA_NULL_TAG => TupleData::Null, - TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, - TUPLE_DATA_TEXT_TAG => { - let len = buf.read_i32::()?; - let mut data = vec![0; len as usize]; - buf.read_exact(&mut data)?; - TupleData::Text(data.into()) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replication message tag `{}`", tag), - )); - } - }; - - Ok(tuple) - } -} - -/// A BEGIN statement -#[derive(Debug)] -pub struct BeginBody { - final_lsn: u64, - timestamp: i64, - xid: u32, -} - -impl BeginBody { - #[inline] - /// Gets the final lsn of the transaction - pub fn final_lsn(&self) -> Lsn { - self.final_lsn - } - - #[inline] - /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - /// Xid of the transaction. - pub fn xid(&self) -> u32 { - self.xid - } -} - -/// A COMMIT statement -#[derive(Debug)] -pub struct CommitBody { - flags: i8, - commit_lsn: u64, - end_lsn: u64, - timestamp: i64, -} - -impl CommitBody { - #[inline] - /// The LSN of the commit. - pub fn commit_lsn(&self) -> Lsn { - self.commit_lsn - } - - #[inline] - /// The end LSN of the transaction. - pub fn end_lsn(&self) -> Lsn { - self.end_lsn - } - - #[inline] - /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - /// Flags; currently unused (will be 0). - pub fn flags(&self) -> i8 { - self.flags - } -} - -/// An Origin replication message -/// -/// Note that there can be multiple Origin messages inside a single transaction. -#[derive(Debug)] -pub struct OriginBody { - commit_lsn: u64, - name: Bytes, -} - -impl OriginBody { - #[inline] - /// The LSN of the commit on the origin server. - pub fn commit_lsn(&self) -> Lsn { - self.commit_lsn - } - - #[inline] - /// Name of the origin. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } -} - -/// Describes the REPLICA IDENTITY setting of a table -#[derive(Debug)] -pub enum ReplicaIdentity { - /// default selection for replica identity (primary key or nothing) - Default, - /// no replica identity is logged for this relation - Nothing, - /// all columns are logged as replica identity - Full, - /// An explicitly chosen candidate key's columns are used as replica identity. - /// Note this will still be set if the index has been dropped; in that case it - /// has the same meaning as 'd'. - Index, -} - -/// A Relation replication message -#[derive(Debug)] -pub struct RelationBody { - rel_id: u32, - namespace: Bytes, - name: Bytes, - replica_identity: ReplicaIdentity, - columns: Vec, -} - -impl RelationBody { - #[inline] - /// ID of the relation. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// Namespace (empty string for pg_catalog). - pub fn namespace(&self) -> io::Result<&str> { - get_str(&self.namespace) - } - - #[inline] - /// Relation name. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } - - #[inline] - /// Replica identity setting for the relation - pub fn replica_identity(&self) -> &ReplicaIdentity { - &self.replica_identity - } - - #[inline] - /// The column definitions of this relation - pub fn columns(&self) -> &[Column] { - &self.columns - } -} - -/// A Type replication message -#[derive(Debug)] -pub struct TypeBody { - id: u32, - namespace: Bytes, - name: Bytes, -} - -impl TypeBody { - #[inline] - /// ID of the data type. - pub fn id(&self) -> Oid { - self.id - } - - #[inline] - /// Namespace (empty string for pg_catalog). - pub fn namespace(&self) -> io::Result<&str> { - get_str(&self.namespace) - } - - #[inline] - /// Name of the data type. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } -} - -/// An INSERT statement -#[derive(Debug)] -pub struct InsertBody { - rel_id: u32, - tuple: Tuple, -} - -impl InsertBody { - #[inline] - /// ID of the relation corresponding to the ID in the relation message. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// The inserted tuple - pub fn tuple(&self) -> &Tuple { - &self.tuple - } -} - -/// An UPDATE statement -#[derive(Debug)] -pub struct UpdateBody { - rel_id: u32, - old_tuple: Option, - key_tuple: Option, - new_tuple: Tuple, -} - -impl UpdateBody { - #[inline] - /// ID of the relation corresponding to the ID in the relation message. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// This field is optional and is only present if the update changed data in any of the - /// column(s) that are part of the REPLICA IDENTITY index. - pub fn key_tuple(&self) -> Option<&Tuple> { - self.key_tuple.as_ref() - } - - #[inline] - /// This field is optional and is only present if table in which the update happened has - /// REPLICA IDENTITY set to FULL. - pub fn old_tuple(&self) -> Option<&Tuple> { - self.old_tuple.as_ref() - } - - #[inline] - /// The new tuple - pub fn new_tuple(&self) -> &Tuple { - &self.new_tuple - } -} - -/// A DELETE statement -#[derive(Debug)] -pub struct DeleteBody { - rel_id: u32, - old_tuple: Option, - key_tuple: Option, -} - -impl DeleteBody { - #[inline] - /// ID of the relation corresponding to the ID in the relation message. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// This field is present if the table in which the delete has happened uses an index as - /// REPLICA IDENTITY. - pub fn key_tuple(&self) -> Option<&Tuple> { - self.key_tuple.as_ref() - } - - #[inline] - /// This field is present if the table in which the delete has happened has REPLICA IDENTITY - /// set to FULL. - pub fn old_tuple(&self) -> Option<&Tuple> { - self.old_tuple.as_ref() - } -} - -/// A TRUNCATE statement -#[derive(Debug)] -pub struct TruncateBody { - options: i8, - rel_ids: Vec, -} - -impl TruncateBody { - #[inline] - /// The IDs of the relations corresponding to the ID in the relation messages - pub fn rel_ids(&self) -> &[u32] { - &self.rel_ids - } - - #[inline] - /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY - pub fn options(&self) -> i8 { - self.options - } -} - pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 3e33b08f0..6f1851fc2 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -174,7 +174,7 @@ fn ltree_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(ltree_from_sql(query.as_slice()).is_ok()) + assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) } #[test] @@ -182,7 +182,7 @@ fn ltree_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(ltree_from_sql(query.as_slice()).is_err()) + assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) } #[test] @@ -202,7 +202,7 @@ fn lquery_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(lquery_from_sql(query.as_slice()).is_ok()) + assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) } #[test] @@ -210,7 +210,7 @@ fn lquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(lquery_from_sql(query.as_slice()).is_err()) + assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) } #[test] @@ -230,7 +230,7 @@ fn ltxtquery_str() { let mut query = vec![1u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(ltree_from_sql(query.as_slice()).is_ok()) + assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) } #[test] @@ -238,5 +238,5 @@ fn ltxtquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(ltree_from_sql(query.as_slice()).is_err()) + assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) } diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index b7f4f9a03..0ec92437d 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -40,7 +40,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_naive_utc_and_offset(naive, Utc)) + Ok(DateTime::from_utc(naive, Utc)) } accepts!(TIMESTAMPTZ); @@ -111,7 +111,7 @@ impl<'a> FromSql<'a> for NaiveDate { let jd = types::date_from_sql(raw)?; base() .date() - .checked_add_signed(Duration::try_days(i64::from(jd)).unwrap()) + .checked_add_signed(Duration::days(i64::from(jd))) .ok_or_else(|| "value too large to decode".into()) } diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 2705e3593..95c5ea417 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -6,7 +6,6 @@ use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; -use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -34,29 +33,12 @@ use tokio_postgres::{Error, Socket}; /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. -/// * `sslcert` - Location of the client SSL certificate file. -/// * `sslkey` - Location for the secret key file used for the client certificate. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used -/// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to -/// be used. Defaults to `prefer`. -/// * `sslrootcert` - Location of SSL certificate authority (CA) certificate. +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. -/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, -/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. -/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// - or if host specifies an IP address, that value will be used directly. -/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for verify-full SSL certificate verification. -/// Specifically: -/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. -/// The connection attempt will fail if the authentication method requires a host name; -/// * If `host` is specified without `hostaddr`, a host name lookup occurs; -/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. -/// The value for `host` is ignored unless the authentication method requires it, -/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -88,10 +70,6 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write -/// ``` -/// -/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -212,32 +190,6 @@ impl Config { self.config.get_application_name() } - /// Sets the client SSL certificate in PEM format. - /// - /// Defaults to `None`. - pub fn ssl_cert(&mut self, ssl_cert: &[u8]) -> &mut Config { - self.config.ssl_cert(ssl_cert); - self - } - - /// Gets the location of the client SSL certificate in PEM format. - pub fn get_ssl_cert(&self) -> Option<&[u8]> { - self.config.get_ssl_cert() - } - - /// Sets the client SSL key in PEM format. - /// - /// Defaults to `None`. - pub fn ssl_key(&mut self, ssl_key: &[u8]) -> &mut Config { - self.config.ssl_key(ssl_key); - self - } - - /// Gets the client SSL key in PEM format. - pub fn get_ssl_key(&self) -> Option<&[u8]> { - self.config.get_ssl_key() - } - /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -251,24 +203,10 @@ impl Config { self.config.get_ssl_mode() } - /// Sets the SSL certificate authority (CA) certificate in PEM format. - /// - /// Defaults to `None`. - pub fn ssl_root_cert(&mut self, ssl_root_cert: &[u8]) -> &mut Config { - self.config.ssl_root_cert(ssl_root_cert); - self - } - - /// Gets the SSL certificate authority (CA) certificate in PEM format. - pub fn get_ssl_root_cert(&self) -> Option<&[u8]> { - self.config.get_ssl_root_cert() - } - /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. - /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -279,11 +217,6 @@ impl Config { self.config.get_hosts() } - /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. - pub fn get_hostaddrs(&self) -> &[IpAddr] { - self.config.get_hostaddrs() - } - /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -296,15 +229,6 @@ impl Config { self } - /// Adds a hostaddr to the configuration. - /// - /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. - /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. - pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { - self.config.hostaddr(hostaddr); - self - } - /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 762caa9b0..e5451e2a2 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -55,11 +55,9 @@ pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } -serde = { version = "1.0", optional = true } socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } -rand = "0.8.5" [dev-dependencies] futures-executor = "0.3" @@ -79,6 +77,7 @@ eui48-04 = { version = "0.4", package = "eui48" } eui48-1 = { version = "1.0", package = "eui48" } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } +serde-1 = { version = "1.0", package = "serde" } serde_json-1 = { version = "1.0", package = "serde_json" } smol_str-01 = { version = "0.1", package = "smol_str" } uuid-08 = { version = "0.8", package = "uuid" } diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index 078d4b8b6..d869b5824 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -1,5 +1,5 @@ use crate::client::SocketConfig; -use crate::config::SslMode; +use crate::config::{Host, SslMode}; use crate::tls::MakeTlsConnect; use crate::{cancel_query_raw, connect_socket, Error, Socket}; use std::io; @@ -24,13 +24,18 @@ where } }; + let hostname = match &config.host { + Host::Tcp(host) => &**host, + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter + #[cfg(unix)] + Host::Unix(_) => "", + }; let tls = tls - .make_tls_connect(config.hostname.as_deref().unwrap_or("")) + .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( - &config.addr, + &config.host, config.port, config.connect_timeout, config.tcp_user_timeout, @@ -38,6 +43,5 @@ where ) .await?; - cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key) - .await + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await } diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index 41aafe7d9..c89dc581f 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -9,7 +9,6 @@ pub async fn cancel_query_raw( stream: S, mode: SslMode, tls: T, - has_hostname: bool, process_id: i32, secret_key: i32, ) -> Result<(), Error> @@ -17,7 +16,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?; + let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index c925ce0ca..d048a3c82 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -55,7 +55,6 @@ impl CancelToken { stream, self.ssl_mode, tls, - true, self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 2cc4256c4..8b7df4e87 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,7 +1,8 @@ use crate::codec::{BackendMessages, FrontendMessage}; +#[cfg(feature = "runtime")] +use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; -use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -14,9 +15,8 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, - CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, - TransactionBuilder, + copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, + Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -27,10 +27,6 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; -#[cfg(feature = "runtime")] -use std::net::IpAddr; -#[cfg(feature = "runtime")] -use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -157,22 +153,13 @@ impl InnerClient { #[cfg(feature = "runtime")] #[derive(Clone)] pub(crate) struct SocketConfig { - pub addr: Addr, - pub hostname: Option, + pub host: Host, pub port: u16, pub connect_timeout: Option, pub tcp_user_timeout: Option, pub keepalive: Option, } -#[cfg(feature = "runtime")] -#[derive(Clone)] -pub(crate) enum Addr { - Tcp(IpAddr), - #[cfg(unix)] - Unix(PathBuf), -} - /// An asynchronous PostgreSQL client. /// /// The client is one half of what is returned when a connection is established. Users interact with the database @@ -425,14 +412,6 @@ impl Client { copy_in::copy_in(self.inner(), statement).await } - /// Executes a `COPY FROM STDIN` query, returning a sink used to write the copy data. - pub async fn copy_in_simple(&self, query: &str) -> Result, Error> - where - U: Buf + 'static + Send, - { - copy_in::copy_in_simple(self.inner(), query).await - } - /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. @@ -444,20 +423,6 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } - /// Executes a `COPY TO STDOUT` query, returning a stream of the resulting data. - pub async fn copy_out_simple(&self, query: &str) -> Result { - copy_out::copy_out_simple(self.inner(), query).await - } - - /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy - /// data. - pub async fn copy_both_simple(&self, query: &str) -> Result, Error> - where - T: Buf + 'static + Send, - { - copy_both::copy_both_simple(self.inner(), query).await - } - /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that @@ -544,11 +509,6 @@ impl Client { TransactionBuilder::new(self) } - /// Returns the server's process ID for the connection. - pub fn backend_pid(&self) -> i32 { - self.process_id - } - /// Constructs a cancellation token that can later be used to request cancellation of a query running on the /// connection associated with this client. pub fn cancel_token(&self) -> CancelToken { diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index e40ed3e07..a8aa7a9f5 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -13,13 +13,10 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; -use std::net::IpAddr; -use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] -use std::path::Path; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::str; use std::str::FromStr; use std::time::Duration; @@ -37,8 +34,7 @@ pub enum TargetSessionAttrs { } /// TLS configuration. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum SslMode { /// Do not use TLS. @@ -47,10 +43,6 @@ pub enum SslMode { Prefer, /// Require the use of TLS. Require, - /// Require the use of TLS. - VerifyCa, - /// Require the use of TLS. - VerifyFull, } /// Channel binding configuration. @@ -65,26 +57,6 @@ pub enum ChannelBinding { Require, } -/// Replication mode configuration. -#[derive(Debug, Copy, Clone, PartialEq)] -#[non_exhaustive] -pub enum ReplicationMode { - /// Physical replication. - Physical, - /// Logical replication. - Logical, -} - -/// Load balancing configuration. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -#[non_exhaustive] -pub enum LoadBalanceHosts { - /// Make connection attempts to hosts in the order provided. - Disable, - /// Make connection attempts to hosts in a random order. - Random, -} - /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -113,32 +85,12 @@ pub enum Host { /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. -/// * `sslcert` - Location of the client SSL certificate file. -/// * `sslcert_inline` - The contents of the client SSL certificate. -/// * `sslkey` - Location for the secret key file used for the client certificate. -/// * `sslkey_inline` - The contents of the client SSL key. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used -/// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to -/// be used. Defaults to `prefer`. -/// * `sslrootcert` - Location of SSL certificate authority (CA) certificate. -/// * `sslrootcert_inline` - The contents of the SSL certificate authority. +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. -/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, -/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. -/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// or if host specifies an IP address, that value will be used directly. -/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for TLS certificate verification. -/// Specifically: -/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. -/// The connection attempt will fail if the authentication method requires a host name; -/// * If `host` is specified without `hostaddr`, a host name lookup occurs; -/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. -/// The value for `host` is ignored unless the authentication method requires it, -/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -161,12 +113,6 @@ pub enum Host { /// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel /// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. /// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. -/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and -/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter -/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to -/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried -/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults -/// to `disable`. /// /// ## Examples /// @@ -179,10 +125,6 @@ pub enum Host { /// ``` /// /// ```not_rust -/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write -/// ``` -/// -/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -210,19 +152,15 @@ pub enum Host { /// ```not_rust /// postgresql:///mydb?user=user&host=/var/lib/postgresql /// ``` -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Eq)] pub struct Config { pub(crate) user: Option, pub(crate) password: Option>, pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, - pub(crate) ssl_cert: Option>, - pub(crate) ssl_key: Option>, pub(crate) ssl_mode: SslMode, - pub(crate) ssl_root_cert: Option>, pub(crate) host: Vec, - pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) tcp_user_timeout: Option, @@ -230,8 +168,6 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, - pub(crate) replication_mode: Option, - pub(crate) load_balance_hosts: LoadBalanceHosts, } impl Default for Config { @@ -254,12 +190,8 @@ impl Config { dbname: None, options: None, application_name: None, - ssl_cert: None, - ssl_key: None, ssl_mode: SslMode::Prefer, - ssl_root_cert: None, host: vec![], - hostaddr: vec![], port: vec![], connect_timeout: None, tcp_user_timeout: None, @@ -267,8 +199,6 @@ impl Config { keepalive_config, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, - replication_mode: None, - load_balance_hosts: LoadBalanceHosts::Disable, } } @@ -339,32 +269,6 @@ impl Config { self.application_name.as_deref() } - /// Sets the client SSL certificate in PEM format. - /// - /// Defaults to `None`. - pub fn ssl_cert(&mut self, ssl_cert: &[u8]) -> &mut Config { - self.ssl_cert = Some(ssl_cert.into()); - self - } - - /// Gets the location of the client SSL certificate in PEM format. - pub fn get_ssl_cert(&self) -> Option<&[u8]> { - self.ssl_cert.as_deref() - } - - /// Sets the client SSL key in PEM format. - /// - /// Defaults to `None`. - pub fn ssl_key(&mut self, ssl_key: &[u8]) -> &mut Config { - self.ssl_key = Some(ssl_key.into()); - self - } - - /// Gets the client SSL key in PEM format. - pub fn get_ssl_key(&self) -> Option<&[u8]> { - self.ssl_key.as_deref() - } - /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -378,24 +282,10 @@ impl Config { self.ssl_mode } - /// Sets the SSL certificate authority (CA) certificate in PEM format. - /// - /// Defaults to `None`. - pub fn ssl_root_cert(&mut self, ssl_root_cert: &[u8]) -> &mut Config { - self.ssl_root_cert = Some(ssl_root_cert.into()); - self - } - - /// Gets the SSL certificate authority (CA) certificate in PEM format. - pub fn get_ssl_root_cert(&self) -> Option<&[u8]> { - self.ssl_root_cert.as_deref() - } - /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. - /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { @@ -413,23 +303,6 @@ impl Config { &self.host } - /// Gets a mutable view of the hosts that have been added to the - /// configuration with `host`. - pub fn get_hosts_mut(&mut self) -> &mut [Host] { - &mut self.host - } - - /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. - pub fn get_hostaddrs(&self) -> &[IpAddr] { - self.hostaddr.deref() - } - - /// Gets a mutable view of the hostaddrs that have been added to the - /// configuration with `hostaddr`. - pub fn get_hostaddrs_mut(&mut self) -> &mut [IpAddr] { - &mut self.hostaddr - } - /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -442,15 +315,6 @@ impl Config { self } - /// Adds a hostaddr to the configuration. - /// - /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. - /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. - pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { - self.hostaddr.push(hostaddr); - self - } - /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -581,30 +445,6 @@ impl Config { self.channel_binding } - /// Set replication mode. - pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { - self.replication_mode = Some(replication_mode); - self - } - - /// Get replication mode. - pub fn get_replication_mode(&self) -> Option { - self.replication_mode - } - - /// Sets the host load balancing behavior. - /// - /// Defaults to `disable`. - pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { - self.load_balance_hosts = load_balance_hosts; - self - } - - /// Gets the host load balancing behavior. - pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { - self.load_balance_hosts - } - fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -622,63 +462,20 @@ impl Config { "application_name" => { self.application_name(value); } - "sslcert" => match std::fs::read(value) { - Ok(contents) => { - self.ssl_cert(&contents); - } - Err(_) => { - return Err(Error::config_parse(Box::new(InvalidValue("sslcert")))); - } - }, - "sslcert_inline" => { - self.ssl_cert(value.as_bytes()); - } - "sslkey" => match std::fs::read(value) { - Ok(contents) => { - self.ssl_key(&contents); - } - Err(_) => { - return Err(Error::config_parse(Box::new(InvalidValue("sslkey")))); - } - }, - "sslkey_inline" => { - self.ssl_key(value.as_bytes()); - } "sslmode" => { let mode = match value { "disable" => SslMode::Disable, "prefer" => SslMode::Prefer, "require" => SslMode::Require, - "verify-ca" => SslMode::VerifyCa, - "verify-full" => SslMode::VerifyFull, _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), }; self.ssl_mode(mode); } - "sslrootcert" => match std::fs::read(value) { - Ok(contents) => { - self.ssl_root_cert(&contents); - } - Err(_) => { - return Err(Error::config_parse(Box::new(InvalidValue("sslrootcert")))); - } - }, - "sslrootcert_inline" => { - self.ssl_root_cert(value.as_bytes()); - } "host" => { for host in value.split(',') { self.host(host); } } - "hostaddr" => { - for hostaddr in value.split(',') { - let addr = hostaddr - .parse() - .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; - self.hostaddr(addr); - } - } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -759,29 +556,6 @@ impl Config { }; self.channel_binding(channel_binding); } - "replication" => { - let mode = match value { - "off" => None, - "true" => Some(ReplicationMode::Physical), - "database" => Some(ReplicationMode::Logical), - _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), - }; - if let Some(mode) = mode { - self.replication_mode(mode); - } - } - "load_balance_hosts" => { - let load_balance_hosts = match value { - "disable" => LoadBalanceHosts::Disable, - "random" => LoadBalanceHosts::Random, - _ => { - return Err(Error::config_parse(Box::new(InvalidValue( - "load_balance_hosts", - )))) - } - }; - self.load_balance_hosts(load_balance_hosts); - } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -815,7 +589,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, tls, true, self).await + connect_raw(stream, tls, self).await } } @@ -846,12 +620,8 @@ impl fmt::Debug for Config { .field("dbname", &self.dbname) .field("options", &self.options) .field("application_name", &self.application_name) - .field("ssl_cert", &self.ssl_cert) - .field("ssl_key", &self.ssl_key) .field("ssl_mode", &self.ssl_mode) - .field("ssl_root_cert", &self.ssl_root_cert) .field("host", &self.host) - .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("tcp_user_timeout", &self.tcp_user_timeout) @@ -861,7 +631,6 @@ impl fmt::Debug for Config { .field("keepalives_retries", &self.keepalive_config.retries) .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) - .field("replication", &self.replication_mode) .finish() } } @@ -1236,41 +1005,3 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } - -#[cfg(test)] -mod tests { - use std::net::IpAddr; - - use crate::{config::Host, Config}; - - #[test] - fn test_simple_parsing() { - let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; - let config = s.parse::().unwrap(); - assert_eq!(Some("pass_user"), config.get_user()); - assert_eq!(Some("postgres"), config.get_dbname()); - assert_eq!( - [ - Host::Tcp("host1".to_string()), - Host::Tcp("host2".to_string()) - ], - config.get_hosts(), - ); - - assert_eq!( - [ - "127.0.0.1".parse::().unwrap(), - "127.0.0.2".parse::().unwrap() - ], - config.get_hostaddrs(), - ); - - assert_eq!(1, 1); - } - - #[test] - fn test_invalid_hostaddr_parsing() { - let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; - s.parse::().err().unwrap(); - } -} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ca57b9cdd..ed7ecac66 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -1,14 +1,12 @@ -use crate::client::{Addr, SocketConfig}; -use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs}; +use crate::client::SocketConfig; +use crate::config::{Host, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; -use crate::tls::MakeTlsConnect; +use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use rand::seq::SliceRandom; +use std::io; use std::task::Poll; -use std::{cmp, io}; -use tokio::net; pub async fn connect( mut tls: T, @@ -17,40 +15,16 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() && config.hostaddr.is_empty() { - return Err(Error::config("both host and hostaddr are missing".into())); + if config.host.is_empty() { + return Err(Error::config("host missing".into())); } - if !config.host.is_empty() - && !config.hostaddr.is_empty() - && config.host.len() != config.hostaddr.len() - { - let msg = format!( - "number of hosts ({}) is different from number of hostaddrs ({})", - config.host.len(), - config.hostaddr.len(), - ); - return Err(Error::config(msg.into())); - } - - // At this point, either one of the following two scenarios could happen: - // (1) either config.host or config.hostaddr must be empty; - // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. - let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); - - if config.port.len() > 1 && config.port.len() != num_hosts { + if config.port.len() > 1 && config.port.len() != config.host.len() { return Err(Error::config("invalid number of ports".into())); } - let mut indices = (0..num_hosts).collect::>(); - if config.load_balance_hosts == LoadBalanceHosts::Random { - indices.shuffle(&mut rand::thread_rng()); - } - let mut error = None; - for i in indices { - let host = config.host.get(i); - let hostaddr = config.hostaddr.get(i); + for (i, host) in config.host.iter().enumerate() { let port = config .port .get(i) @@ -58,23 +32,18 @@ where .copied() .unwrap_or(5432); - // The value of host is used as the hostname for TLS validation, let hostname = match host { - Some(Host::Tcp(host)) => Some(host.clone()), + Host::Tcp(host) => host.as_str(), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] - Some(Host::Unix(_)) => None, - None => None, + Host::Unix(_) => "", }; - // Try to use the value of hostaddr to establish the TCP connection, - // fallback to host if hostaddr is not present. - let addr = match hostaddr { - Some(ipaddr) => Host::Tcp(ipaddr.to_string()), - None => host.cloned().unwrap(), - }; + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; - match connect_host(addr, hostname, port, &mut tls, config).await { + match connect_once(host, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -83,66 +52,17 @@ where Err(error.unwrap()) } -async fn connect_host( - host: Host, - hostname: Option, - port: u16, - tls: &mut T, - config: &Config, -) -> Result<(Client, Connection), Error> -where - T: MakeTlsConnect, -{ - match host { - Host::Tcp(host) => { - let mut addrs = net::lookup_host((&*host, port)) - .await - .map_err(Error::connect)? - .collect::>(); - - if config.load_balance_hosts == LoadBalanceHosts::Random { - addrs.shuffle(&mut rand::thread_rng()); - } - - let mut last_err = None; - for addr in addrs { - match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) - .await - { - Ok(stream) => return Ok(stream), - Err(e) => { - last_err = Some(e); - continue; - } - }; - } - - Err(last_err.unwrap_or_else(|| { - Error::connect(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })) - } - #[cfg(unix)] - Host::Unix(path) => { - connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await - } - } -} - async fn connect_once( - addr: Addr, - hostname: Option<&str>, + host: &Host, port: u16, - tls: &mut T, + tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where - T: MakeTlsConnect, + T: TlsConnect, { let socket = connect_socket( - &addr, + host, port, config.connect_timeout, config.tcp_user_timeout, @@ -153,12 +73,7 @@ where }, ) .await?; - - let tls = tls - .make_tls_connect(hostname.unwrap_or("")) - .map_err(|e| Error::tls(e.into()))?; - let has_hostname = hostname.is_some(); - let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; + let (mut client, mut connection) = connect_raw(socket, tls, config).await?; if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); @@ -201,8 +116,7 @@ where } client.set_socket_config(SocketConfig { - addr, - hostname: hostname.map(|s| s.to_string()), + host: host.clone(), port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 1348828ba..d97636221 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config, ReplicationMode}; +use crate::config::{self, Config}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -81,14 +81,13 @@ where pub async fn connect_raw( stream: S, tls: T, - has_hostname: bool, config: &Config, ) -> Result<(Client, Connection), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; + let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), @@ -125,12 +124,6 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } - if let Some(replication_mode) = &config.replication_mode { - match replication_mode { - ReplicationMode::Physical => params.push(("replication", "true")), - ReplicationMode::Logical => params.push(("replication", "database")), - } - } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 67add04ea..9b3d31d72 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,48 +1,69 @@ -use crate::client::Addr; +use crate::config::Host; use crate::keepalive::KeepaliveConfig; use crate::{Error, Socket}; use socket2::{SockRef, TcpKeepalive}; use std::future::Future; use std::io; use std::time::Duration; -use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +use tokio::net::{self, TcpStream}; use tokio::time; pub(crate) async fn connect_socket( - addr: &Addr, + host: &Host, port: u16, connect_timeout: Option, tcp_user_timeout: Option, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { - match addr { - Addr::Tcp(ip) => { - let stream = - connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?; + match host { + Host::Tcp(host) => { + let addrs = net::lookup_host((&**host, port)) + .await + .map_err(Error::connect)?; - stream.set_nodelay(true).map_err(Error::connect)?; + let mut last_err = None; - let sock_ref = SockRef::from(&stream); - #[cfg(target_os = "linux")] - { - sock_ref - .set_tcp_user_timeout(tcp_user_timeout) - .map_err(Error::connect)?; - } + for addr in addrs { + let stream = + match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { + Ok(stream) => stream, + Err(e) => { + last_err = Some(e); + continue; + } + }; + + stream.set_nodelay(true).map_err(Error::connect)?; + + let sock_ref = SockRef::from(&stream); + #[cfg(target_os = "linux")] + { + sock_ref + .set_tcp_user_timeout(tcp_user_timeout) + .map_err(Error::connect)?; + } + + if let Some(keepalive_config) = keepalive_config { + sock_ref + .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) + .map_err(Error::connect)?; + } - if let Some(keepalive_config) = keepalive_config { - sock_ref - .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) - .map_err(Error::connect)?; + return Ok(Socket::new_tcp(stream)); } - Ok(Socket::new_tcp(stream)) + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) } #[cfg(unix)] - Addr::Unix(dir) => { - let path = dir.join(format!(".s.PGSQL.{}", port)); + Host::Unix(path) => { + let path = path.join(format!(".s.PGSQL.{}", port)); let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?; Ok(Socket::new_unix(socket)) } diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 41b319c2b..5ef21ac5c 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -11,7 +11,6 @@ pub async fn connect_tls( mut stream: S, mode: SslMode, tls: T, - has_hostname: bool, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -22,7 +21,7 @@ where SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { return Ok(MaybeTlsStream::Raw(stream)) } - SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {} + SslMode::Prefer | SslMode::Require => {} } let mut buf = BytesMut::new(); @@ -33,18 +32,13 @@ where stream.read_exact(&mut buf).await.map_err(Error::io)?; if buf[0] != b'S' { - match mode { - SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => { - return Err(Error::tls("server does not support TLS".into())) - } - SslMode::Disable | SslMode::Prefer => return Ok(MaybeTlsStream::Raw(stream)), + if SslMode::Require == mode { + return Err(Error::tls("server does not support TLS".into())); + } else { + return Ok(MaybeTlsStream::Raw(stream)); } } - if !has_hostname { - return Err(Error::tls("no hostname provided for TLS handshake".into())); - } - let stream = tls .connect(stream) .await diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index a3449f88b..414335955 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,5 +1,4 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -21,7 +20,6 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), - CopyBoth(CopyBothReceiver), } pub struct Request { @@ -260,24 +258,6 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } - RequestMessages::CopyBoth(mut receiver) => { - let message = match receiver.poll_next_unpin(cx) { - Poll::Ready(Some(message)) => message, - Poll::Ready(None) => { - trace!("poll_write: finished copy_both request"); - continue; - } - Poll::Pending => { - trace!("poll_write: waiting on copy_both stream"); - self.pending_request = Some(RequestMessages::CopyBoth(receiver)); - return Ok(true); - } - }; - Pin::new(&mut self.stream) - .start_send(message) - .map_err(Error::io)?; - self.pending_request = Some(RequestMessages::CopyBoth(receiver)); - } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs deleted file mode 100644 index 79a7be34a..000000000 --- a/tokio-postgres/src/copy_both.rs +++ /dev/null @@ -1,248 +0,0 @@ -use crate::client::{InnerClient, Responses}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; -use crate::{simple_query, Error}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use futures_channel::mpsc; -use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; -use log::debug; -use pin_project_lite::pin_project; -use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; -use postgres_protocol::message::frontend::CopyData; -use std::marker::{PhantomData, PhantomPinned}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -pub(crate) enum CopyBothMessage { - Message(FrontendMessage), - Done, -} - -pub struct CopyBothReceiver { - receiver: mpsc::Receiver, - done: bool, -} - -impl CopyBothReceiver { - pub(crate) fn new(receiver: mpsc::Receiver) -> CopyBothReceiver { - CopyBothReceiver { - receiver, - done: false, - } - } -} - -impl Stream for CopyBothReceiver { - type Item = FrontendMessage; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.done { - return Poll::Ready(None); - } - - match ready!(self.receiver.poll_next_unpin(cx)) { - Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)), - Some(CopyBothMessage::Done) => { - self.done = true; - let mut buf = BytesMut::new(); - frontend::copy_done(&mut buf); - frontend::sync(&mut buf); - Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) - } - None => { - self.done = true; - let mut buf = BytesMut::new(); - frontend::copy_fail("", &mut buf).unwrap(); - frontend::sync(&mut buf); - Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) - } - } - } -} - -enum SinkState { - Active, - Closing, - Reading, -} - -pin_project! { - /// A sink for `COPY ... FROM STDIN` query data. - /// - /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is - /// not, the copy will be aborted. - pub struct CopyBothDuplex { - #[pin] - sender: mpsc::Sender, - responses: Responses, - buf: BytesMut, - state: SinkState, - #[pin] - _p: PhantomPinned, - _p2: PhantomData, - } -} - -impl CopyBothDuplex -where - T: Buf + 'static + Send, -{ - pub(crate) fn new(sender: mpsc::Sender, responses: Responses) -> Self { - Self { - sender, - responses, - buf: BytesMut::new(), - state: SinkState::Active, - _p: PhantomPinned, - _p2: PhantomData, - } - } - - /// A poll-based version of `finish`. - pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match self.state { - SinkState::Active => { - ready!(self.as_mut().poll_flush(cx))?; - let mut this = self.as_mut().project(); - ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; - this.sender - .start_send(CopyBothMessage::Done) - .map_err(|_| Error::closed())?; - *this.state = SinkState::Closing; - } - SinkState::Closing => { - let this = self.as_mut().project(); - ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; - *this.state = SinkState::Reading; - } - SinkState::Reading => { - let this = self.as_mut().project(); - match ready!(this.responses.poll_next(cx))? { - Message::CommandComplete(body) => { - let rows = body - .tag() - .map_err(Error::parse)? - .rsplit(' ') - .next() - .unwrap() - .parse() - .unwrap_or(0); - return Poll::Ready(Ok(rows)); - } - _ => return Poll::Ready(Err(Error::unexpected_message())), - } - } - } - } - } - - /// Completes the copy, returning the number of rows inserted. - /// - /// The `Sink::close` method is equivalent to `finish`, except that it does not return the - /// number of rows. - pub async fn finish(mut self: Pin<&mut Self>) -> Result { - future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await - } -} - -impl Stream for CopyBothDuplex { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match ready!(this.responses.poll_next(cx)?) { - Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), - Message::CopyDone => Poll::Ready(None), - _ => Poll::Ready(Some(Err(Error::unexpected_message()))), - } - } -} - -impl Sink for CopyBothDuplex -where - T: Buf + 'static + Send, -{ - type Error = Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .sender - .poll_ready(cx) - .map_err(|_| Error::closed()) - } - - fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { - let this = self.project(); - - let data: Box = if item.remaining() > 4096 { - if this.buf.is_empty() { - Box::new(item) - } else { - Box::new(this.buf.split().freeze().chain(item)) - } - } else { - this.buf.put(item); - if this.buf.len() > 4096 { - Box::new(this.buf.split().freeze()) - } else { - return Ok(()); - } - }; - - let data = CopyData::new(data).map_err(Error::encode)?; - this.sender - .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) - .map_err(|_| Error::closed()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - if !this.buf.is_empty() { - ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; - let data: Box = Box::new(this.buf.split().freeze()); - let data = CopyData::new(data).map_err(Error::encode)?; - this.sender - .as_mut() - .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) - .map_err(|_| Error::closed())?; - } - - this.sender.poll_flush(cx).map_err(|_| Error::closed()) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_finish(cx).map_ok(|_| ()) - } -} - -pub async fn copy_both_simple( - client: &InnerClient, - query: &str, -) -> Result, Error> -where - T: Buf + 'static + Send, -{ - debug!("executing copy both query {}", query); - - let buf = simple_query::encode(client, query)?; - - let (mut sender, receiver) = mpsc::channel(1); - let receiver = CopyBothReceiver::new(receiver); - let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; - - sender - .send(CopyBothMessage::Message(FrontendMessage::Raw(buf))) - .await - .map_err(|_| Error::closed())?; - - match responses.next().await? { - Message::CopyBothResponse(_) => {} - _ => return Err(Error::unexpected_message()), - } - - Ok(CopyBothDuplex::new(sender, responses)) -} diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index b3fdba84a..59e31fea6 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -2,8 +2,8 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::query::extract_row_affected; -use crate::{query, simple_query, slice_iter, Error, Statement}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use crate::{query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; @@ -188,10 +188,14 @@ where } } -async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result, Error> +pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> where T: Buf + 'static + Send, { + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); let mut responses = client.send(RequestMessages::CopyIn(receiver))?; @@ -201,11 +205,9 @@ where .await .map_err(|_| Error::closed())?; - if !simple { - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), } match responses.next().await? { @@ -222,23 +224,3 @@ where _p2: PhantomData, }) } - -pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> -where - T: Buf + 'static + Send, -{ - debug!("executing copy in statement {}", statement.name()); - - let buf = query::encode(client, &statement, slice_iter(&[]))?; - start(client, buf, false).await -} - -pub async fn copy_in_simple(client: &InnerClient, query: &str) -> Result, Error> -where - T: Buf + 'static + Send, -{ - debug!("executing copy in query {}", query); - - let buf = simple_query::encode(client, query)?; - start(client, buf, true).await -} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 981f9365e..1e6949252 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, simple_query, slice_iter, Error, Statement}; +use crate::{query, slice_iter, Error, Statement}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -11,36 +11,23 @@ use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -pub async fn copy_out_simple(client: &InnerClient, query: &str) -> Result { - debug!("executing copy out query {}", query); - - let buf = simple_query::encode(client, query)?; - let responses = start(client, buf, true).await?; - Ok(CopyOutStream { - responses, - _p: PhantomPinned, - }) -} - pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { debug!("executing copy out statement {}", statement.name()); let buf = query::encode(client, &statement, slice_iter(&[]))?; - let responses = start(client, buf, false).await?; + let responses = start(client, buf).await?; Ok(CopyOutStream { responses, _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result { +async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - if !simple { - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), } match responses.next().await? { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 275978cb2..a9ecba4f1 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -123,7 +123,6 @@ pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; -pub use crate::copy_both::CopyBothDuplex; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; use crate::error::DbError; @@ -160,7 +159,6 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; -mod copy_both; mod copy_in; mod copy_out; pub mod error; @@ -170,7 +168,6 @@ mod maybe_tls_stream; mod portal; mod prepare; mod query; -pub mod replication; pub mod row; mod simple_query; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs deleted file mode 100644 index 1b49dcc42..000000000 --- a/tokio-postgres/src/replication.rs +++ /dev/null @@ -1,173 +0,0 @@ -//! Utilities for working with the PostgreSQL replication copy both format. - -use crate::copy_both::CopyBothDuplex; -use crate::Error; -use bytes::{BufMut, Bytes, BytesMut}; -use futures_util::{ready, SinkExt, Stream}; -use pin_project_lite::pin_project; -use postgres_protocol::message::backend::LogicalReplicationMessage; -use postgres_protocol::message::backend::ReplicationMessage; -use postgres_types::PgLsn; -use std::pin::Pin; -use std::task::{Context, Poll}; - -const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; -const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; - -pin_project! { - /// A type which deserializes the postgres replication protocol. This type can be used with - /// both physical and logical replication to get access to the byte content of each replication - /// message. - /// - /// The replication *must* be explicitly completed via the `finish` method. - pub struct ReplicationStream { - #[pin] - stream: CopyBothDuplex, - } -} - -impl ReplicationStream { - /// Creates a new ReplicationStream that will wrap the underlying CopyBoth stream - pub fn new(stream: CopyBothDuplex) -> Self { - Self { stream } - } - - /// Send standby update to server. - pub async fn standby_status_update( - self: Pin<&mut Self>, - write_lsn: PgLsn, - flush_lsn: PgLsn, - apply_lsn: PgLsn, - ts: i64, - reply: u8, - ) -> Result<(), Error> { - let mut this = self.project(); - - let mut buf = BytesMut::new(); - buf.put_u8(STANDBY_STATUS_UPDATE_TAG); - buf.put_u64(write_lsn.into()); - buf.put_u64(flush_lsn.into()); - buf.put_u64(apply_lsn.into()); - buf.put_i64(ts); - buf.put_u8(reply); - - this.stream.send(buf.freeze()).await - } - - /// Send hot standby feedback message to server. - pub async fn hot_standby_feedback( - self: Pin<&mut Self>, - timestamp: i64, - global_xmin: u32, - global_xmin_epoch: u32, - catalog_xmin: u32, - catalog_xmin_epoch: u32, - ) -> Result<(), Error> { - let mut this = self.project(); - - let mut buf = BytesMut::new(); - buf.put_u8(HOT_STANDBY_FEEDBACK_TAG); - buf.put_i64(timestamp); - buf.put_u32(global_xmin); - buf.put_u32(global_xmin_epoch); - buf.put_u32(catalog_xmin); - buf.put_u32(catalog_xmin_epoch); - - this.stream.send(buf.freeze()).await - } -} - -impl Stream for ReplicationStream { - type Item = Result, Error>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match ready!(this.stream.poll_next(cx)) { - Some(Ok(buf)) => { - Poll::Ready(Some(ReplicationMessage::parse(&buf).map_err(Error::parse))) - } - Some(Err(err)) => Poll::Ready(Some(Err(err))), - None => Poll::Ready(None), - } - } -} - -pin_project! { - /// A type which deserializes the postgres logical replication protocol. This type gives access - /// to a high level representation of the changes in transaction commit order. - /// - /// The replication *must* be explicitly completed via the `finish` method. - pub struct LogicalReplicationStream { - #[pin] - stream: ReplicationStream, - } -} - -impl LogicalReplicationStream { - /// Creates a new LogicalReplicationStream that will wrap the underlying CopyBoth stream - pub fn new(stream: CopyBothDuplex) -> Self { - Self { - stream: ReplicationStream::new(stream), - } - } - - /// Send standby update to server. - pub async fn standby_status_update( - self: Pin<&mut Self>, - write_lsn: PgLsn, - flush_lsn: PgLsn, - apply_lsn: PgLsn, - ts: i64, - reply: u8, - ) -> Result<(), Error> { - let this = self.project(); - this.stream - .standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, reply) - .await - } - - /// Send hot standby feedback message to server. - pub async fn hot_standby_feedback( - self: Pin<&mut Self>, - timestamp: i64, - global_xmin: u32, - global_xmin_epoch: u32, - catalog_xmin: u32, - catalog_xmin_epoch: u32, - ) -> Result<(), Error> { - let this = self.project(); - this.stream - .hot_standby_feedback( - timestamp, - global_xmin, - global_xmin_epoch, - catalog_xmin, - catalog_xmin_epoch, - ) - .await - } -} - -impl Stream for LogicalReplicationStream { - type Item = Result, Error>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match ready!(this.stream.poll_next(cx)) { - Some(Ok(ReplicationMessage::XLogData(body))) => { - let body = body - .map_data(|buf| LogicalReplicationMessage::parse(&buf)) - .map_err(Error::parse)?; - Poll::Ready(Some(Ok(ReplicationMessage::XLogData(body)))) - } - Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { - Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) - } - Some(Ok(_)) => Poll::Ready(Some(Err(Error::unexpected_message()))), - Some(Err(err)) => Poll::Ready(Some(Err(err))), - None => Poll::Ready(None), - } - } -} diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index a97ee126c..bcc6d928b 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -63,7 +63,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { +fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index cab185ae6..0ab4a7bab 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -22,8 +22,6 @@ use tokio_postgres::{ mod binary_copy; mod parse; #[cfg(feature = "runtime")] -mod replication; -#[cfg(feature = "runtime")] mod runtime; mod types; @@ -330,7 +328,7 @@ async fn simple_query() { } match &messages[2] { SimpleQueryMessage::Row(row) => { - assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); + assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); assert_eq!(row.get(0), Some("1")); assert_eq!(row.get(1), Some("steven")); @@ -339,7 +337,7 @@ async fn simple_query() { } match &messages[3] { SimpleQueryMessage::Row(row) => { - assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); + assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); assert_eq!(row.get(0), Some("2")); assert_eq!(row.get(1), Some("joe")); diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs deleted file mode 100644 index 44aae3f22..000000000 --- a/tokio-postgres/tests/test/replication.rs +++ /dev/null @@ -1,149 +0,0 @@ -use futures_util::StreamExt; -use std::time::{Duration, UNIX_EPOCH}; - -use postgres_protocol::message::backend::LogicalReplicationMessage::{Begin, Commit, Insert}; -use postgres_protocol::message::backend::ReplicationMessage::*; -use postgres_protocol::message::backend::TupleData; -use postgres_types::PgLsn; -use tokio_postgres::replication::LogicalReplicationStream; -use tokio_postgres::NoTls; -use tokio_postgres::SimpleQueryMessage::Row; - -#[tokio::test] -async fn test_replication() { - // form SQL connection - let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; - let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {}", e); - } - }); - - client - .simple_query("DROP TABLE IF EXISTS test_logical_replication") - .await - .unwrap(); - client - .simple_query("CREATE TABLE test_logical_replication(i int)") - .await - .unwrap(); - let res = client - .simple_query("SELECT 'test_logical_replication'::regclass::oid") - .await - .unwrap(); - let rel_id: u32 = if let Row(row) = &res[0] { - row.get("oid").unwrap().parse().unwrap() - } else { - panic!("unexpeced query message"); - }; - - client - .simple_query("DROP PUBLICATION IF EXISTS test_pub") - .await - .unwrap(); - client - .simple_query("CREATE PUBLICATION test_pub FOR ALL TABLES") - .await - .unwrap(); - - let slot = "test_logical_slot"; - - let query = format!( - r#"CREATE_REPLICATION_SLOT {:?} TEMPORARY LOGICAL "pgoutput""#, - slot - ); - let slot_query = client.simple_query(&query).await.unwrap(); - let lsn = if let Row(row) = &slot_query[0] { - row.get("consistent_point").unwrap() - } else { - panic!("unexpeced query message"); - }; - - // issue a query that will appear in the slot's stream since it happened after its creation - client - .simple_query("INSERT INTO test_logical_replication VALUES (42)") - .await - .unwrap(); - - let options = r#"("proto_version" '1', "publication_names" 'test_pub')"#; - let query = format!( - r#"START_REPLICATION SLOT {:?} LOGICAL {} {}"#, - slot, lsn, options - ); - let copy_stream = client - .copy_both_simple::(&query) - .await - .unwrap(); - - let stream = LogicalReplicationStream::new(copy_stream); - tokio::pin!(stream); - - // verify that we can observe the transaction in the replication stream - let begin = loop { - match stream.next().await { - Some(Ok(XLogData(body))) => { - if let Begin(begin) = body.into_data() { - break begin; - } - } - Some(Ok(_)) => (), - Some(Err(_)) => panic!("unexpected replication stream error"), - None => panic!("unexpected replication stream end"), - } - }; - - let insert = loop { - match stream.next().await { - Some(Ok(XLogData(body))) => { - if let Insert(insert) = body.into_data() { - break insert; - } - } - Some(Ok(_)) => (), - Some(Err(_)) => panic!("unexpected replication stream error"), - None => panic!("unexpected replication stream end"), - } - }; - - let commit = loop { - match stream.next().await { - Some(Ok(XLogData(body))) => { - if let Commit(commit) = body.into_data() { - break commit; - } - } - Some(Ok(_)) => (), - Some(Err(_)) => panic!("unexpected replication stream error"), - None => panic!("unexpected replication stream end"), - } - }; - - assert_eq!(begin.final_lsn(), commit.commit_lsn()); - assert_eq!(insert.rel_id(), rel_id); - - let tuple_data = insert.tuple().tuple_data(); - assert_eq!(tuple_data.len(), 1); - assert!(matches!(tuple_data[0], TupleData::Text(_))); - if let TupleData::Text(data) = &tuple_data[0] { - assert_eq!(data, &b"42"[..]); - } - - // Send a standby status update and require a keep alive response - let lsn: PgLsn = lsn.parse().unwrap(); - let epoch = UNIX_EPOCH + Duration::from_secs(946_684_800); - let ts = epoch.elapsed().unwrap().as_micros() as i64; - stream - .as_mut() - .standby_status_update(lsn, lsn, lsn, ts, 1) - .await - .unwrap(); - loop { - match stream.next().await { - Some(Ok(PrimaryKeepAlive(_))) => break, - Some(Ok(_)) => (), - Some(Err(e)) => panic!("unexpected replication stream error: {}", e), - None => panic!("unexpected replication stream end"), - } - } -} diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 86c1f0701..67b4ead8a 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,58 +66,6 @@ async fn target_session_attrs_err() { .unwrap(); } -#[tokio::test] -async fn host_only_ok() { - let _ = tokio_postgres::connect( - "host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_only_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_and_host_ok() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_mismatch() { - let _ = tokio_postgres::connect( - "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - -#[tokio::test] -async fn hostaddr_host_both_missing() { - let _ = tokio_postgres::connect( - "port=5433 user=pass_user dbname=postgres password=password", - NoTls, - ) - .await - .err() - .unwrap(); -} - #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index c325917aa..a8e9e5afa 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -1,4 +1,4 @@ -use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use std::fmt; use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; use tokio_postgres::Client; @@ -54,9 +54,8 @@ async fn test_date_time_params() { fn make_check(time: &str) -> (Option>, &str) { ( Some( - NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap() - .and_utc(), + Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") + .unwrap(), ), time, ) @@ -78,9 +77,8 @@ async fn test_with_special_date_time_params() { fn make_check(time: &str) -> (Timestamp>, &str) { ( Timestamp::Value( - NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap() - .and_utc(), + Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") + .unwrap(), ), time, ) From 82bb077b93059653a2bf8bc9a6ba1249d491f134 Mon Sep 17 00:00:00 2001 From: Jeff Davis Date: Mon, 14 Dec 2020 11:54:01 -0800 Subject: [PATCH 109/126] Make simple_query::encode() pub(crate). --- tokio-postgres/src/simple_query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 24473b896..a26e43e6e 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -63,7 +63,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) From cddba069b3773c9fdc4f75189f77adbf06ed9cb9 Mon Sep 17 00:00:00 2001 From: Jeff Davis Date: Mon, 14 Dec 2020 11:58:59 -0800 Subject: [PATCH 110/126] Connection string config for replication. Co-authored-by: Petros Angelatos --- tokio-postgres/src/config.rs | 45 +++++++++++++++++++++++++++++++ tokio-postgres/src/connect_raw.rs | 8 +++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 62b45f793..8bc2a42df 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -72,6 +72,21 @@ pub enum LoadBalanceHosts { Random, } +/// Replication mode configuration. +/// +/// It is recommended that you use a PostgreSQL server patch version +/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or +/// 9.5.25. Earlier patch levels have a bug that doesn't properly +/// handle pipelined requests after streaming has stopped. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -209,6 +224,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -242,6 +258,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + replication_mode: None, } } @@ -524,6 +541,22 @@ impl Config { self.load_balance_hosts } + /// Set replication mode. + /// + /// It is recommended that you use a PostgreSQL server patch version + /// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or + /// 9.5.25. Earlier patch levels have a bug that doesn't properly + /// handle pipelined requests after streaming has stopped. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -660,6 +693,17 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -744,6 +788,7 @@ impl fmt::Debug for Config { config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..8edf45937 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -133,6 +133,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; From 5f2e1e88a50eb9f5d7652e6b0c3963e8b05c6114 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Fri, 28 May 2021 00:18:23 +0200 Subject: [PATCH 111/126] implement Stream for Responses Signed-off-by: Petros Angelatos --- tokio-postgres/src/client.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 92eabde36..93f7c2f7b 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -19,7 +19,7 @@ use crate::{ use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; -use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; +use futures_util::{future, pin_mut, ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; use postgres_protocol::message::backend::Message; use postgres_types::BorrowToSql; @@ -29,6 +29,7 @@ use std::fmt; use std::net::IpAddr; #[cfg(feature = "runtime")] use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -61,6 +62,17 @@ impl Responses { } } +impl Stream for Responses { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match ready!((*self).poll_next(cx)) { + Err(err) if err.is_closed() => Poll::Ready(None), + msg => Poll::Ready(Some(msg)), + } + } +} + /// A cache of type info and prepared statements for fetching type info /// (corresponding to the queries in the [prepare](prepare) module). #[derive(Default)] From 2016ec345ea6d2be9b56bd47a2af472d471a6cea Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Thu, 1 Apr 2021 15:13:06 +0200 Subject: [PATCH 112/126] add copy_both_simple method Signed-off-by: Petros Angelatos --- postgres-protocol/src/message/backend.rs | 34 +++ tokio-postgres/src/client.rs | 48 ++- tokio-postgres/src/connection.rs | 20 ++ tokio-postgres/src/copy_both.rs | 358 +++++++++++++++++++++++ tokio-postgres/src/lib.rs | 2 + tokio-postgres/tests/test/copy_both.rs | 125 ++++++++ tokio-postgres/tests/test/main.rs | 1 + 7 files changed, 585 insertions(+), 3 deletions(-) create mode 100644 tokio-postgres/src/copy_both.rs create mode 100644 tokio-postgres/tests/test/copy_both.rs diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 73b169288..fdc83fedb 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -93,6 +94,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +192,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -524,6 +536,28 @@ impl CopyOutResponseBody { } } +#[derive(Debug, Clone)] +pub struct CopyBothResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug, Clone)] pub struct DataRowBody { storage: Bytes, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 93f7c2f7b..f463d8402 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,7 @@ -use crate::codec::BackendMessages; +use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::{CopyBothDuplex, CopyBothReceiver}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -13,8 +14,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -41,6 +43,11 @@ pub struct Responses { cur: BackendMessages, } +pub struct CopyBothHandles { + pub(crate) stream_receiver: mpsc::Receiver>, + pub(crate) sink_sender: mpsc::Sender, +} + impl Responses { pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { loop { @@ -115,6 +122,32 @@ impl InnerClient { }) } + pub fn start_copy_both(&self) -> Result { + let (sender, receiver) = mpsc::channel(1); + let (stream_sender, stream_receiver) = mpsc::channel(0); + let (sink_sender, sink_receiver) = mpsc::channel(0); + + let responses = Responses { + receiver, + cur: BackendMessages::empty(), + }; + let messages = RequestMessages::CopyBoth(CopyBothReceiver::new( + responses, + sink_receiver, + stream_sender, + )); + + let request = Request { messages, sender }; + self.sender + .unbounded_send(request) + .map_err(|_| Error::closed())?; + + Ok(CopyBothHandles { + stream_receiver, + sink_sender, + }) + } + pub fn typeinfo(&self) -> Option { self.cached_typeinfo.lock().typeinfo.clone() } @@ -505,6 +538,15 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..a3449f88b 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -20,6 +21,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -258,6 +260,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..d3b46eab7 --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,358 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// The state machine of CopyBothReceiver +/// +/// ```ignore +/// CopyBoth +/// / \ +/// v v +/// CopyOut CopyIn +/// \ / +/// v v +/// CopyNone +/// | +/// v +/// CopyComplete +/// | +/// v +/// CommandComplete +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CopyBothState { + /// The state before having entered the CopyBoth mode. + Setup, + /// Initial state where CopyData messages can go in both directions + CopyBoth, + /// The server->client stream is closed and we're in CopyIn mode + CopyIn, + /// The client->server stream is closed and we're in CopyOut mode + CopyOut, + /// Both directions are closed, we waiting for CommandComplete messages + CopyNone, + /// We have received the first CommandComplete message for the copy + CopyComplete, + /// We have received the final CommandComplete message for the statement + CommandComplete, +} + +/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no +/// matter what the users do with their CopyBothDuplex handle we're always going to send the +/// correct messages to the backend in order to restore the connection into a usable state. +/// +/// ```ignore +/// | +/// | +/// | +/// pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex +/// | ^ \ +/// | / v +/// | Sink Stream +/// ``` +pub struct CopyBothReceiver { + /// Receiver of backend messages from the underlying [Connection](crate::Connection) + responses: Responses, + /// Receiver of frontend messages sent by the user using + sink_receiver: mpsc::Receiver, + /// Sender of CopyData contents to be consumed by the user using + stream_sender: mpsc::Sender>, + /// The current state of the subprotocol + state: CopyBothState, + /// Holds a buffered message until we are ready to send it to the user's stream + buffered_message: Option>, +} + +impl CopyBothReceiver { + pub(crate) fn new( + responses: Responses, + sink_receiver: mpsc::Receiver, + stream_sender: mpsc::Sender>, + ) -> CopyBothReceiver { + CopyBothReceiver { + responses, + sink_receiver, + stream_sender, + state: CopyBothState::Setup, + buffered_message: None, + } + } + + /// Convenience method to set the subprotocol into an unexpected message state + fn unexpected_message(&mut self) { + self.sink_receiver.close(); + self.buffered_message = Some(Err(Error::unexpected_message())); + self.state = CopyBothState::CommandComplete; + } + + /// Processes messages from the backend, it will resolve once all backend messages have been + /// processed + fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> { + use CopyBothState::*; + + loop { + // Deliver the buffered message (if any) to the user to ensure we can potentially + // buffer a new one in response to a server message + if let Some(message) = self.buffered_message.take() { + match self.stream_sender.poll_ready(cx) { + Poll::Ready(_) => { + // If the receiver has hung up we'll just drop the message + let _ = self.stream_sender.start_send(message); + } + Poll::Pending => { + // Stash the message and try again later + self.buffered_message = Some(message); + return Poll::Pending; + } + } + } + + match ready!(self.responses.poll_next_unpin(cx)) { + Some(Ok(Message::CopyBothResponse(body))) => match self.state { + Setup => { + self.buffered_message = Some(Ok(Message::CopyBothResponse(body))); + self.state = CopyBoth; + } + _ => self.unexpected_message(), + }, + Some(Ok(Message::CopyData(body))) => match self.state { + CopyBoth | CopyOut => { + self.buffered_message = Some(Ok(Message::CopyData(body))); + } + _ => self.unexpected_message(), + }, + // The server->client stream is done + Some(Ok(Message::CopyDone)) => { + match self.state { + CopyBoth => self.state = CopyIn, + CopyOut => self.state = CopyNone, + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::CommandComplete(_))) => { + match self.state { + CopyNone => self.state = CopyComplete, + CopyComplete => { + self.stream_sender.close_channel(); + self.sink_receiver.close(); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + // The server indicated an error, terminate our side if we haven't already + Some(Err(err)) => { + match self.state { + Setup | CopyBoth | CopyOut | CopyIn => { + self.sink_receiver.close(); + self.buffered_message = Some(Err(err)); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::ReadyForQuery(_))) => match self.state { + CommandComplete => { + self.sink_receiver.close(); + self.stream_sender.close_channel(); + } + _ => self.unexpected_message(), + }, + Some(Ok(_)) => self.unexpected_message(), + None => return Poll::Ready(()), + } + } + } +} + +/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This +/// is the mechanism that drives the CopyBoth subprotocol forward +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use CopyBothState::*; + + match self.poll_backend(cx) { + Poll::Ready(()) => Poll::Ready(None), + Poll::Pending => match self.state { + Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) { + Some(msg) => Poll::Ready(Some(msg)), + None => { + self.state = match self.state { + CopyBoth => CopyOut, + CopyIn => CopyNone, + _ => unreachable!(), + }; + + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + }, + _ => Poll::Pending, + }, + } + } +} + +pin_project! { + /// A duplex stream for consuming streaming replication data. + /// + /// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new + /// query. This will ensure that the connection returns into normal processing mode. + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ⚠️ INCORRECT ⚠️ + /// client.query("SELECT 1", &[]).await; // hangs forever + /// + /// // duplex_stream drop-ed here + /// } + /// ``` + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ✅ CORRECT ✅ + /// drop(duplex_stream); + /// + /// client.query("SELECT 1", &[]).await; + /// } + /// ``` + pub struct CopyBothDuplex { + #[pin] + sink_sender: mpsc::Sender, + #[pin] + stream_receiver: mpsc::Receiver>, + buf: BytesMut, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) { + Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())), + Some(Ok(_)) => Some(Err(Error::unexpected_message())), + Some(Err(err)) => Some(Err(err)), + None => None, + }) + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sink_sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sink_sender + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sink_sender + .as_mut() + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed())?; + } + + this.sink_sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + this.sink_sender.disconnect(); + Poll::Ready(Ok(())) + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let mut handles = client.start_copy_both()?; + + handles + .sink_sender + .send(FrontendMessage::Raw(buf)) + .await + .map_err(|_| Error::closed())?; + + match handles.stream_receiver.next().await.transpose()? { + Some(Message::CopyBothResponse(_)) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex { + stream_receiver: handles.stream_receiver, + sink_sender: handles.sink_sender, + buf: BytesMut::new(), + _p: PhantomPinned, + _p2: PhantomData, + }) +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a603158fb..cde9df841 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -122,6 +122,7 @@ pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; +pub use crate::copy_both::CopyBothDuplex; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; use crate::error::DbError; @@ -159,6 +160,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; diff --git a/tokio-postgres/tests/test/copy_both.rs b/tokio-postgres/tests/test/copy_both.rs new file mode 100644 index 000000000..2723928ac --- /dev/null +++ b/tokio-postgres/tests/test/copy_both.rs @@ -0,0 +1,125 @@ +use futures_util::{future, StreamExt, TryStreamExt}; +use tokio_postgres::{error::SqlState, Client, SimpleQueryMessage, SimpleQueryRow}; + +async fn q(client: &Client, query: &str) -> Vec { + let msgs = client.simple_query(query).await.unwrap(); + + msgs.into_iter() + .filter_map(|msg| match msg { + SimpleQueryMessage::Row(row) => Some(row), + _ => None, + }) + .collect() +} + +#[tokio::test] +async fn copy_both_error() { + let client = crate::connect("user=postgres replication=database").await; + + let err = client + .copy_both_simple::("START_REPLICATION SLOT undefined LOGICAL 0000/0000") + .await + .err() + .unwrap(); + + assert_eq!(err.code(), Some(&SqlState::UNDEFINED_OBJECT)); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both_stream_error() { + let client = crate::connect("user=postgres replication=true").await; + + q(&client, "CREATE_REPLICATION_SLOT err2 PHYSICAL").await; + + // This will immediately error out after entering CopyBoth mode + let duplex_stream = client + .copy_both_simple::("START_REPLICATION SLOT err2 PHYSICAL FFFF/FFFF") + .await + .unwrap(); + + let mut msgs: Vec<_> = duplex_stream.collect().await; + let result = msgs.pop().unwrap(); + assert_eq!(msgs.len(), 0); + assert!(result.unwrap_err().as_db_error().is_some()); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "DROP_REPLICATION_SLOT err2").await.len(), 0); +} + +#[tokio::test] +async fn copy_both_stream_error_sync() { + let client = crate::connect("user=postgres replication=database").await; + + q(&client, "CREATE_REPLICATION_SLOT err1 TEMPORARY PHYSICAL").await; + + // This will immediately error out after entering CopyBoth mode + let duplex_stream = client + .copy_both_simple::("START_REPLICATION SLOT err1 PHYSICAL FFFF/FFFF") + .await + .unwrap(); + + // Immediately close our sink to send a CopyDone before receiving the ErrorResponse + drop(duplex_stream); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both() { + let client = crate::connect("user=postgres replication=database").await; + + q(&client, "DROP TABLE IF EXISTS replication").await; + q(&client, "CREATE TABLE replication (i text)").await; + + let slot_query = "CREATE_REPLICATION_SLOT slot TEMPORARY LOGICAL \"test_decoding\""; + let lsn = q(&client, slot_query).await[0] + .get("consistent_point") + .unwrap() + .to_owned(); + + // We will attempt to read this from the other end + q(&client, "BEGIN").await; + let xid = q(&client, "SELECT txid_current()").await[0] + .get("txid_current") + .unwrap() + .to_owned(); + q(&client, "INSERT INTO replication VALUES ('processed')").await; + q(&client, "COMMIT").await; + + // Insert a second row to generate unprocessed messages in the stream + q(&client, "INSERT INTO replication VALUES ('ignored')").await; + + let query = format!("START_REPLICATION SLOT slot LOGICAL {}", lsn); + let duplex_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let expected = vec![ + format!("BEGIN {}", xid), + "table public.replication: INSERT: i[text]:'processed'".to_string(), + format!("COMMIT {}", xid), + ]; + + let actual: Vec<_> = duplex_stream + // Process only XLogData messages + .try_filter(|buf| future::ready(buf[0] == b'w')) + // Playback the stream until the first expected message + .try_skip_while(|buf| future::ready(Ok(!buf.ends_with(expected[0].as_ref())))) + // Take only the expected number of messsage, the rest will be discarded by tokio_postgres + .take(expected.len()) + .try_collect() + .await + .unwrap(); + + for (msg, ending) in actual.into_iter().zip(expected.into_iter()) { + assert!(msg.ends_with(ending.as_ref())); + } + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 9a6aa26fe..778ddaf05 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -20,6 +20,7 @@ use tokio_postgres::{ }; mod binary_copy; +mod copy_both; mod parse; #[cfg(feature = "runtime")] mod runtime; From 41f5bacfe0eae63f619731c4f09434dcb4abd9bb Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Tue, 23 Nov 2021 15:36:00 +0100 Subject: [PATCH 113/126] ci: enable logical replication in the test image Signed-off-by: Petros Angelatos --- docker/sql_setup.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..051a12000 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: From acc2ce350914dacab3609690b0d92a64a35b92c3 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Thu, 1 Apr 2021 15:09:42 +0200 Subject: [PATCH 114/126] add simple query versions of copy operations Signed-off-by: Petros Angelatos --- tokio-postgres/src/client.rs | 13 ++++++++++++ tokio-postgres/src/copy_in.rs | 38 +++++++++++++++++++++++++--------- tokio-postgres/src/copy_out.rs | 25 ++++++++++++++++------ 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index f463d8402..e47ef2b1f 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -527,6 +527,14 @@ impl Client { copy_in::copy_in(self.inner(), statement).await } + /// Executes a `COPY FROM STDIN` query, returning a sink used to write the copy data. + pub async fn copy_in_simple(&self, query: &str) -> Result, Error> + where + U: Buf + 'static + Send, + { + copy_in::copy_in_simple(self.inner(), query).await + } + /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. @@ -538,6 +546,11 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a `COPY TO STDOUT` query, returning a stream of the resulting data. + pub async fn copy_out_simple(&self, query: &str) -> Result { + copy_out::copy_out_simple(self.inner(), query).await + } + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy /// data. pub async fn copy_both_simple(&self, query: &str) -> Result, Error> diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index 59e31fea6..b3fdba84a 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -2,8 +2,8 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::query::extract_row_affected; -use crate::{query, slice_iter, Error, Statement}; -use bytes::{Buf, BufMut, BytesMut}; +use crate::{query, simple_query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; @@ -188,14 +188,10 @@ where } } -pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result, Error> where T: Buf + 'static + Send, { - debug!("executing copy in statement {}", statement.name()); - - let buf = query::encode(client, &statement, slice_iter(&[]))?; - let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); let mut responses = client.send(RequestMessages::CopyIn(receiver))?; @@ -205,9 +201,11 @@ where .await .map_err(|_| Error::closed())?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { @@ -224,3 +222,23 @@ where _p2: PhantomData, }) } + +pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + start(client, buf, false).await +} + +pub async fn copy_in_simple(client: &InnerClient, query: &str) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in query {}", query); + + let buf = simple_query::encode(client, query)?; + start(client, buf, true).await +} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..981f9365e 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, simple_query, slice_iter, Error, Statement}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -11,23 +11,36 @@ use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; +pub async fn copy_out_simple(client: &InnerClient, query: &str) -> Result { + debug!("executing copy out query {}", query); + + let buf = simple_query::encode(client, query)?; + let responses = start(client, buf, true).await?; + Ok(CopyOutStream { + responses, + _p: PhantomPinned, + }) +} + pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { debug!("executing copy out statement {}", statement.name()); let buf = query::encode(client, &statement, slice_iter(&[]))?; - let responses = start(client, buf).await?; + let responses = start(client, buf, false).await?; Ok(CopyOutStream { responses, _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { From 2e7372dd8a9e32cdc77dab431feaeb88a51f5cf1 Mon Sep 17 00:00:00 2001 From: Ufuk Celebi <1756620+uce@users.noreply.github.com> Date: Mon, 17 May 2021 22:46:38 +0200 Subject: [PATCH 115/126] config: add sslmode `verify-ca` and `verify-full` When a connection is established, the added modes are treated in the same way as the existing `require` mode as they both require a TLS connection. --- postgres/src/config.rs | 3 ++- tokio-postgres/src/config.rs | 9 ++++++++- tokio-postgres/src/connect_tls.rs | 11 ++++++----- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a32ddc78e..6be68992a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -35,7 +35,8 @@ use tokio_postgres::{Error, Socket}; /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used -/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to +/// be used. Defaults to `prefer`. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 8bc2a42df..83925a2f8 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -48,6 +48,10 @@ pub enum SslMode { Prefer, /// Require the use of TLS. Require, + /// Require the use of TLS. + VerifyCa, + /// Require the use of TLS. + VerifyFull, } /// Channel binding configuration. @@ -116,7 +120,8 @@ pub enum Host { /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used -/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to +/// be used. Defaults to `prefer`. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting @@ -579,6 +584,8 @@ impl Config { "disable" => SslMode::Disable, "prefer" => SslMode::Prefer, "require" => SslMode::Require, + "verify-ca" => SslMode::VerifyCa, + "verify-full" => SslMode::VerifyFull, _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), }; self.ssl_mode(mode); diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 2b1229125..41b319c2b 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -22,7 +22,7 @@ where SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { return Ok(MaybeTlsStream::Raw(stream)) } - SslMode::Prefer | SslMode::Require => {} + SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {} } let mut buf = BytesMut::new(); @@ -33,10 +33,11 @@ where stream.read_exact(&mut buf).await.map_err(Error::io)?; if buf[0] != b'S' { - if SslMode::Require == mode { - return Err(Error::tls("server does not support TLS".into())); - } else { - return Ok(MaybeTlsStream::Raw(stream)); + match mode { + SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => { + return Err(Error::tls("server does not support TLS".into())) + } + SslMode::Disable | SslMode::Prefer => return Ok(MaybeTlsStream::Raw(stream)), } } From c8a054cdb83dbbe520edea358c30b9273487613c Mon Sep 17 00:00:00 2001 From: Ufuk Celebi <1756620+uce@users.noreply.github.com> Date: Wed, 19 May 2021 22:21:19 +0200 Subject: [PATCH 116/126] config: add ssl config params Adds additional SSL config params: - sslcert - sslkey - sslrootcert More details at https://www.postgresql.org/docs/9.5/libpq-connect.html#LIBPQ-CONNSTRING. --- postgres/src/config.rs | 44 ++++++++++++++++++++++- tokio-postgres/src/config.rs | 69 ++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 6be68992a..74653a88d 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -3,12 +3,12 @@ use crate::connection::Connection; use crate::Client; use log::info; -use std::fmt; use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; +use std::{fmt, path::PathBuf}; use tokio::runtime; #[doc(inline)] pub use tokio_postgres::config::{ @@ -34,9 +34,12 @@ use tokio_postgres::{Error, Socket}; /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslcert` - Location of the client SSL certificate file. +/// * `sslkey` - Location for the secret key file used for the client certificate. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used /// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to /// be used. Defaults to `prefer`. +/// * `sslrootcert` - Location of SSL certificate authority (CA) certificate. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting @@ -218,6 +221,32 @@ impl Config { self.config.get_application_name() } + /// Sets the location of the client SSL certificate file. + /// + /// Defaults to `None`. + pub fn ssl_cert(&mut self, ssl_cert: &str) -> &mut Config { + self.config.ssl_cert(ssl_cert); + self + } + + /// Gets the location of the client SSL certificate file. + pub fn get_ssl_cert(&self) -> Option { + self.config.get_ssl_cert() + } + + /// Sets the location of the secret key file used for the client certificate. + /// + /// Defaults to `None`. + pub fn ssl_key(&mut self, ssl_key: &str) -> &mut Config { + self.config.ssl_key(ssl_key); + self + } + + /// Gets the location of the secret key file used for the client certificate. + pub fn get_ssl_key(&self) -> Option { + self.config.get_ssl_key() + } + /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -231,6 +260,19 @@ impl Config { self.config.get_ssl_mode() } + /// Sets the location of SSL certificate authority (CA) certificate. + /// + /// Defaults to `None`. + pub fn ssl_root_cert(&mut self, ssl_root_cert: &str) -> &mut Config { + self.config.ssl_root_cert(ssl_root_cert); + self + } + + /// Gets the location of SSL certificate authority (CA) certificate. + pub fn get_ssl_root_cert(&self) -> Option { + self.config.get_ssl_root_cert() + } + /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 83925a2f8..c750852e3 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -119,9 +119,12 @@ pub enum Host { /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslcert` - Location of the client SSL certificate file. +/// * `sslkey` - Location for the secret key file used for the client certificate. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used /// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to /// be used. Defaults to `prefer`. +/// * `sslrootcert` - Location of SSL certificate authority (CA) certificate. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting @@ -217,7 +220,10 @@ pub struct Config { pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, + pub(crate) ssl_cert: Option, + pub(crate) ssl_key: Option, pub(crate) ssl_mode: SslMode, + pub(crate) ssl_root_cert: Option, pub(crate) host: Vec, pub(crate) hostaddr: Vec, pub(crate) port: Vec, @@ -247,7 +253,10 @@ impl Config { dbname: None, options: None, application_name: None, + ssl_cert: None, + ssl_key: None, ssl_mode: SslMode::Prefer, + ssl_root_cert: None, host: vec![], hostaddr: vec![], port: vec![], @@ -334,6 +343,32 @@ impl Config { self.application_name.as_deref() } + /// Sets the location of the client SSL certificate file. + /// + /// Defaults to `None`. + pub fn ssl_cert(&mut self, ssl_cert: &str) -> &mut Config { + self.ssl_cert = Some(PathBuf::from(ssl_cert)); + self + } + + /// Gets the location of the client SSL certificate file. + pub fn get_ssl_cert(&self) -> Option { + self.ssl_cert.clone() + } + + /// Sets the location of the secret key file used for the client certificate. + /// + /// Defaults to `None`. + pub fn ssl_key(&mut self, ssl_key: &str) -> &mut Config { + self.ssl_key = Some(PathBuf::from(ssl_key)); + self + } + + /// Gets the location of the secret key file used for the client certificate. + pub fn get_ssl_key(&self) -> Option { + self.ssl_key.clone() + } + /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -347,6 +382,19 @@ impl Config { self.ssl_mode } + /// Sets the location of SSL certificate authority (CA) certificate. + /// + /// Defaults to `None`. + pub fn ssl_root_cert(&mut self, ssl_root_cert: &str) -> &mut Config { + self.ssl_root_cert = Some(PathBuf::from(ssl_root_cert)); + self + } + + /// Gets the location of SSL certificate authority (CA) certificate. + pub fn get_ssl_root_cert(&self) -> Option { + self.ssl_root_cert.clone() + } + /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix @@ -579,6 +627,18 @@ impl Config { "application_name" => { self.application_name(value); } + "sslcert" => { + if std::fs::metadata(value).is_err() { + return Err(Error::config_parse(Box::new(InvalidValue("sslcert")))); + } + self.ssl_cert(value); + } + "sslkey" => { + if std::fs::metadata(value).is_err() { + return Err(Error::config_parse(Box::new(InvalidValue("sslkey")))); + } + self.ssl_key(value); + } "sslmode" => { let mode = match value { "disable" => SslMode::Disable, @@ -590,6 +650,12 @@ impl Config { }; self.ssl_mode(mode); } + "sslrootcert" => { + if std::fs::metadata(value).is_err() { + return Err(Error::config_parse(Box::new(InvalidValue("sslrootcert")))); + } + self.ssl_root_cert(value); + } "host" => { for host in value.split(',') { self.host(host); @@ -776,7 +842,10 @@ impl fmt::Debug for Config { .field("dbname", &self.dbname) .field("options", &self.options) .field("application_name", &self.application_name) + .field("ssl_cert", &self.ssl_cert) + .field("ssl_key", &self.ssl_key) .field("ssl_mode", &self.ssl_mode) + .field("ssl_root_cert", &self.ssl_root_cert) .field("host", &self.host) .field("hostaddr", &self.hostaddr) .field("port", &self.port) From 11467f85393ef6a06313b813d1ce9bd468270877 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Mon, 24 May 2021 19:25:03 +0200 Subject: [PATCH 117/126] describe branching strategy Signed-off-by: Petros Angelatos --- README.md | 58 +++++++++++++++++-------------------------------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index b81a6716f..a6716dcdd 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,24 @@ -# Rust-Postgres +# Materialize fork of Rust-Postgres -PostgreSQL support for Rust. +This repo serves as a staging area in order to develop and use features of the +rust postgtres client before they are accepted upstream. -## postgres [![Latest Version](https://img.shields.io/crates/v/postgres.svg)](https://crates.io/crates/postgres) +Since development on this repo and the upstream one can happen in parallel this +repo adops a branching strategy that keeps both in sync and keeps a tidy +history. Importantly, the release branches are **never** forced-pushed so that +older versions of materialize are always buildable. -[Documentation](https://docs.rs/postgres) +## Branching strategy -A native, synchronous PostgreSQL client. +For every upstream release a local `mz-{version}` branch is created. The latest +such branch should be made the default branch of this repo in the Github +settings. -## tokio-postgres [![Latest Version](https://img.shields.io/crates/v/tokio-postgres.svg)](https://crates.io/crates/tokio-postgres) +Whenever a PR is opened it should targed the current release branch (it should +be picked automatically if its set as default on Github). -[Documentation](https://docs.rs/tokio-postgres) - -A native, asynchronous PostgreSQL client. - -## postgres-types [![Latest Version](https://img.shields.io/crates/v/postgres-types.svg)](https://crates.io/crates/postgres-types) - -[Documentation](https://docs.rs/postgres-types) - -Conversions between Rust and Postgres types. - -## postgres-native-tls [![Latest Version](https://img.shields.io/crates/v/postgres-native-tls.svg)](https://crates.io/crates/postgres-native-tls) - -[Documentation](https://docs.rs/postgres-native-tls) - -TLS support for postgres and tokio-postgres via native-tls. - -## postgres-openssl [![Latest Version](https://img.shields.io/crates/v/postgres-openssl.svg)](https://crates.io/crates/postgres-openssl) - -[Documentation](https://docs.rs/postgres-openssl) - -TLS support for postgres and tokio-postgres via openssl. - -# Running test suite - -The test suite requires postgres to be running in the correct configuration. The easiest way to do this is with docker: - -1. Install `docker` and `docker-compose`. - 1. On ubuntu: `sudo apt install docker.io docker-compose`. -1. Make sure your user has permissions for docker. - 1. On ubuntu: ``sudo usermod -aG docker $USER`` -1. Change to top-level directory of `rust-postgres` repo. -1. Run `docker-compose up -d`. -1. Run `cargo test`. -1. Run `docker-compose stop`. +Whenever a new version is created upstream a new `mz-{version}` branch on this +repo is created, initially pointing at the release commit of the upstream repo. +Then, all the fork-specific work is rebased on top of it. This process gives +the opportunity to prune PRs that have successfully made it to the upstream +repo and keep a clean per-version history. From 443f6c685c3aff9d158924060eb957b2a71724fb Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sun, 16 Jan 2022 02:37:22 -0500 Subject: [PATCH 118/126] Fix conditional imports --- tokio-postgres/src/config.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index c750852e3..e0102b6ce 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -19,7 +19,9 @@ use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] -use std::path::{Path, PathBuf}; +use std::path::Path; +#[cfg(unix)] +use std::path::PathBuf; use std::str; use std::str::FromStr; use std::time::Duration; From 43932918dc585ab5dff75d5f81bd40f4109872fd Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Mon, 21 Feb 2022 23:13:24 -0500 Subject: [PATCH 119/126] README: update development instructions --- README.md | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index a6716dcdd..7a353dfe8 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,33 @@ # Materialize fork of Rust-Postgres -This repo serves as a staging area in order to develop and use features of the -rust postgtres client before they are accepted upstream. +This repo serves as a staging area for Materialize patches to the +[rust-postgres] client before they are accepted upstream. -Since development on this repo and the upstream one can happen in parallel this -repo adops a branching strategy that keeps both in sync and keeps a tidy -history. Importantly, the release branches are **never** forced-pushed so that -older versions of materialize are always buildable. +There are no releases from this fork. The [MaterializeInc/materialize] +repository simply pins a recent commit from the `master` branch. Other projects +are welcome to do the same. The `master` branch is never force pushed. Upstream +changes are periodically into `master` via `git merge`. -## Branching strategy +## Adding a new patch -For every upstream release a local `mz-{version}` branch is created. The latest -such branch should be made the default branch of this repo in the Github -settings. +Develop your patch against the master branch of the upstream [rust-postgres] +project. Open a PR with your changes. If your PR is not merged quickly, open the +same PR against this repository and request a review from a Materialize +engineer. -Whenever a PR is opened it should targed the current release branch (it should -be picked automatically if its set as default on Github). +The long-term goal is to get every patch merged upstream. -Whenever a new version is created upstream a new `mz-{version}` branch on this -repo is created, initially pointing at the release commit of the upstream repo. -Then, all the fork-specific work is rebased on top of it. This process gives -the opportunity to prune PRs that have successfully made it to the upstream -repo and keep a clean per-version history. +## Integrating upstream changes + +```shell +git clone https://github.com/MaterializeInc/rust-postgres.git +git remote add upstream https://github.com/sfackler/rust-postgres.git +git checkout master +git pull +git checkout -b integrate-upstream +git fetch upstream +git merge upstream/master +# Resolve any conflicts, then open a PR against this repository with the merge commit. +``` + +[rust-postgres]: https://github.com/sfackler/rust-postgres.git From 95a1ef77d634e35a6abe5034cfa04ae5e64a0236 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Mon, 21 Feb 2022 23:16:14 -0500 Subject: [PATCH 120/126] README: fix links --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a353dfe8..7c2cc04f8 100644 --- a/README.md +++ b/README.md @@ -30,4 +30,5 @@ git merge upstream/master # Resolve any conflicts, then open a PR against this repository with the merge commit. ``` -[rust-postgres]: https://github.com/sfackler/rust-postgres.git +[rust-postgres]: https://github.com/sfackler/rust-postgres +[MaterializeInc/materialize]: https://github.com/MaterializeInc/materialize From 6d8777de347374a2e7975978b065caceea9b2664 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sun, 10 Jul 2022 18:12:13 -0400 Subject: [PATCH 121/126] Optionally enable deriving Serialize/Deserialize --- tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/config.rs | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index c2f80dc7e..23a9cab3e 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -56,6 +56,7 @@ pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } postgres-types = { version = "0.2.7", path = "../postgres-types" } +serde = { version = "1.0", optional = true } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index e0102b6ce..1546bac12 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -41,7 +41,8 @@ pub enum TargetSessionAttrs { } /// TLS configuration. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] pub enum SslMode { /// Do not use TLS. From cc77e6eb88d000f7349ad6869b21b504f1b650fe Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sun, 10 Jul 2022 23:32:02 -0400 Subject: [PATCH 122/126] Change SSL configuration to PEM bytes rather than files --- postgres/src/config.rs | 26 ++++++------- tokio-postgres/src/config.rs | 72 +++++++++++++++++++----------------- 2 files changed, 52 insertions(+), 46 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 74653a88d..91ad3c904 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -3,12 +3,12 @@ use crate::connection::Connection; use crate::Client; use log::info; +use std::fmt; use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use std::{fmt, path::PathBuf}; use tokio::runtime; #[doc(inline)] pub use tokio_postgres::config::{ @@ -221,29 +221,29 @@ impl Config { self.config.get_application_name() } - /// Sets the location of the client SSL certificate file. + /// Sets the client SSL certificate in PEM format. /// /// Defaults to `None`. - pub fn ssl_cert(&mut self, ssl_cert: &str) -> &mut Config { + pub fn ssl_cert(&mut self, ssl_cert: &[u8]) -> &mut Config { self.config.ssl_cert(ssl_cert); self } - /// Gets the location of the client SSL certificate file. - pub fn get_ssl_cert(&self) -> Option { + /// Gets the location of the client SSL certificate in PEM format. + pub fn get_ssl_cert(&self) -> Option<&[u8]> { self.config.get_ssl_cert() } - /// Sets the location of the secret key file used for the client certificate. + /// Sets the client SSL key in PEM format. /// /// Defaults to `None`. - pub fn ssl_key(&mut self, ssl_key: &str) -> &mut Config { + pub fn ssl_key(&mut self, ssl_key: &[u8]) -> &mut Config { self.config.ssl_key(ssl_key); self } - /// Gets the location of the secret key file used for the client certificate. - pub fn get_ssl_key(&self) -> Option { + /// Gets the client SSL key in PEM format. + pub fn get_ssl_key(&self) -> Option<&[u8]> { self.config.get_ssl_key() } @@ -260,16 +260,16 @@ impl Config { self.config.get_ssl_mode() } - /// Sets the location of SSL certificate authority (CA) certificate. + /// Sets the SSL certificate authority (CA) certificate in PEM format. /// /// Defaults to `None`. - pub fn ssl_root_cert(&mut self, ssl_root_cert: &str) -> &mut Config { + pub fn ssl_root_cert(&mut self, ssl_root_cert: &[u8]) -> &mut Config { self.config.ssl_root_cert(ssl_root_cert); self } - /// Gets the location of SSL certificate authority (CA) certificate. - pub fn get_ssl_root_cert(&self) -> Option { + /// Gets the SSL certificate authority (CA) certificate in PEM format. + pub fn get_ssl_root_cert(&self) -> Option<&[u8]> { self.config.get_ssl_root_cert() } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 1546bac12..156c28d1d 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -223,10 +223,10 @@ pub struct Config { pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, - pub(crate) ssl_cert: Option, - pub(crate) ssl_key: Option, + pub(crate) ssl_cert: Option>, + pub(crate) ssl_key: Option>, pub(crate) ssl_mode: SslMode, - pub(crate) ssl_root_cert: Option, + pub(crate) ssl_root_cert: Option>, pub(crate) host: Vec, pub(crate) hostaddr: Vec, pub(crate) port: Vec, @@ -346,30 +346,30 @@ impl Config { self.application_name.as_deref() } - /// Sets the location of the client SSL certificate file. + /// Sets the client SSL certificate in PEM format. /// /// Defaults to `None`. - pub fn ssl_cert(&mut self, ssl_cert: &str) -> &mut Config { - self.ssl_cert = Some(PathBuf::from(ssl_cert)); + pub fn ssl_cert(&mut self, ssl_cert: &[u8]) -> &mut Config { + self.ssl_cert = Some(ssl_cert.into()); self } - /// Gets the location of the client SSL certificate file. - pub fn get_ssl_cert(&self) -> Option { - self.ssl_cert.clone() + /// Gets the location of the client SSL certificate in PEM format. + pub fn get_ssl_cert(&self) -> Option<&[u8]> { + self.ssl_cert.as_deref() } - /// Sets the location of the secret key file used for the client certificate. + /// Sets the client SSL key in PEM format. /// /// Defaults to `None`. - pub fn ssl_key(&mut self, ssl_key: &str) -> &mut Config { - self.ssl_key = Some(PathBuf::from(ssl_key)); + pub fn ssl_key(&mut self, ssl_key: &[u8]) -> &mut Config { + self.ssl_key = Some(ssl_key.into()); self } - /// Gets the location of the secret key file used for the client certificate. - pub fn get_ssl_key(&self) -> Option { - self.ssl_key.clone() + /// Gets the client SSL key in PEM format. + pub fn get_ssl_key(&self) -> Option<&[u8]> { + self.ssl_key.as_deref() } /// Sets the SSL configuration. @@ -385,17 +385,17 @@ impl Config { self.ssl_mode } - /// Sets the location of SSL certificate authority (CA) certificate. + /// Sets the SSL certificate authority (CA) certificate in PEM format. /// /// Defaults to `None`. - pub fn ssl_root_cert(&mut self, ssl_root_cert: &str) -> &mut Config { - self.ssl_root_cert = Some(PathBuf::from(ssl_root_cert)); + pub fn ssl_root_cert(&mut self, ssl_root_cert: &[u8]) -> &mut Config { + self.ssl_root_cert = Some(ssl_root_cert.into()); self } - /// Gets the location of SSL certificate authority (CA) certificate. - pub fn get_ssl_root_cert(&self) -> Option { - self.ssl_root_cert.clone() + /// Gets the SSL certificate authority (CA) certificate in PEM format. + pub fn get_ssl_root_cert(&self) -> Option<&[u8]> { + self.ssl_root_cert.as_deref() } /// Adds a host to the configuration. @@ -630,18 +630,22 @@ impl Config { "application_name" => { self.application_name(value); } - "sslcert" => { - if std::fs::metadata(value).is_err() { + "sslcert" => match std::fs::read(value) { + Ok(contents) => { + self.ssl_cert(&contents); + } + Err(_) => { return Err(Error::config_parse(Box::new(InvalidValue("sslcert")))); } - self.ssl_cert(value); - } - "sslkey" => { - if std::fs::metadata(value).is_err() { + }, + "sslkey" => match std::fs::read(value) { + Ok(contents) => { + self.ssl_key(&contents); + } + Err(_) => { return Err(Error::config_parse(Box::new(InvalidValue("sslkey")))); } - self.ssl_key(value); - } + }, "sslmode" => { let mode = match value { "disable" => SslMode::Disable, @@ -653,12 +657,14 @@ impl Config { }; self.ssl_mode(mode); } - "sslrootcert" => { - if std::fs::metadata(value).is_err() { + "sslrootcert" => match std::fs::read(value) { + Ok(contents) => { + self.ssl_root_cert(&contents); + } + Err(_) => { return Err(Error::config_parse(Box::new(InvalidValue("sslrootcert")))); } - self.ssl_root_cert(value); - } + }, "host" => { for host in value.split(',') { self.host(host); From 46c667dbe933b00893c39ba74e0601eb98053238 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 27 May 2023 22:56:44 -0400 Subject: [PATCH 123/126] Add options for specifying SSL certificates inline This is often quite a bit more convenient than mucking with files on disk, as it allows the URL to be fully self contained. --- tokio-postgres/src/config.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 156c28d1d..e94eac459 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -123,11 +123,14 @@ pub enum Host { /// * `options` - Command line options used to configure the server. /// * `application_name` - Sets the `application_name` parameter on the server. /// * `sslcert` - Location of the client SSL certificate file. +/// * `sslcert_inline` - The contents of the client SSL certificate. /// * `sslkey` - Location for the secret key file used for the client certificate. +/// * `sslkey_inline` - The contents of the client SSL key. /// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used /// if available, but not used otherwise. If set to `require`, `verify-ca`, or `verify-full`, TLS will be forced to /// be used. Defaults to `prefer`. /// * `sslrootcert` - Location of SSL certificate authority (CA) certificate. +/// * `sslrootcert_inline` - The contents of the SSL certificate authority. /// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting @@ -638,6 +641,9 @@ impl Config { return Err(Error::config_parse(Box::new(InvalidValue("sslcert")))); } }, + "sslcert_inline" => { + self.ssl_cert(value.as_bytes()); + } "sslkey" => match std::fs::read(value) { Ok(contents) => { self.ssl_key(&contents); @@ -646,6 +652,9 @@ impl Config { return Err(Error::config_parse(Box::new(InvalidValue("sslkey")))); } }, + "sslkey_inline" => { + self.ssl_key(value.as_bytes()); + } "sslmode" => { let mode = match value { "disable" => SslMode::Disable, @@ -665,6 +674,9 @@ impl Config { return Err(Error::config_parse(Box::new(InvalidValue("sslrootcert")))); } }, + "sslrootcert_inline" => { + self.ssl_root_cert(value.as_bytes()); + } "host" => { for host in value.split(',') { self.host(host); From e807e3b0615373999c55a0304a5bf240196dabdd Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Mon, 25 Mar 2024 10:43:21 -0400 Subject: [PATCH 124/126] Upgrade to Rust 1.77 --- .github/workflows/ci.yml | 2 +- postgres-types/src/chrono_04.rs | 2 +- tokio-postgres/Cargo.toml | 1 - tokio-postgres/tests/test/main.rs | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 641a42722..a950428ec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.74.0 + version: 1.77.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index d599bde02..6b6406232 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -113,7 +113,7 @@ impl<'a> FromSql<'a> for NaiveDate { let jd = types::date_from_sql(raw)?; base() .date() - .checked_add_signed(Duration::days(i64::from(jd))) + .checked_add_signed(Duration::try_days(i64::from(jd)).unwrap()) .ok_or_else(|| "value too large to decode".into()) } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 23a9cab3e..f0e7fdb3e 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -82,7 +82,6 @@ chrono-04 = { version = "0.4", package = "chrono", default-features = false } eui48-1 = { version = "1.0", package = "eui48", default-features = false } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } -serde-1 = { version = "1.0", package = "serde" } serde_json-1 = { version = "1.0", package = "serde_json" } smol_str-01 = { version = "0.1", package = "smol_str" } uuid-08 = { version = "0.8", package = "uuid" } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 778ddaf05..3debf4eba 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -337,7 +337,7 @@ async fn simple_query() { } match &messages[3] { SimpleQueryMessage::Row(row) => { - assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); assert_eq!(row.get(0), Some("1")); assert_eq!(row.get(1), Some("steven")); @@ -346,7 +346,7 @@ async fn simple_query() { } match &messages[4] { SimpleQueryMessage::Row(row) => { - assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); assert_eq!(row.get(0), Some("2")); assert_eq!(row.get(1), Some("joe")); From 3fe1a997b8066f3c0804b0f4d0bbed53eaf37e41 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Fri, 14 Jun 2024 11:33:49 -0400 Subject: [PATCH 125/126] Expose the backend PID on the client So that clients can observe the ID of their connection without a network roundtrip. --- tokio-postgres/src/client.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index e47ef2b1f..e1d784607 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -610,6 +610,11 @@ impl Client { TransactionBuilder::new(self) } + /// Returns the server's process ID for the connection. + pub fn backend_pid(&self) -> i32 { + self.process_id + } + /// Constructs a cancellation token that can later be used to request cancellation of a query running on the /// connection associated with this client. pub fn cancel_token(&self) -> CancelToken { From 900ed50f3edc8c48cb0503124516da19a3a6afe7 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Wed, 21 Aug 2024 17:54:00 +0300 Subject: [PATCH 126/126] add independent replication crate --- Cargo.toml | 1 + postgres-replication/Cargo.toml | 31 + postgres-replication/LICENSE-APACHE | 201 ++++++ postgres-replication/LICENSE-MIT | 22 + postgres-replication/src/lib.rs | 175 +++++ postgres-replication/src/protocol.rs | 791 ++++++++++++++++++++++ postgres-replication/tests/replication.rs | 150 ++++ tokio-postgres/src/error/mod.rs | 9 +- 8 files changed, 1377 insertions(+), 3 deletions(-) create mode 100644 postgres-replication/Cargo.toml create mode 100644 postgres-replication/LICENSE-APACHE create mode 100644 postgres-replication/LICENSE-MIT create mode 100644 postgres-replication/src/lib.rs create mode 100644 postgres-replication/src/protocol.rs create mode 100644 postgres-replication/tests/replication.rs diff --git a/Cargo.toml b/Cargo.toml index 16e3739dd..f31993b07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "postgres-native-tls", "postgres-openssl", "postgres-protocol", + "postgres-replication", "postgres-types", "tokio-postgres", ] diff --git a/postgres-replication/Cargo.toml b/postgres-replication/Cargo.toml new file mode 100644 index 000000000..f24cd6ccf --- /dev/null +++ b/postgres-replication/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "postgres-replication" +version = "0.6.7" +authors = ["Petros Angelatos "] +edition = "2018" +description = "Protocol definitions for the Postgres logical replication protocol" +license = "MIT OR Apache-2.0" +repository = "https://github.com/sfackler/rust-postgres" +readme = "../README.md" + +[features] +default = [] + +[dependencies] +bytes = "1.0" +memchr = "2.0" +byteorder = "1.0" +postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } +postgres-types = { version = "0.2.7", path = "../postgres-types" } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", features = ["runtime"] } +futures-util = { version = "0.3", features = ["sink"] } +pin-project-lite = "0.2" + +[dev-dependencies] +tokio = { version = "1.0", features = [ + "macros", + "net", + "rt", + "rt-multi-thread", + "time", +] } diff --git a/postgres-replication/LICENSE-APACHE b/postgres-replication/LICENSE-APACHE new file mode 100644 index 000000000..16fe87b06 --- /dev/null +++ b/postgres-replication/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/postgres-replication/LICENSE-MIT b/postgres-replication/LICENSE-MIT new file mode 100644 index 000000000..71803aea1 --- /dev/null +++ b/postgres-replication/LICENSE-MIT @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2016 Steven Fackler + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/postgres-replication/src/lib.rs b/postgres-replication/src/lib.rs new file mode 100644 index 000000000..08d17d4b8 --- /dev/null +++ b/postgres-replication/src/lib.rs @@ -0,0 +1,175 @@ +//! Utilities for working with the PostgreSQL replication copy both format. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures_util::{ready, SinkExt, Stream}; +use pin_project_lite::pin_project; +use postgres_types::PgLsn; +use tokio_postgres::CopyBothDuplex; +use tokio_postgres::Error; + +pub mod protocol; + +use crate::protocol::{LogicalReplicationMessage, ReplicationMessage}; + +const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; +const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; + +pin_project! { + /// A type which deserializes the postgres replication protocol. This type can be used with + /// both physical and logical replication to get access to the byte content of each replication + /// message. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct ReplicationStream { + #[pin] + stream: CopyBothDuplex, + } +} + +impl ReplicationStream { + /// Creates a new ReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { stream } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + ts: i64, + reply: u8, + ) -> Result<(), Error> { + let mut this = self.project(); + + let mut buf = BytesMut::new(); + buf.put_u8(STANDBY_STATUS_UPDATE_TAG); + buf.put_u64(write_lsn.into()); + buf.put_u64(flush_lsn.into()); + buf.put_u64(apply_lsn.into()); + buf.put_i64(ts); + buf.put_u8(reply); + + this.stream.send(buf.freeze()).await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: i64, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let mut this = self.project(); + + let mut buf = BytesMut::new(); + buf.put_u8(HOT_STANDBY_FEEDBACK_TAG); + buf.put_i64(timestamp); + buf.put_u32(global_xmin); + buf.put_u32(global_xmin_epoch); + buf.put_u32(catalog_xmin); + buf.put_u32(catalog_xmin_epoch); + + this.stream.send(buf.freeze()).await + } +} + +impl Stream for ReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(buf)) => { + Poll::Ready(Some(ReplicationMessage::parse(&buf).map_err(Error::parse))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} + +pin_project! { + /// A type which deserializes the postgres logical replication protocol. This type gives access + /// to a high level representation of the changes in transaction commit order. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct LogicalReplicationStream { + #[pin] + stream: ReplicationStream, + } +} + +impl LogicalReplicationStream { + /// Creates a new LogicalReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { + stream: ReplicationStream::new(stream), + } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + ts: i64, + reply: u8, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, reply) + .await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: i64, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .hot_standby_feedback( + timestamp, + global_xmin, + global_xmin_epoch, + catalog_xmin, + catalog_xmin_epoch, + ) + .await + } +} + +impl Stream for LogicalReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(ReplicationMessage::XLogData(body))) => { + let body = body + .map_data(|buf| LogicalReplicationMessage::parse(&buf)) + .map_err(Error::parse)?; + Poll::Ready(Some(Ok(ReplicationMessage::XLogData(body)))) + } + Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { + Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} diff --git a/postgres-replication/src/protocol.rs b/postgres-replication/src/protocol.rs new file mode 100644 index 000000000..d94825014 --- /dev/null +++ b/postgres-replication/src/protocol.rs @@ -0,0 +1,791 @@ +use std::io::{self, Read}; +use std::{cmp, str}; + +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::Bytes; +use memchr::memchr; +use postgres_protocol::{Lsn, Oid}; + +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + +// logical replication message tags +const BEGIN_TAG: u8 = b'B'; +const COMMIT_TAG: u8 = b'C'; +const ORIGIN_TAG: u8 = b'O'; +const RELATION_TAG: u8 = b'R'; +const TYPE_TAG: u8 = b'Y'; +const INSERT_TAG: u8 = b'I'; +const UPDATE_TAG: u8 = b'U'; +const DELETE_TAG: u8 = b'D'; +const TRUNCATE_TAG: u8 = b'T'; +const TUPLE_NEW_TAG: u8 = b'N'; +const TUPLE_KEY_TAG: u8 = b'K'; +const TUPLE_OLD_TAG: u8 = b'O'; +const TUPLE_DATA_NULL_TAG: u8 = b'n'; +const TUPLE_DATA_TOAST_TAG: u8 = b'u'; +const TUPLE_DATA_TEXT_TAG: u8 = b't'; + +// replica identity tags +const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; +const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; +const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; +const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; + +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +#[derive(Debug)] +pub enum ReplicationMessage { + XLogData(XLogDataBody), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage { + #[inline] + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::()?; + let wal_end = buf.read_u64::()?; + let timestamp = buf.read_i64::()?; + let data = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + data, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::()?; + let timestamp = buf.read_i64::()?; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + +#[derive(Debug)] +pub struct XLogDataBody { + wal_start: u64, + wal_end: u64, + timestamp: i64, + data: D, +} + +impl XLogDataBody { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &D { + &self.data + } + + #[inline] + pub fn into_data(self) -> D { + self.data + } + + pub fn map_data(self, f: F) -> Result, E> + where + F: Fn(D) -> Result, + { + let data = f(self.data)?; + Ok(XLogDataBody { + wal_start: self.wal_start, + wal_end: self.wal_end, + timestamp: self.timestamp, + data, + }) + } +} + +#[derive(Debug)] +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: i64, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + +#[non_exhaustive] +/// A message of the logical replication stream +#[derive(Debug)] +pub enum LogicalReplicationMessage { + /// A BEGIN statement + Begin(BeginBody), + /// A BEGIN statement + Commit(CommitBody), + /// An Origin replication message + /// Note that there can be multiple Origin messages inside a single transaction. + Origin(OriginBody), + /// A Relation replication message + Relation(RelationBody), + /// A Type replication message + Type(TypeBody), + /// An INSERT statement + Insert(InsertBody), + /// An UPDATE statement + Update(UpdateBody), + /// A DELETE statement + Delete(DeleteBody), + /// A TRUNCATE statement + Truncate(TruncateBody), +} + +impl LogicalReplicationMessage { + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let logical_replication_message = match tag { + BEGIN_TAG => Self::Begin(BeginBody { + final_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + xid: buf.read_u32::()?, + }), + COMMIT_TAG => Self::Commit(CommitBody { + flags: buf.read_i8()?, + commit_lsn: buf.read_u64::()?, + end_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + }), + ORIGIN_TAG => Self::Origin(OriginBody { + commit_lsn: buf.read_u64::()?, + name: buf.read_cstr()?, + }), + RELATION_TAG => { + let rel_id = buf.read_u32::()?; + let namespace = buf.read_cstr()?; + let name = buf.read_cstr()?; + let replica_identity = match buf.read_u8()? { + REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, + REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, + REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, + REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replica identity tag `{}`", tag), + )); + } + }; + let column_len = buf.read_i16::()?; + + let mut columns = Vec::with_capacity(column_len as usize); + for _ in 0..column_len { + columns.push(Column::parse(&mut buf)?); + } + + Self::Relation(RelationBody { + rel_id, + namespace, + name, + replica_identity, + columns, + }) + } + TYPE_TAG => Self::Type(TypeBody { + id: buf.read_u32::()?, + namespace: buf.read_cstr()?, + name: buf.read_cstr()?, + }), + INSERT_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + }; + + Self::Insert(InsertBody { rel_id, tuple }) + } + UPDATE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + let new_tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + TUPLE_OLD_TAG | TUPLE_KEY_TAG => { + if tag == TUPLE_OLD_TAG { + old_tuple = Some(Tuple::parse(&mut buf)?); + } else { + key_tuple = Some(Tuple::parse(&mut buf)?); + } + + match buf.read_u8()? { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + } + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + }; + + Self::Update(UpdateBody { + rel_id, + key_tuple, + old_tuple, + new_tuple, + }) + } + DELETE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + match tag { + TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), + TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + } + + Self::Delete(DeleteBody { + rel_id, + key_tuple, + old_tuple, + }) + } + TRUNCATE_TAG => { + let relation_len = buf.read_i32::()?; + let options = buf.read_i8()?; + + let mut rel_ids = Vec::with_capacity(relation_len as usize); + for _ in 0..relation_len { + rel_ids.push(buf.read_u32::()?); + } + + Self::Truncate(TruncateBody { options, rel_ids }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(logical_replication_message) + } +} + +/// A row as it appears in the replication stream +#[derive(Debug)] +pub struct Tuple(Vec); + +impl Tuple { + #[inline] + /// The tuple data of this tuple + pub fn tuple_data(&self) -> &[TupleData] { + &self.0 + } +} + +impl Tuple { + fn parse(buf: &mut Buffer) -> io::Result { + let col_len = buf.read_i16::()?; + let mut tuple = Vec::with_capacity(col_len as usize); + for _ in 0..col_len { + tuple.push(TupleData::parse(buf)?); + } + + Ok(Tuple(tuple)) + } +} + +/// A column as it appears in the replication stream +#[derive(Debug)] +pub struct Column { + flags: i8, + name: Bytes, + type_id: i32, + type_modifier: i32, +} + +impl Column { + #[inline] + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as + /// part of the key. + pub fn flags(&self) -> i8 { + self.flags + } + + #[inline] + /// Name of the column. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// ID of the column's data type. + pub fn type_id(&self) -> i32 { + self.type_id + } + + #[inline] + /// Type modifier of the column (`atttypmod`). + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl Column { + fn parse(buf: &mut Buffer) -> io::Result { + Ok(Self { + flags: buf.read_i8()?, + name: buf.read_cstr()?, + type_id: buf.read_i32::()?, + type_modifier: buf.read_i32::()?, + }) + } +} + +/// The data of an individual column as it appears in the replication stream +#[derive(Debug)] +pub enum TupleData { + /// Represents a NULL value + Null, + /// Represents an unchanged TOASTed value (the actual value is not sent). + UnchangedToast, + /// Column data as text formatted value. + Text(Bytes), +} + +impl TupleData { + fn parse(buf: &mut Buffer) -> io::Result { + let type_tag = buf.read_u8()?; + + let tuple = match type_tag { + TUPLE_DATA_NULL_TAG => TupleData::Null, + TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, + TUPLE_DATA_TEXT_TAG => { + let len = buf.read_i32::()?; + let mut data = vec![0; len as usize]; + buf.read_exact(&mut data)?; + TupleData::Text(data.into()) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(tuple) + } +} + +/// A BEGIN statement +#[derive(Debug)] +pub struct BeginBody { + final_lsn: u64, + timestamp: i64, + xid: u32, +} + +impl BeginBody { + #[inline] + /// Gets the final lsn of the transaction + pub fn final_lsn(&self) -> Lsn { + self.final_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Xid of the transaction. + pub fn xid(&self) -> u32 { + self.xid + } +} + +/// A COMMIT statement +#[derive(Debug)] +pub struct CommitBody { + flags: i8, + commit_lsn: u64, + end_lsn: u64, + timestamp: i64, +} + +impl CommitBody { + #[inline] + /// The LSN of the commit. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// The end LSN of the transaction. + pub fn end_lsn(&self) -> Lsn { + self.end_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Flags; currently unused (will be 0). + pub fn flags(&self) -> i8 { + self.flags + } +} + +/// An Origin replication message +/// +/// Note that there can be multiple Origin messages inside a single transaction. +#[derive(Debug)] +pub struct OriginBody { + commit_lsn: u64, + name: Bytes, +} + +impl OriginBody { + #[inline] + /// The LSN of the commit on the origin server. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// Name of the origin. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// Describes the REPLICA IDENTITY setting of a table +#[derive(Debug)] +pub enum ReplicaIdentity { + /// default selection for replica identity (primary key or nothing) + Default, + /// no replica identity is logged for this relation + Nothing, + /// all columns are logged as replica identity + Full, + /// An explicitly chosen candidate key's columns are used as replica identity. + /// Note this will still be set if the index has been dropped; in that case it + /// has the same meaning as 'd'. + Index, +} + +/// A Relation replication message +#[derive(Debug)] +pub struct RelationBody { + rel_id: u32, + namespace: Bytes, + name: Bytes, + replica_identity: ReplicaIdentity, + columns: Vec, +} + +impl RelationBody { + #[inline] + /// ID of the relation. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Relation name. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// Replica identity setting for the relation + pub fn replica_identity(&self) -> &ReplicaIdentity { + &self.replica_identity + } + + #[inline] + /// The column definitions of this relation + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +/// A Type replication message +#[derive(Debug)] +pub struct TypeBody { + id: u32, + namespace: Bytes, + name: Bytes, +} + +impl TypeBody { + #[inline] + /// ID of the data type. + pub fn id(&self) -> Oid { + self.id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Name of the data type. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// An INSERT statement +#[derive(Debug)] +pub struct InsertBody { + rel_id: u32, + tuple: Tuple, +} + +impl InsertBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// The inserted tuple + pub fn tuple(&self) -> &Tuple { + &self.tuple + } +} + +/// An UPDATE statement +#[derive(Debug)] +pub struct UpdateBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, + new_tuple: Tuple, +} + +impl UpdateBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is optional and is only present if the update changed data in any of the + /// column(s) that are part of the REPLICA IDENTITY index. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is optional and is only present if table in which the update happened has + /// REPLICA IDENTITY set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } + + #[inline] + /// The new tuple + pub fn new_tuple(&self) -> &Tuple { + &self.new_tuple + } +} + +/// A DELETE statement +#[derive(Debug)] +pub struct DeleteBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, +} + +impl DeleteBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is present if the table in which the delete has happened uses an index as + /// REPLICA IDENTITY. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is present if the table in which the delete has happened has REPLICA IDENTITY + /// set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } +} + +/// A TRUNCATE statement +#[derive(Debug)] +pub struct TruncateBody { + options: i8, + rel_ids: Vec, +} + +impl TruncateBody { + #[inline] + /// The IDs of the relations corresponding to the ID in the relation messages + pub fn rel_ids(&self) -> &[u32] { + &self.rel_ids + } + + #[inline] + /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + pub fn options(&self) -> i8 { + self.options + } +} + +struct Buffer { + bytes: Bytes, + idx: usize, +} + +impl Buffer { + #[inline] + fn slice(&self) -> &[u8] { + &self.bytes[self.idx..] + } + + #[inline] + fn read_cstr(&mut self) -> io::Result { + match memchr(0, self.slice()) { + Some(pos) => { + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + #[inline] + fn read_all(&mut self) -> Bytes { + let buf = self.bytes.slice(self.idx..); + self.idx = self.bytes.len(); + buf + } +} + +impl Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = { + let slice = self.slice(); + let len = cmp::min(slice.len(), buf.len()); + buf[..len].copy_from_slice(&slice[..len]); + len + }; + self.idx += len; + Ok(len) + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/postgres-replication/tests/replication.rs b/postgres-replication/tests/replication.rs new file mode 100644 index 000000000..49700ef8c --- /dev/null +++ b/postgres-replication/tests/replication.rs @@ -0,0 +1,150 @@ +use std::time::{Duration, UNIX_EPOCH}; + +use futures_util::StreamExt; + +use postgres_replication::protocol::LogicalReplicationMessage::{Begin, Commit, Insert}; +use postgres_replication::protocol::ReplicationMessage::*; +use postgres_replication::protocol::TupleData; +use postgres_replication::LogicalReplicationStream; +use postgres_types::PgLsn; +use tokio_postgres::NoTls; +use tokio_postgres::SimpleQueryMessage::Row; + +#[tokio::test] +async fn test_replication() { + // form SQL connection + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client + .simple_query("DROP TABLE IF EXISTS test_logical_replication") + .await + .unwrap(); + client + .simple_query("CREATE TABLE test_logical_replication(i int)") + .await + .unwrap(); + let res = client + .simple_query("SELECT 'test_logical_replication'::regclass::oid") + .await + .unwrap(); + let rel_id: u32 = if let Row(row) = &res[1] { + row.get("oid").unwrap().parse().unwrap() + } else { + panic!("unexpeced query message"); + }; + + client + .simple_query("DROP PUBLICATION IF EXISTS test_pub") + .await + .unwrap(); + client + .simple_query("CREATE PUBLICATION test_pub FOR ALL TABLES") + .await + .unwrap(); + + let slot = "test_logical_slot"; + + let query = format!( + r#"CREATE_REPLICATION_SLOT {:?} TEMPORARY LOGICAL "pgoutput""#, + slot + ); + let slot_query = client.simple_query(&query).await.unwrap(); + let lsn = if let Row(row) = &slot_query[1] { + row.get("consistent_point").unwrap() + } else { + panic!("unexpeced query message"); + }; + + // issue a query that will appear in the slot's stream since it happened after its creation + client + .simple_query("INSERT INTO test_logical_replication VALUES (42)") + .await + .unwrap(); + + let options = r#"("proto_version" '1', "publication_names" 'test_pub')"#; + let query = format!( + r#"START_REPLICATION SLOT {:?} LOGICAL {} {}"#, + slot, lsn, options + ); + let copy_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let stream = LogicalReplicationStream::new(copy_stream); + tokio::pin!(stream); + + // verify that we can observe the transaction in the replication stream + let begin = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Begin(begin) = body.into_data() { + break begin; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + let insert = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Insert(insert) = body.into_data() { + break insert; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + let commit = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Commit(commit) = body.into_data() { + break commit; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + assert_eq!(begin.final_lsn(), commit.commit_lsn()); + assert_eq!(insert.rel_id(), rel_id); + + let tuple_data = insert.tuple().tuple_data(); + assert_eq!(tuple_data.len(), 1); + assert!(matches!(tuple_data[0], TupleData::Text(_))); + if let TupleData::Text(data) = &tuple_data[0] { + assert_eq!(data, &b"42"[..]); + } + + // Send a standby status update and require a keep alive response + let lsn: PgLsn = lsn.parse().unwrap(); + let epoch = UNIX_EPOCH + Duration::from_secs(946_684_800); + let ts = epoch.elapsed().unwrap().as_micros() as i64; + stream + .as_mut() + .standby_status_update(lsn, lsn, lsn, ts, 1) + .await + .unwrap(); + loop { + match stream.next().await { + Some(Ok(PrimaryKeepAlive(_))) => break, + Some(Ok(_)) => (), + Some(Err(e)) => panic!("unexpected replication stream error: {}", e), + None => panic!("unexpected replication stream end"), + } + } +} diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index 75664d258..e35d4d4e4 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -86,7 +86,8 @@ pub struct DbError { } impl DbError { - pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result { + /// Parses the error fields obtained from Postgres into a `DBError`. + pub fn parse(fields: &mut ErrorFields<'_>) -> io::Result { let mut severity = None; let mut parsed_severity = None; let mut code = None; @@ -446,7 +447,8 @@ impl Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { + /// Constructs an `UnexpectedMessage` error. + pub fn unexpected_message() -> Error { Error::new(Kind::UnexpectedMessage, None) } @@ -458,7 +460,8 @@ impl Error { } } - pub(crate) fn parse(e: io::Error) -> Error { + /// Constructs a `Parse` error wrapping the provided one. + pub fn parse(e: io::Error) -> Error { Error::new(Kind::Parse, Some(Box::new(e))) }