diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0ae551dc..a950428ec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,6 +47,33 @@ jobs: key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - run: cargo clippy --all --all-targets + check-wasm32: + name: check-wasm32 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - run: rustup target add wasm32-unknown-unknown + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: check-wasm32-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js + test: name: test runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 4752836a7..f31993b07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,11 @@ members = [ "postgres-native-tls", "postgres-openssl", "postgres-protocol", + "postgres-replication", "postgres-types", "tokio-postgres", ] +resolver = "2" [profile.release] debug = 2 diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs new file mode 100644 index 000000000..52d0ba8f6 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs @@ -0,0 +1,31 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchStruct { + a: i32, +} + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchStruct { + a: i32, +} + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent, allow_mismatch)] +struct TransparentFromSqlAllowMismatchStruct(i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch, transparent)] +struct AllowMismatchFromSqlTransparentStruct(i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr new file mode 100644 index 000000000..a8e573248 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr @@ -0,0 +1,43 @@ +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:4:1 + | +4 | / #[postgres(allow_mismatch)] +5 | | struct ToSqlAllowMismatchStruct { +6 | | a: i32, +7 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:10:1 + | +10 | / #[postgres(allow_mismatch)] +11 | | struct FromSqlAllowMismatchStruct { +12 | | a: i32, +13 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:16:1 + | +16 | / #[postgres(allow_mismatch)] +17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32); + | |_______________________________________________^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:20:1 + | +20 | / #[postgres(allow_mismatch)] +21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32); + | |_________________________________________________^ + +error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)] + --> src/compile-fail/invalid-allow-mismatch.rs:24:25 + | +24 | #[postgres(transparent, allow_mismatch)] + | ^^^^^^^^^^^^^^ + +error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)] + --> src/compile-fail/invalid-allow-mismatch.rs:28:28 + | +28 | #[postgres(allow_mismatch, transparent)] + | ^^^^^^^^^^^ diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index a1b76345f..50a22790d 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -89,6 +89,49 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")] + struct InventoryItem { + name: String, + supplier_id: i32, + #[postgres(name = "Price")] + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + \"NAME\" TEXT, + \"SUPPLIER_ID\" INT, + \"Price\" DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + #[test] fn wrong_name() { #[derive(FromSql, ToSql, Debug, PartialEq)] diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index a7039ca05..f3e6c488c 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -1,5 +1,5 @@ use crate::test_type; -use postgres::{Client, NoTls}; +use postgres::{error::DbError, Client, NoTls}; use postgres_types::{FromSql, ToSql, WrongType}; use std::error::Error; @@ -53,6 +53,35 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood", rename_all = "snake_case")] + enum Mood { + VerySad, + #[postgres(name = "okay")] + Ok, + VeryHappy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('very_sad', 'okay', 'very_happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::VerySad, "'very_sad'"), + (Mood::Ok, "'okay'"), + (Mood::VeryHappy, "'very_happy'"), + ], + ); +} + #[test] fn wrong_name() { #[derive(Debug, ToSql, FromSql, PartialEq)] @@ -102,3 +131,73 @@ fn missing_variant() { let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); assert!(err.source().unwrap().is::()); } + +#[test] +fn allow_mismatch_enums() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Bar); +} + +#[test] +fn missing_enum_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn + .query_one("SELECT $1::\"Foo\"", &[&Foo::Buz]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_and_renaming() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo", allow_mismatch)] + enum Foo { + #[postgres(name = "bar")] + Bar, + #[postgres(name = "buz")] + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Buz); +} + +#[test] +fn wrong_name_and_allow_mismatch() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} diff --git a/postgres-derive/CHANGELOG.md b/postgres-derive/CHANGELOG.md index 22714acc2..b0075fa8e 100644 --- a/postgres-derive/CHANGELOG.md +++ b/postgres-derive/CHANGELOG.md @@ -1,5 +1,12 @@ # Change Log +## v0.4.5 - 2023-08-19 + +### Added + +* Added a `rename_all` option for enum and struct derives. +* Added an `allow_mismatch` option to disable strict enum variant checks against the Postgres type. + ## v0.4.4 - 2023-03-27 ### Changed diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 535a64315..cbae6c77b 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "postgres-derive" -version = "0.4.4" +version = "0.4.5" authors = ["Steven Fackler "] -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" edition = "2018" description = "An internal crate used by postgres-types" repository = "https://github.com/sfackler/rust-postgres" @@ -15,3 +15,4 @@ test = false syn = "2.0" proc-macro2 = "1.0" quote = "1.0" +heck = "0.5" diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs index 63473863a..a68538dcc 100644 --- a/postgres-derive/src/accepts.rs +++ b/postgres-derive/src/accepts.rs @@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream { } } -pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream { +pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream { let num_variants = variants.len(); let variant_names = variants.iter().map(|v| &v.name); - quote! { - if type_.name() != #name { - return false; + if allow_mismatch { + quote! { + type_.name() == #name } + } else { + quote! { + if type_.name() != #name { + return false; + } - match *type_.kind() { - ::postgres_types::Kind::Enum(ref variants) => { - if variants.len() != #num_variants { - return false; - } - - variants.iter().all(|v| { - match &**v { - #( - #variant_names => true, - )* - _ => false, + match *type_.kind() { + ::postgres_types::Kind::Enum(ref variants) => { + if variants.len() != #num_variants { + return false; } - }) + + variants.iter().all(|v| { + match &**v { + #( + #variant_names => true, + )* + _ => false, + } + }) + } + _ => false, } - _ => false, } } } diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs new file mode 100644 index 000000000..20ecc8eed --- /dev/null +++ b/postgres-derive/src/case.rs @@ -0,0 +1,110 @@ +#[allow(deprecated, unused_imports)] +use std::ascii::AsciiExt; + +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTrainCase, + ToUpperCamelCase, +}; + +use self::RenameRule::*; + +/// The different possible ways to change case of fields in a struct, or variants in an enum. +#[allow(clippy::enum_variant_names)] +#[derive(Copy, Clone, PartialEq)] +pub enum RenameRule { + /// Rename direct children to "lowercase" style. + LowerCase, + /// Rename direct children to "UPPERCASE" style. + UpperCase, + /// Rename direct children to "PascalCase" style, as typically used for + /// enum variants. + PascalCase, + /// Rename direct children to "camelCase" style. + CamelCase, + /// Rename direct children to "snake_case" style, as commonly used for + /// fields. + SnakeCase, + /// Rename direct children to "SCREAMING_SNAKE_CASE" style, as commonly + /// used for constants. + ScreamingSnakeCase, + /// Rename direct children to "kebab-case" style. + KebabCase, + /// Rename direct children to "SCREAMING-KEBAB-CASE" style. + ScreamingKebabCase, + + /// Rename direct children to "Train-Case" style. + TrainCase, +} + +pub const RENAME_RULES: &[&str] = &[ + "lowercase", + "UPPERCASE", + "PascalCase", + "camelCase", + "snake_case", + "SCREAMING_SNAKE_CASE", + "kebab-case", + "SCREAMING-KEBAB-CASE", + "Train-Case", +]; + +impl RenameRule { + pub fn from_str(rule: &str) -> Option { + match rule { + "lowercase" => Some(LowerCase), + "UPPERCASE" => Some(UpperCase), + "PascalCase" => Some(PascalCase), + "camelCase" => Some(CamelCase), + "snake_case" => Some(SnakeCase), + "SCREAMING_SNAKE_CASE" => Some(ScreamingSnakeCase), + "kebab-case" => Some(KebabCase), + "SCREAMING-KEBAB-CASE" => Some(ScreamingKebabCase), + "Train-Case" => Some(TrainCase), + _ => None, + } + } + /// Apply a renaming rule to an enum or struct field, returning the version expected in the source. + pub fn apply_to_field(&self, variant: &str) -> String { + match *self { + LowerCase => variant.to_lowercase(), + UpperCase => variant.to_uppercase(), + PascalCase => variant.to_upper_camel_case(), + CamelCase => variant.to_lower_camel_case(), + SnakeCase => variant.to_snake_case(), + ScreamingSnakeCase => variant.to_shouty_snake_case(), + KebabCase => variant.to_kebab_case(), + ScreamingKebabCase => variant.to_shouty_kebab_case(), + TrainCase => variant.to_train_case(), + } + } +} + +#[test] +fn rename_field() { + for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ + ( + "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "VeryTasty", + "verytasty", + "VERYTASTY", + "veryTasty", + "very_tasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("A", "a", "A", "a", "a", "A", "a", "A"), + ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(LowerCase.apply_to_field(original), lower); + assert_eq!(UpperCase.apply_to_field(original), upper); + assert_eq!(PascalCase.apply_to_field(original), original); + assert_eq!(CamelCase.apply_to_field(original), camel); + assert_eq!(SnakeCase.apply_to_field(original), snake); + assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); + assert_eq!(KebabCase.apply_to_field(original), kebab); + assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); + } +} diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index 15bfabc13..b6aad8ab3 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -4,7 +4,7 @@ use syn::{ TypeParamBound, }; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Field { pub name: String, @@ -13,18 +13,26 @@ pub struct Field { } impl Field { - pub fn parse(raw: &syn::Field) -> Result { - let overrides = Overrides::extract(&raw.attrs)?; - + pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { + let overrides = Overrides::extract(&raw.attrs, false)?; let ident = raw.ident.as_ref().unwrap().clone(); - Ok(Field { - name: overrides.name.unwrap_or_else(|| { + + // field level name override takes precendence over container level rename_all override + let name = match overrides.name { + Some(n) => n, + None => { let name = ident.to_string(); - match name.strip_prefix("r#") { - Some(name) => name.to_string(), - None => name, + let stripped = name.strip_prefix("r#").map(String::from).unwrap_or(name); + + match rename_all { + Some(rule) => rule.apply_to_field(&stripped), + None => stripped, } - }), + } + }; + + Ok(Field { + name, ident, type_: raw.ty.clone(), }) diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3c6bc7113..9a6dfa926 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -1,6 +1,6 @@ use syn::{Error, Fields, Ident}; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Variant { pub ident: Ident, @@ -8,7 +8,7 @@ pub struct Variant { } impl Variant { - pub fn parse(raw: &syn::Variant) -> Result { + pub fn parse(raw: &syn::Variant, rename_all: Option) -> Result { match raw.fields { Fields::Unit => {} _ => { @@ -18,11 +18,16 @@ impl Variant { )) } } + let overrides = Overrides::extract(&raw.attrs, false)?; - let overrides = Overrides::extract(&raw.attrs)?; + // variant level name override takes precendence over container level rename_all override + let name = overrides.name.unwrap_or_else(|| match rename_all { + Some(rule) => rule.apply_to_field(&raw.ident.to_string()), + None => raw.ident.to_string(), + }); Ok(Variant { ident: raw.ident.clone(), - name: overrides.name.unwrap_or_else(|| raw.ident.to_string()), + name, }) } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index bb87ded5f..d3ac47f4f 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -15,16 +15,19 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -45,16 +48,36 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )) } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } @@ -75,7 +98,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "FromSql", &fields), diff --git a/postgres-derive/src/lib.rs b/postgres-derive/src/lib.rs index 98e6add24..b849096c9 100644 --- a/postgres-derive/src/lib.rs +++ b/postgres-derive/src/lib.rs @@ -7,6 +7,7 @@ use proc_macro::TokenStream; use syn::parse_macro_input; mod accepts; +mod case; mod composites; mod enums; mod fromsql; diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index ddb37688b..d50550bee 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -1,16 +1,22 @@ use syn::punctuated::Punctuated; use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token}; +use crate::case::{RenameRule, RENAME_RULES}; + pub struct Overrides { pub name: Option, + pub rename_all: Option, pub transparent: bool, + pub allow_mismatch: bool, } impl Overrides { - pub fn extract(attrs: &[Attribute]) -> Result { + pub fn extract(attrs: &[Attribute], container_attr: bool) -> Result { let mut overrides = Overrides { name: None, + rename_all: None, transparent: false, + allow_mismatch: false, }; for attr in attrs { @@ -28,7 +34,15 @@ impl Overrides { for item in nested { match item { Meta::NameValue(meta) => { - if !meta.path.is_ident("name") { + let name_override = meta.path.is_ident("name"); + let rename_all_override = meta.path.is_ident("rename_all"); + if !container_attr && rename_all_override { + return Err(Error::new_spanned( + &meta.path, + "rename_all is a container attribute", + )); + } + if !name_override && !rename_all_override { return Err(Error::new_spanned(&meta.path, "unknown override")); } @@ -41,14 +55,46 @@ impl Overrides { } }; - overrides.name = Some(value); + if name_override { + overrides.name = Some(value); + } else if rename_all_override { + let rename_rule = RenameRule::from_str(&value).ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule)) + .collect::>() + .join(", ") + ), + ) + })?; + + overrides.rename_all = Some(rename_rule); + } } Meta::Path(path) => { - if !path.is_ident("transparent") { + if path.is_ident("transparent") { + if overrides.allow_mismatch { + return Err(Error::new_spanned( + path, + "#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]", + )); + } + overrides.transparent = true; + } else if path.is_ident("allow_mismatch") { + if overrides.transparent { + return Err(Error::new_spanned( + path, + "#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]", + )); + } + overrides.allow_mismatch = true; + } else { return Err(Error::new_spanned(path, "unknown override")); } - - overrides.transparent = true; } bad => return Err(Error::new_spanned(bad, "unknown attribute")), } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index e51acc7fd..81d4834bf 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -13,16 +13,19 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -41,16 +44,36 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } @@ -69,7 +92,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "ToSql", &fields), diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index 1f2f6385d..02259b3dc 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -3,7 +3,7 @@ name = "postgres-native-tls" version = "0.5.0" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "TLS support for tokio-postgres via native-tls" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" @@ -19,9 +19,9 @@ runtime = ["tokio-postgres/runtime"] native-tls = "0.2" tokio = "1.0" tokio-native-tls = "0.3" -tokio-postgres = { version = "0.7.0", path = "../tokio-postgres", default-features = false } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } [dev-dependencies] futures-util = "0.3" tokio = { version = "1.0", features = ["macros", "net", "rt"] } -postgres = { version = "0.19.0", path = "../postgres" } +postgres = { version = "0.19.8", path = "../postgres" } diff --git a/postgres-openssl/Cargo.toml b/postgres-openssl/Cargo.toml index 8671308af..9013384a2 100644 --- a/postgres-openssl/Cargo.toml +++ b/postgres-openssl/Cargo.toml @@ -3,7 +3,7 @@ name = "postgres-openssl" version = "0.5.0" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "TLS support for tokio-postgres via openssl" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" @@ -19,9 +19,9 @@ runtime = ["tokio-postgres/runtime"] openssl = "0.10" tokio = "1.0" tokio-openssl = "0.6" -tokio-postgres = { version = "0.7.0", path = "../tokio-postgres", default-features = false } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } [dev-dependencies] futures-util = "0.3" tokio = { version = "1.0", features = ["macros", "net", "rt"] } -postgres = { version = "0.19.0", path = "../postgres" } +postgres = { version = "0.19.8", path = "../postgres" } diff --git a/postgres-protocol/CHANGELOG.md b/postgres-protocol/CHANGELOG.md index 034fd637c..54dce91b0 100644 --- a/postgres-protocol/CHANGELOG.md +++ b/postgres-protocol/CHANGELOG.md @@ -1,5 +1,26 @@ # Change Log +## v0.6.7 - 2024-07-21 + +### Deprecated + +* Deprecated `ErrorField::value`. + +### Added + +* Added a `Clone` implementation for `DataRowBody`. +* Added `ErrorField::value_bytes`. + +### Changed + +* Upgraded `base64`. + +## v0.6.6 - 2023-08-19 + +### Added + +* Added the `js` feature for WASM support. + ## v0.6.5 - 2023-03-27 ### Added diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index e32211369..49cf2d59c 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,15 +1,19 @@ [package] name = "postgres-protocol" -version = "0.6.5" +version = "0.6.7" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" +[features] +default = [] +js = ["getrandom/js"] + [dependencies] -base64 = "0.21" +base64 = "0.22" byteorder = "1.0" bytes = "1.0" fallible-iterator = "0.2" @@ -19,3 +23,4 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" +getrandom = { version = "0.2", optional = true } diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..e0de3b6c6 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -9,7 +9,6 @@ //! //! This library assumes that the `client_encoding` backend parameter has been //! set to `UTF8`. It will most likely not behave properly if that is not the case. -#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] #![warn(missing_docs, rust_2018_idioms, clippy::all)] use byteorder::{BigEndian, ByteOrder}; @@ -61,7 +60,7 @@ macro_rules! from_usize { impl FromUsize for $t { #[inline] fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::max_value() as usize { + if x > <$t>::MAX as usize { Err(io::Error::new( io::ErrorKind::InvalidInput, "value too large to transmit", diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index fcfcfd260..fdc83fedb 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -9,9 +9,8 @@ use std::io::{self, Read}; use std::ops::Range; use std::str; -use crate::{Lsn, Oid}; +use crate::Oid; -// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -35,33 +34,6 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; -// replication message tags -pub const XLOG_DATA_TAG: u8 = b'w'; -pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; - -// logical replication message tags -const BEGIN_TAG: u8 = b'B'; -const COMMIT_TAG: u8 = b'C'; -const ORIGIN_TAG: u8 = b'O'; -const RELATION_TAG: u8 = b'R'; -const TYPE_TAG: u8 = b'Y'; -const INSERT_TAG: u8 = b'I'; -const UPDATE_TAG: u8 = b'U'; -const DELETE_TAG: u8 = b'D'; -const TRUNCATE_TAG: u8 = b'T'; -const TUPLE_NEW_TAG: u8 = b'N'; -const TUPLE_KEY_TAG: u8 = b'K'; -const TUPLE_OLD_TAG: u8 = b'O'; -const TUPLE_DATA_NULL_TAG: u8 = b'n'; -const TUPLE_DATA_TOAST_TAG: u8 = b'u'; -const TUPLE_DATA_TEXT_TAG: u8 = b't'; - -// replica identity tags -const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; -const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; -const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; -const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; - #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -318,59 +290,6 @@ impl Message { } } -/// An enum representing Postgres backend replication messages. -#[non_exhaustive] -#[derive(Debug)] -pub enum ReplicationMessage { - XLogData(XLogDataBody), - PrimaryKeepAlive(PrimaryKeepAliveBody), -} - -impl ReplicationMessage { - #[inline] - pub fn parse(buf: &Bytes) -> io::Result { - let mut buf = Buffer { - bytes: buf.clone(), - idx: 0, - }; - - let tag = buf.read_u8()?; - - let replication_message = match tag { - XLOG_DATA_TAG => { - let wal_start = buf.read_u64::()?; - let wal_end = buf.read_u64::()?; - let timestamp = buf.read_i64::()?; - let data = buf.read_all(); - ReplicationMessage::XLogData(XLogDataBody { - wal_start, - wal_end, - timestamp, - data, - }) - } - PRIMARY_KEEPALIVE_TAG => { - let wal_end = buf.read_u64::()?; - let timestamp = buf.read_i64::()?; - let reply = buf.read_u8()?; - ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { - wal_end, - timestamp, - reply, - }) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replication message tag `{}`", tag), - )); - } - }; - - Ok(replication_message) - } -} - struct Buffer { bytes: Bytes, idx: usize, @@ -617,10 +536,11 @@ impl CopyOutResponseBody { } } +#[derive(Debug, Clone)] pub struct CopyBothResponseBody { - storage: Bytes, - len: u16, format: u8, + len: u16, + storage: Bytes, } impl CopyBothResponseBody { @@ -638,7 +558,7 @@ impl CopyBothResponseBody { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct DataRowBody { storage: Bytes, len: u16, @@ -747,7 +667,7 @@ impl<'a> FallibleIterator for ErrorFields<'a> { } let value_end = find_null(self.buf, 0)?; - let value = get_str(&self.buf[..value_end])?; + let value = &self.buf[..value_end]; self.buf = &self.buf[value_end + 1..]; Ok(Some(ErrorField { type_, value })) @@ -756,7 +676,7 @@ impl<'a> FallibleIterator for ErrorFields<'a> { pub struct ErrorField<'a> { type_: u8, - value: &'a str, + value: &'a [u8], } impl<'a> ErrorField<'a> { @@ -766,7 +686,13 @@ impl<'a> ErrorField<'a> { } #[inline] + #[deprecated(note = "use value_bytes instead", since = "0.6.7")] pub fn value(&self) -> &str { + str::from_utf8(self.value).expect("error field value contained non-UTF8 bytes") + } + + #[inline] + pub fn value_bytes(&self) -> &[u8] { self.value } } @@ -896,655 +822,6 @@ impl RowDescriptionBody { } } -#[derive(Debug)] -pub struct XLogDataBody { - wal_start: u64, - wal_end: u64, - timestamp: i64, - data: D, -} - -impl XLogDataBody { - #[inline] - pub fn wal_start(&self) -> u64 { - self.wal_start - } - - #[inline] - pub fn wal_end(&self) -> u64 { - self.wal_end - } - - #[inline] - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - pub fn data(&self) -> &D { - &self.data - } - - #[inline] - pub fn into_data(self) -> D { - self.data - } - - pub fn map_data(self, f: F) -> Result, E> - where - F: Fn(D) -> Result, - { - let data = f(self.data)?; - Ok(XLogDataBody { - wal_start: self.wal_start, - wal_end: self.wal_end, - timestamp: self.timestamp, - data, - }) - } -} - -#[derive(Debug)] -pub struct PrimaryKeepAliveBody { - wal_end: u64, - timestamp: i64, - reply: u8, -} - -impl PrimaryKeepAliveBody { - #[inline] - pub fn wal_end(&self) -> u64 { - self.wal_end - } - - #[inline] - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - pub fn reply(&self) -> u8 { - self.reply - } -} - -#[non_exhaustive] -/// A message of the logical replication stream -#[derive(Debug)] -pub enum LogicalReplicationMessage { - /// A BEGIN statement - Begin(BeginBody), - /// A BEGIN statement - Commit(CommitBody), - /// An Origin replication message - /// Note that there can be multiple Origin messages inside a single transaction. - Origin(OriginBody), - /// A Relation replication message - Relation(RelationBody), - /// A Type replication message - Type(TypeBody), - /// An INSERT statement - Insert(InsertBody), - /// An UPDATE statement - Update(UpdateBody), - /// A DELETE statement - Delete(DeleteBody), - /// A TRUNCATE statement - Truncate(TruncateBody), -} - -impl LogicalReplicationMessage { - pub fn parse(buf: &Bytes) -> io::Result { - let mut buf = Buffer { - bytes: buf.clone(), - idx: 0, - }; - - let tag = buf.read_u8()?; - - let logical_replication_message = match tag { - BEGIN_TAG => Self::Begin(BeginBody { - final_lsn: buf.read_u64::()?, - timestamp: buf.read_i64::()?, - xid: buf.read_u32::()?, - }), - COMMIT_TAG => Self::Commit(CommitBody { - flags: buf.read_i8()?, - commit_lsn: buf.read_u64::()?, - end_lsn: buf.read_u64::()?, - timestamp: buf.read_i64::()?, - }), - ORIGIN_TAG => Self::Origin(OriginBody { - commit_lsn: buf.read_u64::()?, - name: buf.read_cstr()?, - }), - RELATION_TAG => { - let rel_id = buf.read_u32::()?; - let namespace = buf.read_cstr()?; - let name = buf.read_cstr()?; - let replica_identity = match buf.read_u8()? { - REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, - REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, - REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, - REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replica identity tag `{}`", tag), - )); - } - }; - let column_len = buf.read_i16::()?; - - let mut columns = Vec::with_capacity(column_len as usize); - for _ in 0..column_len { - columns.push(Column::parse(&mut buf)?); - } - - Self::Relation(RelationBody { - rel_id, - namespace, - name, - replica_identity, - columns, - }) - } - TYPE_TAG => Self::Type(TypeBody { - id: buf.read_u32::()?, - namespace: buf.read_cstr()?, - name: buf.read_cstr()?, - }), - INSERT_TAG => { - let rel_id = buf.read_u32::()?; - let tag = buf.read_u8()?; - - let tuple = match tag { - TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unexpected tuple tag `{}`", tag), - )); - } - }; - - Self::Insert(InsertBody { rel_id, tuple }) - } - UPDATE_TAG => { - let rel_id = buf.read_u32::()?; - let tag = buf.read_u8()?; - - let mut key_tuple = None; - let mut old_tuple = None; - - let new_tuple = match tag { - TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, - TUPLE_OLD_TAG | TUPLE_KEY_TAG => { - if tag == TUPLE_OLD_TAG { - old_tuple = Some(Tuple::parse(&mut buf)?); - } else { - key_tuple = Some(Tuple::parse(&mut buf)?); - } - - match buf.read_u8()? { - TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unexpected tuple tag `{}`", tag), - )); - } - } - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown tuple tag `{}`", tag), - )); - } - }; - - Self::Update(UpdateBody { - rel_id, - key_tuple, - old_tuple, - new_tuple, - }) - } - DELETE_TAG => { - let rel_id = buf.read_u32::()?; - let tag = buf.read_u8()?; - - let mut key_tuple = None; - let mut old_tuple = None; - - match tag { - TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), - TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown tuple tag `{}`", tag), - )); - } - } - - Self::Delete(DeleteBody { - rel_id, - key_tuple, - old_tuple, - }) - } - TRUNCATE_TAG => { - let relation_len = buf.read_i32::()?; - let options = buf.read_i8()?; - - let mut rel_ids = Vec::with_capacity(relation_len as usize); - for _ in 0..relation_len { - rel_ids.push(buf.read_u32::()?); - } - - Self::Truncate(TruncateBody { options, rel_ids }) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replication message tag `{}`", tag), - )); - } - }; - - Ok(logical_replication_message) - } -} - -/// A row as it appears in the replication stream -#[derive(Debug)] -pub struct Tuple(Vec); - -impl Tuple { - #[inline] - /// The tuple data of this tuple - pub fn tuple_data(&self) -> &[TupleData] { - &self.0 - } -} - -impl Tuple { - fn parse(buf: &mut Buffer) -> io::Result { - let col_len = buf.read_i16::()?; - let mut tuple = Vec::with_capacity(col_len as usize); - for _ in 0..col_len { - tuple.push(TupleData::parse(buf)?); - } - - Ok(Tuple(tuple)) - } -} - -/// A column as it appears in the replication stream -#[derive(Debug)] -pub struct Column { - flags: i8, - name: Bytes, - type_id: i32, - type_modifier: i32, -} - -impl Column { - #[inline] - /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as - /// part of the key. - pub fn flags(&self) -> i8 { - self.flags - } - - #[inline] - /// Name of the column. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } - - #[inline] - /// ID of the column's data type. - pub fn type_id(&self) -> i32 { - self.type_id - } - - #[inline] - /// Type modifier of the column (`atttypmod`). - pub fn type_modifier(&self) -> i32 { - self.type_modifier - } -} - -impl Column { - fn parse(buf: &mut Buffer) -> io::Result { - Ok(Self { - flags: buf.read_i8()?, - name: buf.read_cstr()?, - type_id: buf.read_i32::()?, - type_modifier: buf.read_i32::()?, - }) - } -} - -/// The data of an individual column as it appears in the replication stream -#[derive(Debug)] -pub enum TupleData { - /// Represents a NULL value - Null, - /// Represents an unchanged TOASTed value (the actual value is not sent). - UnchangedToast, - /// Column data as text formatted value. - Text(Bytes), -} - -impl TupleData { - fn parse(buf: &mut Buffer) -> io::Result { - let type_tag = buf.read_u8()?; - - let tuple = match type_tag { - TUPLE_DATA_NULL_TAG => TupleData::Null, - TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, - TUPLE_DATA_TEXT_TAG => { - let len = buf.read_i32::()?; - let mut data = vec![0; len as usize]; - buf.read_exact(&mut data)?; - TupleData::Text(data.into()) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown replication message tag `{}`", tag), - )); - } - }; - - Ok(tuple) - } -} - -/// A BEGIN statement -#[derive(Debug)] -pub struct BeginBody { - final_lsn: u64, - timestamp: i64, - xid: u32, -} - -impl BeginBody { - #[inline] - /// Gets the final lsn of the transaction - pub fn final_lsn(&self) -> Lsn { - self.final_lsn - } - - #[inline] - /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - /// Xid of the transaction. - pub fn xid(&self) -> u32 { - self.xid - } -} - -/// A COMMIT statement -#[derive(Debug)] -pub struct CommitBody { - flags: i8, - commit_lsn: u64, - end_lsn: u64, - timestamp: i64, -} - -impl CommitBody { - #[inline] - /// The LSN of the commit. - pub fn commit_lsn(&self) -> Lsn { - self.commit_lsn - } - - #[inline] - /// The end LSN of the transaction. - pub fn end_lsn(&self) -> Lsn { - self.end_lsn - } - - #[inline] - /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). - pub fn timestamp(&self) -> i64 { - self.timestamp - } - - #[inline] - /// Flags; currently unused (will be 0). - pub fn flags(&self) -> i8 { - self.flags - } -} - -/// An Origin replication message -/// -/// Note that there can be multiple Origin messages inside a single transaction. -#[derive(Debug)] -pub struct OriginBody { - commit_lsn: u64, - name: Bytes, -} - -impl OriginBody { - #[inline] - /// The LSN of the commit on the origin server. - pub fn commit_lsn(&self) -> Lsn { - self.commit_lsn - } - - #[inline] - /// Name of the origin. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } -} - -/// Describes the REPLICA IDENTITY setting of a table -#[derive(Debug)] -pub enum ReplicaIdentity { - /// default selection for replica identity (primary key or nothing) - Default, - /// no replica identity is logged for this relation - Nothing, - /// all columns are logged as replica identity - Full, - /// An explicitly chosen candidate key's columns are used as replica identity. - /// Note this will still be set if the index has been dropped; in that case it - /// has the same meaning as 'd'. - Index, -} - -/// A Relation replication message -#[derive(Debug)] -pub struct RelationBody { - rel_id: u32, - namespace: Bytes, - name: Bytes, - replica_identity: ReplicaIdentity, - columns: Vec, -} - -impl RelationBody { - #[inline] - /// ID of the relation. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// Namespace (empty string for pg_catalog). - pub fn namespace(&self) -> io::Result<&str> { - get_str(&self.namespace) - } - - #[inline] - /// Relation name. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } - - #[inline] - /// Replica identity setting for the relation - pub fn replica_identity(&self) -> &ReplicaIdentity { - &self.replica_identity - } - - #[inline] - /// The column definitions of this relation - pub fn columns(&self) -> &[Column] { - &self.columns - } -} - -/// A Type replication message -#[derive(Debug)] -pub struct TypeBody { - id: u32, - namespace: Bytes, - name: Bytes, -} - -impl TypeBody { - #[inline] - /// ID of the data type. - pub fn id(&self) -> Oid { - self.id - } - - #[inline] - /// Namespace (empty string for pg_catalog). - pub fn namespace(&self) -> io::Result<&str> { - get_str(&self.namespace) - } - - #[inline] - /// Name of the data type. - pub fn name(&self) -> io::Result<&str> { - get_str(&self.name) - } -} - -/// An INSERT statement -#[derive(Debug)] -pub struct InsertBody { - rel_id: u32, - tuple: Tuple, -} - -impl InsertBody { - #[inline] - /// ID of the relation corresponding to the ID in the relation message. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// The inserted tuple - pub fn tuple(&self) -> &Tuple { - &self.tuple - } -} - -/// An UPDATE statement -#[derive(Debug)] -pub struct UpdateBody { - rel_id: u32, - old_tuple: Option, - key_tuple: Option, - new_tuple: Tuple, -} - -impl UpdateBody { - #[inline] - /// ID of the relation corresponding to the ID in the relation message. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// This field is optional and is only present if the update changed data in any of the - /// column(s) that are part of the REPLICA IDENTITY index. - pub fn key_tuple(&self) -> Option<&Tuple> { - self.key_tuple.as_ref() - } - - #[inline] - /// This field is optional and is only present if table in which the update happened has - /// REPLICA IDENTITY set to FULL. - pub fn old_tuple(&self) -> Option<&Tuple> { - self.old_tuple.as_ref() - } - - #[inline] - /// The new tuple - pub fn new_tuple(&self) -> &Tuple { - &self.new_tuple - } -} - -/// A DELETE statement -#[derive(Debug)] -pub struct DeleteBody { - rel_id: u32, - old_tuple: Option, - key_tuple: Option, -} - -impl DeleteBody { - #[inline] - /// ID of the relation corresponding to the ID in the relation message. - pub fn rel_id(&self) -> u32 { - self.rel_id - } - - #[inline] - /// This field is present if the table in which the delete has happened uses an index as - /// REPLICA IDENTITY. - pub fn key_tuple(&self) -> Option<&Tuple> { - self.key_tuple.as_ref() - } - - #[inline] - /// This field is present if the table in which the delete has happened has REPLICA IDENTITY - /// set to FULL. - pub fn old_tuple(&self) -> Option<&Tuple> { - self.old_tuple.as_ref() - } -} - -/// A TRUNCATE statement -#[derive(Debug)] -pub struct TruncateBody { - options: i8, - rel_ids: Vec, -} - -impl TruncateBody { - #[inline] - /// The IDs of the relations corresponding to the ID in the relation messages - pub fn rel_ids(&self) -> &[u32] { - &self.rel_ids - } - - #[inline] - /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY - pub fn options(&self) -> i8 { - self.options - } -} - pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/postgres-replication/Cargo.toml b/postgres-replication/Cargo.toml new file mode 100644 index 000000000..f24cd6ccf --- /dev/null +++ b/postgres-replication/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "postgres-replication" +version = "0.6.7" +authors = ["Petros Angelatos "] +edition = "2018" +description = "Protocol definitions for the Postgres logical replication protocol" +license = "MIT OR Apache-2.0" +repository = "https://github.com/sfackler/rust-postgres" +readme = "../README.md" + +[features] +default = [] + +[dependencies] +bytes = "1.0" +memchr = "2.0" +byteorder = "1.0" +postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } +postgres-types = { version = "0.2.7", path = "../postgres-types" } +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", features = ["runtime"] } +futures-util = { version = "0.3", features = ["sink"] } +pin-project-lite = "0.2" + +[dev-dependencies] +tokio = { version = "1.0", features = [ + "macros", + "net", + "rt", + "rt-multi-thread", + "time", +] } diff --git a/postgres-replication/LICENSE-APACHE b/postgres-replication/LICENSE-APACHE new file mode 100644 index 000000000..16fe87b06 --- /dev/null +++ b/postgres-replication/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/postgres-replication/LICENSE-MIT b/postgres-replication/LICENSE-MIT new file mode 100644 index 000000000..71803aea1 --- /dev/null +++ b/postgres-replication/LICENSE-MIT @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2016 Steven Fackler + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/tokio-postgres/src/replication.rs b/postgres-replication/src/lib.rs similarity index 95% rename from tokio-postgres/src/replication.rs rename to postgres-replication/src/lib.rs index 1b49dcc42..08d17d4b8 100644 --- a/tokio-postgres/src/replication.rs +++ b/postgres-replication/src/lib.rs @@ -1,15 +1,18 @@ //! Utilities for working with the PostgreSQL replication copy both format. -use crate::copy_both::CopyBothDuplex; -use crate::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; + use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{ready, SinkExt, Stream}; use pin_project_lite::pin_project; -use postgres_protocol::message::backend::LogicalReplicationMessage; -use postgres_protocol::message::backend::ReplicationMessage; use postgres_types::PgLsn; -use std::pin::Pin; -use std::task::{Context, Poll}; +use tokio_postgres::CopyBothDuplex; +use tokio_postgres::Error; + +pub mod protocol; + +use crate::protocol::{LogicalReplicationMessage, ReplicationMessage}; const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; @@ -165,7 +168,6 @@ impl Stream for LogicalReplicationStream { Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) } - Some(Ok(_)) => Poll::Ready(Some(Err(Error::unexpected_message()))), Some(Err(err)) => Poll::Ready(Some(Err(err))), None => Poll::Ready(None), } diff --git a/postgres-replication/src/protocol.rs b/postgres-replication/src/protocol.rs new file mode 100644 index 000000000..d94825014 --- /dev/null +++ b/postgres-replication/src/protocol.rs @@ -0,0 +1,791 @@ +use std::io::{self, Read}; +use std::{cmp, str}; + +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::Bytes; +use memchr::memchr; +use postgres_protocol::{Lsn, Oid}; + +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + +// logical replication message tags +const BEGIN_TAG: u8 = b'B'; +const COMMIT_TAG: u8 = b'C'; +const ORIGIN_TAG: u8 = b'O'; +const RELATION_TAG: u8 = b'R'; +const TYPE_TAG: u8 = b'Y'; +const INSERT_TAG: u8 = b'I'; +const UPDATE_TAG: u8 = b'U'; +const DELETE_TAG: u8 = b'D'; +const TRUNCATE_TAG: u8 = b'T'; +const TUPLE_NEW_TAG: u8 = b'N'; +const TUPLE_KEY_TAG: u8 = b'K'; +const TUPLE_OLD_TAG: u8 = b'O'; +const TUPLE_DATA_NULL_TAG: u8 = b'n'; +const TUPLE_DATA_TOAST_TAG: u8 = b'u'; +const TUPLE_DATA_TEXT_TAG: u8 = b't'; + +// replica identity tags +const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; +const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; +const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; +const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; + +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +#[derive(Debug)] +pub enum ReplicationMessage { + XLogData(XLogDataBody), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage { + #[inline] + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::()?; + let wal_end = buf.read_u64::()?; + let timestamp = buf.read_i64::()?; + let data = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + data, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::()?; + let timestamp = buf.read_i64::()?; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + +#[derive(Debug)] +pub struct XLogDataBody { + wal_start: u64, + wal_end: u64, + timestamp: i64, + data: D, +} + +impl XLogDataBody { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &D { + &self.data + } + + #[inline] + pub fn into_data(self) -> D { + self.data + } + + pub fn map_data(self, f: F) -> Result, E> + where + F: Fn(D) -> Result, + { + let data = f(self.data)?; + Ok(XLogDataBody { + wal_start: self.wal_start, + wal_end: self.wal_end, + timestamp: self.timestamp, + data, + }) + } +} + +#[derive(Debug)] +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: i64, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + +#[non_exhaustive] +/// A message of the logical replication stream +#[derive(Debug)] +pub enum LogicalReplicationMessage { + /// A BEGIN statement + Begin(BeginBody), + /// A BEGIN statement + Commit(CommitBody), + /// An Origin replication message + /// Note that there can be multiple Origin messages inside a single transaction. + Origin(OriginBody), + /// A Relation replication message + Relation(RelationBody), + /// A Type replication message + Type(TypeBody), + /// An INSERT statement + Insert(InsertBody), + /// An UPDATE statement + Update(UpdateBody), + /// A DELETE statement + Delete(DeleteBody), + /// A TRUNCATE statement + Truncate(TruncateBody), +} + +impl LogicalReplicationMessage { + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let logical_replication_message = match tag { + BEGIN_TAG => Self::Begin(BeginBody { + final_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + xid: buf.read_u32::()?, + }), + COMMIT_TAG => Self::Commit(CommitBody { + flags: buf.read_i8()?, + commit_lsn: buf.read_u64::()?, + end_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + }), + ORIGIN_TAG => Self::Origin(OriginBody { + commit_lsn: buf.read_u64::()?, + name: buf.read_cstr()?, + }), + RELATION_TAG => { + let rel_id = buf.read_u32::()?; + let namespace = buf.read_cstr()?; + let name = buf.read_cstr()?; + let replica_identity = match buf.read_u8()? { + REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, + REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, + REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, + REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replica identity tag `{}`", tag), + )); + } + }; + let column_len = buf.read_i16::()?; + + let mut columns = Vec::with_capacity(column_len as usize); + for _ in 0..column_len { + columns.push(Column::parse(&mut buf)?); + } + + Self::Relation(RelationBody { + rel_id, + namespace, + name, + replica_identity, + columns, + }) + } + TYPE_TAG => Self::Type(TypeBody { + id: buf.read_u32::()?, + namespace: buf.read_cstr()?, + name: buf.read_cstr()?, + }), + INSERT_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + }; + + Self::Insert(InsertBody { rel_id, tuple }) + } + UPDATE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + let new_tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + TUPLE_OLD_TAG | TUPLE_KEY_TAG => { + if tag == TUPLE_OLD_TAG { + old_tuple = Some(Tuple::parse(&mut buf)?); + } else { + key_tuple = Some(Tuple::parse(&mut buf)?); + } + + match buf.read_u8()? { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + } + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + }; + + Self::Update(UpdateBody { + rel_id, + key_tuple, + old_tuple, + new_tuple, + }) + } + DELETE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + match tag { + TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), + TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + } + + Self::Delete(DeleteBody { + rel_id, + key_tuple, + old_tuple, + }) + } + TRUNCATE_TAG => { + let relation_len = buf.read_i32::()?; + let options = buf.read_i8()?; + + let mut rel_ids = Vec::with_capacity(relation_len as usize); + for _ in 0..relation_len { + rel_ids.push(buf.read_u32::()?); + } + + Self::Truncate(TruncateBody { options, rel_ids }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(logical_replication_message) + } +} + +/// A row as it appears in the replication stream +#[derive(Debug)] +pub struct Tuple(Vec); + +impl Tuple { + #[inline] + /// The tuple data of this tuple + pub fn tuple_data(&self) -> &[TupleData] { + &self.0 + } +} + +impl Tuple { + fn parse(buf: &mut Buffer) -> io::Result { + let col_len = buf.read_i16::()?; + let mut tuple = Vec::with_capacity(col_len as usize); + for _ in 0..col_len { + tuple.push(TupleData::parse(buf)?); + } + + Ok(Tuple(tuple)) + } +} + +/// A column as it appears in the replication stream +#[derive(Debug)] +pub struct Column { + flags: i8, + name: Bytes, + type_id: i32, + type_modifier: i32, +} + +impl Column { + #[inline] + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as + /// part of the key. + pub fn flags(&self) -> i8 { + self.flags + } + + #[inline] + /// Name of the column. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// ID of the column's data type. + pub fn type_id(&self) -> i32 { + self.type_id + } + + #[inline] + /// Type modifier of the column (`atttypmod`). + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl Column { + fn parse(buf: &mut Buffer) -> io::Result { + Ok(Self { + flags: buf.read_i8()?, + name: buf.read_cstr()?, + type_id: buf.read_i32::()?, + type_modifier: buf.read_i32::()?, + }) + } +} + +/// The data of an individual column as it appears in the replication stream +#[derive(Debug)] +pub enum TupleData { + /// Represents a NULL value + Null, + /// Represents an unchanged TOASTed value (the actual value is not sent). + UnchangedToast, + /// Column data as text formatted value. + Text(Bytes), +} + +impl TupleData { + fn parse(buf: &mut Buffer) -> io::Result { + let type_tag = buf.read_u8()?; + + let tuple = match type_tag { + TUPLE_DATA_NULL_TAG => TupleData::Null, + TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, + TUPLE_DATA_TEXT_TAG => { + let len = buf.read_i32::()?; + let mut data = vec![0; len as usize]; + buf.read_exact(&mut data)?; + TupleData::Text(data.into()) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(tuple) + } +} + +/// A BEGIN statement +#[derive(Debug)] +pub struct BeginBody { + final_lsn: u64, + timestamp: i64, + xid: u32, +} + +impl BeginBody { + #[inline] + /// Gets the final lsn of the transaction + pub fn final_lsn(&self) -> Lsn { + self.final_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Xid of the transaction. + pub fn xid(&self) -> u32 { + self.xid + } +} + +/// A COMMIT statement +#[derive(Debug)] +pub struct CommitBody { + flags: i8, + commit_lsn: u64, + end_lsn: u64, + timestamp: i64, +} + +impl CommitBody { + #[inline] + /// The LSN of the commit. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// The end LSN of the transaction. + pub fn end_lsn(&self) -> Lsn { + self.end_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Flags; currently unused (will be 0). + pub fn flags(&self) -> i8 { + self.flags + } +} + +/// An Origin replication message +/// +/// Note that there can be multiple Origin messages inside a single transaction. +#[derive(Debug)] +pub struct OriginBody { + commit_lsn: u64, + name: Bytes, +} + +impl OriginBody { + #[inline] + /// The LSN of the commit on the origin server. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// Name of the origin. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// Describes the REPLICA IDENTITY setting of a table +#[derive(Debug)] +pub enum ReplicaIdentity { + /// default selection for replica identity (primary key or nothing) + Default, + /// no replica identity is logged for this relation + Nothing, + /// all columns are logged as replica identity + Full, + /// An explicitly chosen candidate key's columns are used as replica identity. + /// Note this will still be set if the index has been dropped; in that case it + /// has the same meaning as 'd'. + Index, +} + +/// A Relation replication message +#[derive(Debug)] +pub struct RelationBody { + rel_id: u32, + namespace: Bytes, + name: Bytes, + replica_identity: ReplicaIdentity, + columns: Vec, +} + +impl RelationBody { + #[inline] + /// ID of the relation. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Relation name. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// Replica identity setting for the relation + pub fn replica_identity(&self) -> &ReplicaIdentity { + &self.replica_identity + } + + #[inline] + /// The column definitions of this relation + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +/// A Type replication message +#[derive(Debug)] +pub struct TypeBody { + id: u32, + namespace: Bytes, + name: Bytes, +} + +impl TypeBody { + #[inline] + /// ID of the data type. + pub fn id(&self) -> Oid { + self.id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Name of the data type. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// An INSERT statement +#[derive(Debug)] +pub struct InsertBody { + rel_id: u32, + tuple: Tuple, +} + +impl InsertBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// The inserted tuple + pub fn tuple(&self) -> &Tuple { + &self.tuple + } +} + +/// An UPDATE statement +#[derive(Debug)] +pub struct UpdateBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, + new_tuple: Tuple, +} + +impl UpdateBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is optional and is only present if the update changed data in any of the + /// column(s) that are part of the REPLICA IDENTITY index. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is optional and is only present if table in which the update happened has + /// REPLICA IDENTITY set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } + + #[inline] + /// The new tuple + pub fn new_tuple(&self) -> &Tuple { + &self.new_tuple + } +} + +/// A DELETE statement +#[derive(Debug)] +pub struct DeleteBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, +} + +impl DeleteBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is present if the table in which the delete has happened uses an index as + /// REPLICA IDENTITY. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is present if the table in which the delete has happened has REPLICA IDENTITY + /// set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } +} + +/// A TRUNCATE statement +#[derive(Debug)] +pub struct TruncateBody { + options: i8, + rel_ids: Vec, +} + +impl TruncateBody { + #[inline] + /// The IDs of the relations corresponding to the ID in the relation messages + pub fn rel_ids(&self) -> &[u32] { + &self.rel_ids + } + + #[inline] + /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + pub fn options(&self) -> i8 { + self.options + } +} + +struct Buffer { + bytes: Bytes, + idx: usize, +} + +impl Buffer { + #[inline] + fn slice(&self) -> &[u8] { + &self.bytes[self.idx..] + } + + #[inline] + fn read_cstr(&mut self) -> io::Result { + match memchr(0, self.slice()) { + Some(pos) => { + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + #[inline] + fn read_all(&mut self) -> Bytes { + let buf = self.bytes.slice(self.idx..); + self.idx = self.bytes.len(); + buf + } +} + +impl Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = { + let slice = self.slice(); + let len = cmp::min(slice.len(), buf.len()); + buf[..len].copy_from_slice(&slice[..len]); + len + }; + self.idx += len; + Ok(len) + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/tokio-postgres/tests/test/replication.rs b/postgres-replication/tests/replication.rs similarity index 92% rename from tokio-postgres/tests/test/replication.rs rename to postgres-replication/tests/replication.rs index 44aae3f22..49700ef8c 100644 --- a/tokio-postgres/tests/test/replication.rs +++ b/postgres-replication/tests/replication.rs @@ -1,11 +1,12 @@ -use futures_util::StreamExt; use std::time::{Duration, UNIX_EPOCH}; -use postgres_protocol::message::backend::LogicalReplicationMessage::{Begin, Commit, Insert}; -use postgres_protocol::message::backend::ReplicationMessage::*; -use postgres_protocol::message::backend::TupleData; +use futures_util::StreamExt; + +use postgres_replication::protocol::LogicalReplicationMessage::{Begin, Commit, Insert}; +use postgres_replication::protocol::ReplicationMessage::*; +use postgres_replication::protocol::TupleData; +use postgres_replication::LogicalReplicationStream; use postgres_types::PgLsn; -use tokio_postgres::replication::LogicalReplicationStream; use tokio_postgres::NoTls; use tokio_postgres::SimpleQueryMessage::Row; @@ -32,7 +33,7 @@ async fn test_replication() { .simple_query("SELECT 'test_logical_replication'::regclass::oid") .await .unwrap(); - let rel_id: u32 = if let Row(row) = &res[0] { + let rel_id: u32 = if let Row(row) = &res[1] { row.get("oid").unwrap().parse().unwrap() } else { panic!("unexpeced query message"); @@ -54,7 +55,7 @@ async fn test_replication() { slot ); let slot_query = client.simple_query(&query).await.unwrap(); - let lsn = if let Row(row) = &slot_query[0] { + let lsn = if let Row(row) = &slot_query[1] { row.get("consistent_point").unwrap() } else { panic!("unexpeced query message"); diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md index 0f42f3495..1e5cd31d8 100644 --- a/postgres-types/CHANGELOG.md +++ b/postgres-types/CHANGELOG.md @@ -1,14 +1,39 @@ # Change Log +## Unreleased + +## v0.2.7 - 2024-07-21 + +### Added + +* Added `Default` implementation for `Json`. +* Added a `js` feature for WASM compatibility. + +### Changed + +* `FromStr` implementation for `PgLsn` no longer allocates a `Vec` when splitting an lsn string on it's `/`. +* The `eui48-1` feature no longer enables default features of the `eui48` library. + +## v0.2.6 - 2023-08-19 + +### Fixed + +* Fixed serialization to `OIDVECTOR` and `INT2VECTOR`. + +### Added + +* Removed the `'static` requirement for the `impl BorrowToSql for Box`. +* Added a `ToSql` implementation for `Cow<[u8]>`. + ## v0.2.5 - 2023-03-27 -## Added +### Added * Added support for multi-range types. ## v0.2.4 - 2022-08-20 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `Box<[T]>`. * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 35cdd6e7b..e2d21b358 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "postgres-types" -version = "0.2.5" +version = "0.2.7" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "Conversions between Rust and Postgres values" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" @@ -13,6 +13,7 @@ categories = ["database"] [features] derive = ["postgres-derive"] array-impls = ["array-init"] +js = ["postgres-protocol/js"] with-bit-vec-0_6 = ["bit-vec-06"] with-cidr-0_2 = ["cidr-02"] with-chrono-0_4 = ["chrono-04"] @@ -30,8 +31,8 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } -postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } +postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } +postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true } @@ -39,8 +40,10 @@ chrono-04 = { version = "0.4.16", package = "chrono", default-features = false, "clock", ], optional = true } cidr-02 = { version = "0.2", package = "cidr", optional = true } +# eui48-04 will stop compiling and support will be removed +# See https://github.com/sfackler/rust-postgres/issues/1073 eui48-04 = { version = "0.4", package = "eui48", optional = true } -eui48-1 = { version = "1.0", package = "eui48", optional = true } +eui48-1 = { version = "1.0", package = "eui48", optional = true, default-features = false } geo-types-06 = { version = "0.6", package = "geo-types", optional = true } geo-types-0_7 = { version = "0.7", package = "geo-types", optional = true } serde-1 = { version = "1.0", package = "serde", optional = true } diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index b7f4f9a03..6b6406232 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -1,5 +1,7 @@ use bytes::BytesMut; -use chrono_04::{DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use chrono_04::{ + DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc, +}; use postgres_protocol::types; use std::error::Error; @@ -40,7 +42,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_naive_utc_and_offset(naive, Utc)) + Ok(Utc.from_utc_datetime(&naive)) } accepts!(TIMESTAMPTZ); diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..492039766 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -125,9 +125,63 @@ //! Happy, //! } //! ``` -#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] +//! +//! Alternatively, the `#[postgres(rename_all = "...")]` attribute can be used to rename all fields or variants +//! with the chosen casing convention. This will not affect the struct or enum's type name. Note that +//! `#[postgres(name = "...")]` takes precendence when used in conjunction with `#[postgres(rename_all = "...")]`: +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(name = "mood", rename_all = "snake_case")] +//! enum Mood { +//! #[postgres(name = "ok")] +//! Ok, // ok +//! VeryHappy, // very_happy +//! } +//! ``` +//! +//! The following case conventions are supported: +//! - `"lowercase"` +//! - `"UPPERCASE"` +//! - `"PascalCase"` +//! - `"camelCase"` +//! - `"snake_case"` +//! - `"SCREAMING_SNAKE_CASE"` +//! - `"kebab-case"` +//! - `"SCREAMING-KEBAB-CASE"` +//! - `"Train-Case"` +//! +//! ## Allowing Enum Mismatches +//! +//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum +//! variants between the Rust and Postgres types. +//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(allow_mismatch)] +//! enum Mood { +//! Happy, +//! Meh, +//! } +//! ``` #![warn(clippy::all, rust_2018_idioms, missing_docs)] - use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; use std::any::type_name; @@ -910,9 +964,15 @@ impl<'a, T: ToSql> ToSql for &'a [T] { _ => panic!("expected array type"), }; + // Arrays are normally one indexed by default but oidvector and int2vector *require* zero indexing + let lower_bound = match *ty { + Type::OID_VECTOR | Type::INT2_VECTOR => 0, + _ => 1, + }; + let dimension = ArrayDimension { len: downcast(self.len())?, - lower_bound: 1, + lower_bound, }; types::array_to_sql( @@ -998,6 +1058,18 @@ impl ToSql for Box<[T]> { to_sql_checked!(); } +impl<'a> ToSql for Cow<'a, [u8]> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[u8] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + impl ToSql for Vec { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { <&[u8] as ToSql>::to_sql(&&**self, ty, w) @@ -1012,28 +1084,20 @@ impl ToSql for Vec { impl<'a> ToSql for &'a str { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match *ty { - ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), - ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), - ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + match ty.name() { + "ltree" => types::ltree_to_sql(self, w), + "lquery" => types::lquery_to_sql(self, w), + "ltxtquery" => types::ltxtquery_to_sql(self, w), _ => types::text_to_sql(self, w), } Ok(IsNull::No) } fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty - if (ty.name() == "citext" - || ty.name() == "ltree" - || ty.name() == "lquery" - || ty.name() == "ltxtquery") => - { - true - } - _ => false, - } + matches!( + *ty, + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN + ) || matches!(ty.name(), "citext" | "ltree" | "lquery" | "ltxtquery") } to_sql_checked!(); @@ -1158,7 +1222,7 @@ impl ToSql for IpAddr { } fn downcast(len: usize) -> Result> { - if len > i32::max_value() as usize { + if len > i32::MAX as usize { Err("value too large to transmit".into()) } else { Ok(len as i32) @@ -1186,17 +1250,17 @@ impl BorrowToSql for &dyn ToSql { } } -impl sealed::Sealed for Box {} +impl<'a> sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() } } -impl sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> sealed::Sealed for Box {} +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() diff --git a/postgres-types/src/pg_lsn.rs b/postgres-types/src/pg_lsn.rs index f0bbf4022..f339f9689 100644 --- a/postgres-types/src/pg_lsn.rs +++ b/postgres-types/src/pg_lsn.rs @@ -33,16 +33,14 @@ impl FromStr for PgLsn { type Err = ParseLsnError; fn from_str(lsn_str: &str) -> Result { - let split: Vec<&str> = lsn_str.split('/').collect(); - if split.len() == 2 { - let (hi, lo) = ( - u64::from_str_radix(split[0], 16).map_err(|_| ParseLsnError(()))?, - u64::from_str_radix(split[1], 16).map_err(|_| ParseLsnError(()))?, - ); - Ok(PgLsn((hi << 32) | lo)) - } else { - Err(ParseLsnError(())) - } + let Some((split_hi, split_lo)) = lsn_str.split_once('/') else { + return Err(ParseLsnError(())); + }; + let (hi, lo) = ( + u64::from_str_radix(split_hi, 16).map_err(|_| ParseLsnError(()))?, + u64::from_str_radix(split_lo, 16).map_err(|_| ParseLsnError(()))?, + ); + Ok(PgLsn((hi << 32) | lo)) } } diff --git a/postgres-types/src/serde_json_1.rs b/postgres-types/src/serde_json_1.rs index b98d561d1..715c33f98 100644 --- a/postgres-types/src/serde_json_1.rs +++ b/postgres-types/src/serde_json_1.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use std::io::Read; /// A wrapper type to allow arbitrary `Serialize`/`Deserialize` types to convert to Postgres JSON values. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Default, Debug, PartialEq, Eq)] pub struct Json(pub T); impl<'a, T> FromSql<'a> for Json diff --git a/postgres-types/src/special.rs b/postgres-types/src/special.rs index 1a865287e..d8541bf0e 100644 --- a/postgres-types/src/special.rs +++ b/postgres-types/src/special.rs @@ -1,7 +1,6 @@ use bytes::BytesMut; use postgres_protocol::types; use std::error::Error; -use std::{i32, i64}; use crate::{FromSql, IsNull, ToSql, Type}; diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index b8263a04a..258cdb518 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,20 +1,40 @@ # Change Log +## v0.19.8 - 2024-07-21 + +### Added + +* Added `{Client, Transaction, GenericClient}::query_typed`. + +## v0.19.7 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + +## v0.19.6 - 2023-08-19 + +### Added + +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + ## v0.19.5 - 2023-03-27 -## Added +### Added * Added `keepalives_interval` and `keepalives_retries` config options. * Added the `tcp_user_timeout` config option. * Added `RowIter::rows_affected`. -## Changed +### Changed * Passing an incorrect number of parameters to a query method now returns an error instead of panicking. ## v0.19.4 - 2022-08-21 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. * Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index e0b2a249d..ff95c4f14 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "postgres" -version = "0.19.5" +version = "0.19.8" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "A native, synchronous PostgreSQL client" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" @@ -39,11 +39,9 @@ with-time-0_3 = ["tokio-postgres/with-time-0_3"] bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } -tokio-postgres = { version = "0.7.8", path = "../tokio-postgres" } - -tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres" } +tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] -criterion = "0.4" -tokio = { version = "1.0", features = ["rt-multi-thread"] } +criterion = "0.5" diff --git a/postgres/src/client.rs b/postgres/src/client.rs index c8e14cf81..42ce6dec9 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -257,6 +257,71 @@ impl Client { Ok(RowIter::new(self.connection.as_ref(), stream)) } + /// Like `query`, but requires the types of query parameters to be explicitly specified. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + pub fn query_typed( + &mut self, + query: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.connection + .block_on(self.client.query_typed(query, params)) + } + + /// The maximally flexible version of [`query_typed`]. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + /// + /// [`query_typed`]: #method.query_typed + /// + /// # Examples + /// ```no_run + /// # use postgres::{Client, NoTls}; + /// use postgres::types::{ToSql, Type}; + /// use fallible_iterator::FallibleIterator; + /// # fn main() -> Result<(), postgres::Error> { + /// # let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let params: Vec<(String, Type)> = vec![ + /// ("first param".into(), Type::TEXT), + /// ("second param".into(), Type::TEXT), + /// ]; + /// let mut it = client.query_typed_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// )?; + /// + /// while let Some(row) = it.next()? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn query_typed_raw(&mut self, query: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator, + { + let stream = self + .connection + .block_on(self.client.query_typed_raw(query, params))?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + /// Creates a new prepared statement. /// /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 2705e3593..91ad3c904 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -1,6 +1,4 @@ //! Connection configuration. -//! -//! Requires the `runtime` Cargo feature (enabled by default). use crate::connection::Connection; use crate::Client; @@ -13,7 +11,9 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -29,7 +29,7 @@ use tokio_postgres::{Error, Socket}; /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -47,9 +47,9 @@ use tokio_postgres::{Error, Socket}; /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, -/// - or if host specifies an IP address, that value will be used directly. +/// or if host specifies an IP address, that value will be used directly. /// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications -/// with time constraints. However, a host name is required for verify-full SSL certificate verification. +/// with time constraints. However, a host name is required for TLS certificate verification. /// Specifically: /// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. /// The connection attempt will fail if the authentication method requires a host name; @@ -76,6 +76,15 @@ use tokio_postgres::{Error, Socket}; /// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that /// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server /// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -84,7 +93,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// host=/var/run/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' /// ``` /// /// ```not_rust @@ -98,7 +107,7 @@ use tokio_postgres::{Error, Socket}; /// # Url /// /// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, -/// and the format accept query parameters for all of the key-value pairs described in the section above. Multiple +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple /// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, /// as the path component of the URL specifies the database name. /// @@ -109,7 +118,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql://user:password@%2Fvar%2Frun%2Fpostgresql/mydb?connect_timeout=10 +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 /// ``` /// /// ```not_rust @@ -117,7 +126,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql:///mydb?user=user&host=/var/run/postgresql +/// postgresql:///mydb?user=user&host=/var/lib/postgresql /// ``` #[derive(Clone)] pub struct Config { @@ -147,7 +156,7 @@ impl Config { /// Sets the user to authenticate with. /// - /// Required. + /// If the user is not set, then this defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.config.user(user); self @@ -435,6 +444,19 @@ impl Config { self.config.get_channel_binding() } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.config.load_balance_hosts(load_balance_hosts); + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.config.get_load_balance_hosts() + } + /// Sets the notice callback. /// /// This callback will be invoked with the contents of every diff --git a/postgres/src/generic_client.rs b/postgres/src/generic_client.rs index 12f07465d..7b534867c 100644 --- a/postgres/src/generic_client.rs +++ b/postgres/src/generic_client.rs @@ -44,6 +44,19 @@ pub trait GenericClient: private::Sealed { I: IntoIterator, I::IntoIter: ExactSizeIterator; + /// Like [`Client::query_typed`] + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error>; + + /// Like [`Client::query_typed_raw`] + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send; + /// Like `Client::prepare`. fn prepare(&mut self, query: &str) -> Result; @@ -115,6 +128,22 @@ impl GenericClient for Client { self.query_raw(query, params) } + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params) + } + + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params) + } + fn prepare(&mut self, query: &str) -> Result { self.prepare(query) } @@ -195,6 +224,22 @@ impl GenericClient for Transaction<'_> { self.query_raw(query, params) } + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params) + } + + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params) + } + fn prepare(&mut self, query: &str) -> Result { self.prepare(query) } diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index fbe85cbde..ddf1609ad 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -55,7 +55,7 @@ //! | ------- | ----------- | ------------------ | ------- | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 0.4 | no | +//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 17c49c406..5c8c15973 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -115,6 +115,35 @@ impl<'a> Transaction<'a> { Ok(RowIter::new(self.connection.as_ref(), stream)) } + /// Like `Client::query_typed`. + pub fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_typed(statement, params), + ) + } + + /// Like `Client::query_typed_raw`. + pub fn query_typed_raw(&mut self, query: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator, + { + let stream = self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_typed_raw(query, params), + )?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + /// Binds parameters to a statement, creating a "portal". /// /// Portals can be used with the `query_portal` method to page through the results of a query without being forced diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 3345a1d43..e0be26296 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,6 +1,50 @@ # Change Log -## v0.7.8 +## Unreleased + +## v0.7.11 - 2024-07-21 + +### Fixed + +* Fixed handling of non-UTF8 error fields which can be sent after failed handshakes. +* Fixed cancellation handling of `TransactionBuilder::start` futures. + +### Added + +* Added `table_oid` and `field_id` fields to `Columns` struct of prepared statements. +* Added `GenericClient::simple_query`. +* Added `#[track_caller]` to `Row::get` and `SimpleQueryRow::get`. +* Added `TargetSessionAttrs::ReadOnly`. +* Added `Debug` implementation for `Statement`. +* Added `Clone` implementation for `Row`. +* Added `SimpleQueryMessage::RowDescription`. +* Added `{Client, Transaction, GenericClient}::query_typed`. + +### Changed + +* Disable `rustc-serialize` compatibility of `eui48-1` dependency +* Config setters now take `impl Into`. + +## v0.7.10 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + +## v0.7.9 - 2023-08-19 + +## Fixed + +* Fixed builds on OpenBSD. + +## Added + +* Added the `js` feature for WASM support. +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + +## v0.7.8 - 2023-05-27 ## Added diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 762caa9b0..f0e7fdb3e 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "tokio-postgres" -version = "0.7.8" +version = "0.7.11" authors = ["Steven Fackler "] edition = "2018" -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" description = "A native, asynchronous PostgreSQL client" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" @@ -40,6 +40,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"] with-uuid-1 = ["postgres-types/with-uuid-1"] with-time-0_2 = ["postgres-types/with-time-0_2"] with-time-0_3 = ["postgres-types/with-time-0_3"] +js = ["postgres-protocol/js", "postgres-types/js"] [dependencies] async-trait = "0.1" @@ -53,18 +54,21 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } -postgres-types = { version = "0.2.4", path = "../postgres-types" } +postgres-protocol = { version = "0.6.7", path = "../postgres-protocol" } +postgres-types = { version = "0.2.7", path = "../postgres-types" } serde = { version = "1.0", optional = true } -socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" +whoami = "1.4.1" + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +socket2 = { version = "0.5", features = ["all"] } [dev-dependencies] futures-executor = "0.3" -criterion = "0.4" -env_logger = "0.10" +criterion = "0.5" +env_logger = "0.11" tokio = { version = "1.0", features = [ "macros", "net", @@ -75,8 +79,7 @@ tokio = { version = "1.0", features = [ bit-vec-06 = { version = "0.6", package = "bit-vec" } chrono-04 = { version = "0.4", package = "chrono", default-features = false } -eui48-04 = { version = "0.4", package = "eui48" } -eui48-1 = { version = "1.0", package = "eui48" } +eui48-1 = { version = "1.0", package = "eui48", default-features = false } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } serde_json-1 = { version = "1.0", package = "serde_json" } diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 2cc4256c4..e1d784607 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,7 +1,7 @@ use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; -use crate::copy_both::CopyBothDuplex; +use crate::copy_both::{CopyBothDuplex, CopyBothReceiver}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -21,9 +21,9 @@ use crate::{ use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; -use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; +use futures_util::{future, pin_mut, ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use postgres_protocol::message::{backend::Message, frontend}; +use postgres_protocol::message::backend::Message; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; @@ -31,6 +31,7 @@ use std::fmt; use std::net::IpAddr; #[cfg(feature = "runtime")] use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -42,6 +43,11 @@ pub struct Responses { cur: BackendMessages, } +pub struct CopyBothHandles { + pub(crate) stream_receiver: mpsc::Receiver>, + pub(crate) sink_sender: mpsc::Sender, +} + impl Responses { pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { loop { @@ -63,6 +69,17 @@ impl Responses { } } +impl Stream for Responses { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match ready!((*self).poll_next(cx)) { + Err(err) if err.is_closed() => Poll::Ready(None), + msg => Poll::Ready(Some(msg)), + } + } +} + /// A cache of type info and prepared statements for fetching type info /// (corresponding to the queries in the [prepare](prepare) module). #[derive(Default)] @@ -105,6 +122,32 @@ impl InnerClient { }) } + pub fn start_copy_both(&self) -> Result { + let (sender, receiver) = mpsc::channel(1); + let (stream_sender, stream_receiver) = mpsc::channel(0); + let (sink_sender, sink_receiver) = mpsc::channel(0); + + let responses = Responses { + receiver, + cur: BackendMessages::empty(), + }; + let messages = RequestMessages::CopyBoth(CopyBothReceiver::new( + responses, + sink_receiver, + stream_sender, + )); + + let request = Request { messages, sender }; + self.sender + .unbounded_send(request) + .map_err(|_| Error::closed())?; + + Ok(CopyBothHandles { + stream_receiver, + sink_sender, + }) + } + pub fn typeinfo(&self) -> Option { self.cached_typeinfo.lock().typeinfo.clone() } @@ -276,19 +319,9 @@ impl Client { where T: ?Sized + ToStatement, { - let stream = self.query_raw(statement, slice_iter(params)).await?; - pin_mut!(stream); - - let row = match stream.try_next().await? { - Some(row) => row, - None => return Err(Error::row_count()), - }; - - if stream.try_next().await?.is_some() { - return Err(Error::row_count()); - } - - Ok(row) + self.query_opt(statement, params) + .await + .and_then(|res| res.ok_or_else(Error::row_count)) } /// Executes a statements which returns zero or one rows, returning it. @@ -312,16 +345,22 @@ impl Client { let stream = self.query_raw(statement, slice_iter(params)).await?; pin_mut!(stream); - let row = match stream.try_next().await? { - Some(row) => row, - None => return Ok(None), - }; + let mut first = None; + + // Originally this was two calls to `try_next().await?`, + // once for the first element, and second to error if more than one. + // + // However, this new form with only one .await in a loop generates + // slightly smaller codegen/stack usage for the resulting future. + while let Some(row) = stream.try_next().await? { + if first.is_some() { + return Err(Error::row_count()); + } - if stream.try_next().await?.is_some() { - return Err(Error::row_count()); + first = Some(row); } - Ok(Some(row)) + Ok(first) } /// The maximally flexible version of [`query`]. @@ -339,7 +378,6 @@ impl Client { /// /// ```no_run /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { - /// use tokio_postgres::types::ToSql; /// use futures_util::{pin_mut, TryStreamExt}; /// /// let params: Vec = vec![ @@ -370,6 +408,70 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Like `query`, but requires the types of query parameters to be explicitly specified. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + pub async fn query_typed( + &self, + query: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed_raw(query, params.iter().map(|(v, t)| (*v, t.clone()))) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query_typed`]. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + /// + /// [`query_typed`]: #method.query_typed + /// + /// # Examples + /// + /// ```no_run + /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { + /// use futures_util::{pin_mut, TryStreamExt}; + /// use tokio_postgres::types::Type; + /// + /// let params: Vec<(String, Type)> = vec![ + /// ("first param".into(), Type::TEXT), + /// ("second param".into(), Type::TEXT), + /// ]; + /// let mut it = client.query_typed_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// ).await?; + /// + /// pin_mut!(it); + /// while let Some(row) = it.try_next().await? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query_typed_raw(&self, query: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator, + { + query::query_typed(&self.inner, query, params).await + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list @@ -497,43 +599,7 @@ impl Client { /// /// The transaction will roll back by default - use the `commit` method to commit it. pub async fn transaction(&mut self) -> Result, Error> { - struct RollbackIfNotDone<'me> { - client: &'me Client, - done: bool, - } - - impl<'a> Drop for RollbackIfNotDone<'a> { - fn drop(&mut self) { - if self.done { - return; - } - - let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); - buf.split().freeze() - }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); - } - } - - // This is done, as `Future` created by this method can be dropped after - // `RequestMessages` is synchronously send to the `Connection` by - // `batch_execute()`, but before `Responses` is asynchronously polled to - // completion. In that case `Transaction` won't be created and thus - // won't be rolled back. - { - let mut cleaner = RollbackIfNotDone { - client: self, - done: false, - }; - self.batch_execute("BEGIN").await?; - cleaner.done = true; - } - - Ok(Transaction::new(self)) + self.build_transaction().start().await } /// Returns a builder for a transaction with custom settings. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index e40ed3e07..e94eac459 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -3,6 +3,7 @@ #[cfg(feature = "runtime")] use crate::connect::connect; use crate::connect_raw::connect_raw; +#[cfg(not(target_arch = "wasm32"))] use crate::keepalive::KeepaliveConfig; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -19,6 +20,7 @@ use std::ops::Deref; use std::os::unix::ffi::OsStrExt; #[cfg(unix)] use std::path::Path; +#[cfg(unix)] use std::path::PathBuf; use std::str; use std::str::FromStr; @@ -34,6 +36,8 @@ pub enum TargetSessionAttrs { Any, /// The session must allow writes. ReadWrite, + /// The session allow only reads. + ReadOnly, } /// TLS configuration. @@ -65,16 +69,6 @@ pub enum ChannelBinding { Require, } -/// Replication mode configuration. -#[derive(Debug, Copy, Clone, PartialEq)] -#[non_exhaustive] -pub enum ReplicationMode { - /// Physical replication. - Physical, - /// Logical replication. - Logical, -} - /// Load balancing configuration. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -85,6 +79,21 @@ pub enum LoadBalanceHosts { Random, } +/// Replication mode configuration. +/// +/// It is recommended that you use a PostgreSQL server patch version +/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or +/// 9.5.25. Earlier patch levels have a bug that doesn't properly +/// handle pipelined requests after streaming has stopped. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -108,7 +117,7 @@ pub enum Host { /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -210,7 +219,7 @@ pub enum Host { /// ```not_rust /// postgresql:///mydb?user=user&host=/var/lib/postgresql /// ``` -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Eq)] pub struct Config { pub(crate) user: Option, pub(crate) password: Option>, @@ -227,11 +236,12 @@ pub struct Config { pub(crate) connect_timeout: Option, pub(crate) tcp_user_timeout: Option, pub(crate) keepalives: bool, + #[cfg(not(target_arch = "wasm32"))] pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, - pub(crate) replication_mode: Option, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -243,11 +253,6 @@ impl Default for Config { impl Config { /// Creates a new configuration. pub fn new() -> Config { - let keepalive_config = KeepaliveConfig { - idle: Duration::from_secs(2 * 60 * 60), - interval: None, - retries: None, - }; Config { user: None, password: None, @@ -264,19 +269,24 @@ impl Config { connect_timeout: None, tcp_user_timeout: None, keepalives: true, - keepalive_config, + #[cfg(not(target_arch = "wasm32"))] + keepalive_config: KeepaliveConfig { + idle: Duration::from_secs(2 * 60 * 60), + interval: None, + retries: None, + }, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, - replication_mode: None, load_balance_hosts: LoadBalanceHosts::Disable, + replication_mode: None, } } /// Sets the user to authenticate with. /// - /// Required. - pub fn user(&mut self, user: &str) -> &mut Config { - self.user = Some(user.to_string()); + /// Defaults to the user executing this process. + pub fn user(&mut self, user: impl Into) -> &mut Config { + self.user = Some(user.into()); self } @@ -304,8 +314,8 @@ impl Config { /// Sets the name of the database to connect to. /// /// Defaults to the user. - pub fn dbname(&mut self, dbname: &str) -> &mut Config { - self.dbname = Some(dbname.to_string()); + pub fn dbname(&mut self, dbname: impl Into) -> &mut Config { + self.dbname = Some(dbname.into()); self } @@ -316,8 +326,8 @@ impl Config { } /// Sets command line options used to configure the server. - pub fn options(&mut self, options: &str) -> &mut Config { - self.options = Some(options.to_string()); + pub fn options(&mut self, options: impl Into) -> &mut Config { + self.options = Some(options.into()); self } @@ -328,8 +338,8 @@ impl Config { } /// Sets the value of the `application_name` runtime parameter. - pub fn application_name(&mut self, application_name: &str) -> &mut Config { - self.application_name = Some(application_name.to_string()); + pub fn application_name(&mut self, application_name: impl Into) -> &mut Config { + self.application_name = Some(application_name.into()); self } @@ -396,7 +406,9 @@ impl Config { /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. /// There must be either no hosts, or the same number of hosts as hostaddrs. - pub fn host(&mut self, host: &str) -> &mut Config { + pub fn host(&mut self, host: impl Into) -> &mut Config { + let host = host.into(); + #[cfg(unix)] { if host.starts_with('/') { @@ -404,7 +416,7 @@ impl Config { } } - self.host.push(Host::Tcp(host.to_string())); + self.host.push(Host::Tcp(host)); self } @@ -413,23 +425,11 @@ impl Config { &self.host } - /// Gets a mutable view of the hosts that have been added to the - /// configuration with `host`. - pub fn get_hosts_mut(&mut self) -> &mut [Host] { - &mut self.host - } - /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. pub fn get_hostaddrs(&self) -> &[IpAddr] { self.hostaddr.deref() } - /// Gets a mutable view of the hostaddrs that have been added to the - /// configuration with `hostaddr`. - pub fn get_hostaddrs_mut(&mut self) -> &mut [IpAddr] { - &mut self.hostaddr - } - /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -513,6 +513,7 @@ impl Config { /// Sets the amount of idle time before a keepalive packet is sent on the connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { self.keepalive_config.idle = keepalives_idle; self @@ -520,6 +521,7 @@ impl Config { /// Gets the configured amount of idle time before a keepalive packet will /// be sent on the connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_idle(&self) -> Duration { self.keepalive_config.idle } @@ -528,12 +530,14 @@ impl Config { /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config { self.keepalive_config.interval = Some(keepalives_interval); self } /// Gets the time interval between TCP keepalive probes. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_interval(&self) -> Option { self.keepalive_config.interval } @@ -541,12 +545,14 @@ impl Config { /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config { self.keepalive_config.retries = Some(keepalives_retries); self } /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_retries(&self) -> Option { self.keepalive_config.retries } @@ -581,17 +587,6 @@ impl Config { self.channel_binding } - /// Set replication mode. - pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { - self.replication_mode = Some(replication_mode); - self - } - - /// Get replication mode. - pub fn get_replication_mode(&self) -> Option { - self.replication_mode - } - /// Sets the host load balancing behavior. /// /// Defaults to `disable`. @@ -605,6 +600,22 @@ impl Config { self.load_balance_hosts } + /// Set replication mode. + /// + /// It is recommended that you use a PostgreSQL server patch version + /// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or + /// 9.5.25. Earlier patch levels have a bug that doesn't properly + /// handle pipelined requests after streaming has stopped. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -706,12 +717,14 @@ impl Config { self.tcp_user_timeout(Duration::from_secs(timeout as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives" => { let keepalives = value .parse::() .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?; self.keepalives(keepalives != 0); } + #[cfg(not(target_arch = "wasm32"))] "keepalives_idle" => { let keepalives_idle = value .parse::() @@ -720,6 +733,7 @@ impl Config { self.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_interval" => { let keepalives_interval = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_interval"))) @@ -728,6 +742,7 @@ impl Config { self.keepalives_interval(Duration::from_secs(keepalives_interval as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_retries" => { let keepalives_retries = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_retries"))) @@ -738,6 +753,7 @@ impl Config { let target_session_attrs = match value { "any" => TargetSessionAttrs::Any, "read-write" => TargetSessionAttrs::ReadWrite, + "read-only" => TargetSessionAttrs::ReadOnly, _ => { return Err(Error::config_parse(Box::new(InvalidValue( "target_session_attrs", @@ -759,17 +775,6 @@ impl Config { }; self.channel_binding(channel_binding); } - "replication" => { - let mode = match value { - "off" => None, - "true" => Some(ReplicationMode::Physical), - "database" => Some(ReplicationMode::Logical), - _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), - }; - if let Some(mode) = mode { - self.replication_mode(mode); - } - } "load_balance_hosts" => { let load_balance_hosts = match value { "disable" => LoadBalanceHosts::Disable, @@ -782,6 +787,17 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -840,7 +856,8 @@ impl fmt::Debug for Config { } } - f.debug_struct("Config") + let mut config_dbg = &mut f.debug_struct("Config"); + config_dbg = config_dbg .field("user", &self.user) .field("password", &self.password.as_ref().map(|_| Redaction {})) .field("dbname", &self.dbname) @@ -855,10 +872,17 @@ impl fmt::Debug for Config { .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("tcp_user_timeout", &self.tcp_user_timeout) - .field("keepalives", &self.keepalives) - .field("keepalives_idle", &self.keepalive_config.idle) - .field("keepalives_interval", &self.keepalive_config.interval) - .field("keepalives_retries", &self.keepalive_config.retries) + .field("keepalives", &self.keepalives); + + #[cfg(not(target_arch = "wasm32"))] + { + config_dbg = config_dbg + .field("keepalives_idle", &self.keepalive_config.idle) + .field("keepalives_interval", &self.keepalive_config.interval) + .field("keepalives_retries", &self.keepalive_config.retries); + } + + config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) .field("replication", &self.replication_mode) @@ -1110,7 +1134,7 @@ impl<'a> UrlParser<'a> { let mut it = creds.splitn(2, ':'); let user = self.decode(it.next().unwrap())?; - self.config.user(&user); + self.config.user(user); if let Some(password) = it.next() { let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); @@ -1173,7 +1197,7 @@ impl<'a> UrlParser<'a> { }; if !dbname.is_empty() { - self.config.dbname(&self.decode(dbname)?); + self.config.dbname(self.decode(dbname)?); } Ok(()) diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ca57b9cdd..8189cb91c 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -160,7 +160,7 @@ where let has_hostname = hostname.is_some(); let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; - if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { + if config.target_session_attrs != TargetSessionAttrs::Any { let rows = client.simple_query_raw("SHOW transaction_read_only"); pin_mut!(rows); @@ -185,11 +185,21 @@ where match next.await.transpose()? { Some(SimpleQueryMessage::Row(row)) => { - if row.try_get(0)? == Some("on") { + let read_only_result = row.try_get(0)?; + if read_only_result == Some("on") + && config.target_session_attrs == TargetSessionAttrs::ReadWrite + { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, "database does not allow writes", ))); + } else if read_only_result == Some("off") + && config.target_session_attrs == TargetSessionAttrs::ReadOnly + { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database is not read only", + ))); } else { break; } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 1348828ba..8edf45937 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -13,6 +13,7 @@ use postgres_protocol::authentication::sasl; use postgres_protocol::authentication::sasl::ScramSha256; use postgres_protocol::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol::message::frontend; +use std::borrow::Cow; use std::collections::{HashMap, VecDeque}; use std::io; use std::pin::Pin; @@ -96,8 +97,13 @@ where delayed: VecDeque::new(), }; - startup(&mut stream, config).await?; - authenticate(&mut stream, config).await?; + let user = config + .user + .as_deref() + .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed); + + startup(&mut stream, config, &user).await?; + authenticate(&mut stream, config, &user).await?; let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); @@ -107,15 +113,17 @@ where Ok((client, connection)) } -async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn startup( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &config.user { - params.push(("user", &**user)); - } + params.push(("user", user)); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -141,7 +149,11 @@ where .map_err(Error::io) } -async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn authenticate( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, @@ -164,10 +176,6 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = config - .user - .as_ref() - .ok_or_else(|| Error::config("user missing".into()))?; let pass = config .password .as_ref() diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 67add04ea..f27131178 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -14,7 +14,9 @@ pub(crate) async fn connect_socket( addr: &Addr, port: u16, connect_timeout: Option, - tcp_user_timeout: Option, + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< + Duration, + >, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { match addr { diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs index 79a7be34a..d3b46eab7 100644 --- a/tokio-postgres/src/copy_both.rs +++ b/tokio-postgres/src/copy_both.rs @@ -1,10 +1,9 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::{simple_query, Error}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_channel::mpsc; -use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; +use futures_util::{ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; use pin_project_lite::pin_project; use postgres_protocol::message::backend::Message; @@ -14,150 +13,253 @@ use std::marker::{PhantomData, PhantomPinned}; use std::pin::Pin; use std::task::{Context, Poll}; -pub(crate) enum CopyBothMessage { - Message(FrontendMessage), - Done, +/// The state machine of CopyBothReceiver +/// +/// ```ignore +/// CopyBoth +/// / \ +/// v v +/// CopyOut CopyIn +/// \ / +/// v v +/// CopyNone +/// | +/// v +/// CopyComplete +/// | +/// v +/// CommandComplete +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CopyBothState { + /// The state before having entered the CopyBoth mode. + Setup, + /// Initial state where CopyData messages can go in both directions + CopyBoth, + /// The server->client stream is closed and we're in CopyIn mode + CopyIn, + /// The client->server stream is closed and we're in CopyOut mode + CopyOut, + /// Both directions are closed, we waiting for CommandComplete messages + CopyNone, + /// We have received the first CommandComplete message for the copy + CopyComplete, + /// We have received the final CommandComplete message for the statement + CommandComplete, } +/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no +/// matter what the users do with their CopyBothDuplex handle we're always going to send the +/// correct messages to the backend in order to restore the connection into a usable state. +/// +/// ```ignore +/// | +/// | +/// | +/// pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex +/// | ^ \ +/// | / v +/// | Sink Stream +/// ``` pub struct CopyBothReceiver { - receiver: mpsc::Receiver, - done: bool, + /// Receiver of backend messages from the underlying [Connection](crate::Connection) + responses: Responses, + /// Receiver of frontend messages sent by the user using + sink_receiver: mpsc::Receiver, + /// Sender of CopyData contents to be consumed by the user using + stream_sender: mpsc::Sender>, + /// The current state of the subprotocol + state: CopyBothState, + /// Holds a buffered message until we are ready to send it to the user's stream + buffered_message: Option>, } impl CopyBothReceiver { - pub(crate) fn new(receiver: mpsc::Receiver) -> CopyBothReceiver { + pub(crate) fn new( + responses: Responses, + sink_receiver: mpsc::Receiver, + stream_sender: mpsc::Sender>, + ) -> CopyBothReceiver { CopyBothReceiver { - receiver, - done: false, + responses, + sink_receiver, + stream_sender, + state: CopyBothState::Setup, + buffered_message: None, } } -} -impl Stream for CopyBothReceiver { - type Item = FrontendMessage; + /// Convenience method to set the subprotocol into an unexpected message state + fn unexpected_message(&mut self) { + self.sink_receiver.close(); + self.buffered_message = Some(Err(Error::unexpected_message())); + self.state = CopyBothState::CommandComplete; + } - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.done { - return Poll::Ready(None); - } + /// Processes messages from the backend, it will resolve once all backend messages have been + /// processed + fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> { + use CopyBothState::*; - match ready!(self.receiver.poll_next_unpin(cx)) { - Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)), - Some(CopyBothMessage::Done) => { - self.done = true; - let mut buf = BytesMut::new(); - frontend::copy_done(&mut buf); - frontend::sync(&mut buf); - Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + loop { + // Deliver the buffered message (if any) to the user to ensure we can potentially + // buffer a new one in response to a server message + if let Some(message) = self.buffered_message.take() { + match self.stream_sender.poll_ready(cx) { + Poll::Ready(_) => { + // If the receiver has hung up we'll just drop the message + let _ = self.stream_sender.start_send(message); + } + Poll::Pending => { + // Stash the message and try again later + self.buffered_message = Some(message); + return Poll::Pending; + } + } } - None => { - self.done = true; - let mut buf = BytesMut::new(); - frontend::copy_fail("", &mut buf).unwrap(); - frontend::sync(&mut buf); - Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + + match ready!(self.responses.poll_next_unpin(cx)) { + Some(Ok(Message::CopyBothResponse(body))) => match self.state { + Setup => { + self.buffered_message = Some(Ok(Message::CopyBothResponse(body))); + self.state = CopyBoth; + } + _ => self.unexpected_message(), + }, + Some(Ok(Message::CopyData(body))) => match self.state { + CopyBoth | CopyOut => { + self.buffered_message = Some(Ok(Message::CopyData(body))); + } + _ => self.unexpected_message(), + }, + // The server->client stream is done + Some(Ok(Message::CopyDone)) => { + match self.state { + CopyBoth => self.state = CopyIn, + CopyOut => self.state = CopyNone, + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::CommandComplete(_))) => { + match self.state { + CopyNone => self.state = CopyComplete, + CopyComplete => { + self.stream_sender.close_channel(); + self.sink_receiver.close(); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + // The server indicated an error, terminate our side if we haven't already + Some(Err(err)) => { + match self.state { + Setup | CopyBoth | CopyOut | CopyIn => { + self.sink_receiver.close(); + self.buffered_message = Some(Err(err)); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::ReadyForQuery(_))) => match self.state { + CommandComplete => { + self.sink_receiver.close(); + self.stream_sender.close_channel(); + } + _ => self.unexpected_message(), + }, + Some(Ok(_)) => self.unexpected_message(), + None => return Poll::Ready(()), } } } } -enum SinkState { - Active, - Closing, - Reading, +/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This +/// is the mechanism that drives the CopyBoth subprotocol forward +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use CopyBothState::*; + + match self.poll_backend(cx) { + Poll::Ready(()) => Poll::Ready(None), + Poll::Pending => match self.state { + Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) { + Some(msg) => Poll::Ready(Some(msg)), + None => { + self.state = match self.state { + CopyBoth => CopyOut, + CopyIn => CopyNone, + _ => unreachable!(), + }; + + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + }, + _ => Poll::Pending, + }, + } + } } pin_project! { - /// A sink for `COPY ... FROM STDIN` query data. + /// A duplex stream for consuming streaming replication data. + /// + /// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new + /// query. This will ensure that the connection returns into normal processing mode. + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; /// - /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is - /// not, the copy will be aborted. + /// // ⚠️ INCORRECT ⚠️ + /// client.query("SELECT 1", &[]).await; // hangs forever + /// + /// // duplex_stream drop-ed here + /// } + /// ``` + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ✅ CORRECT ✅ + /// drop(duplex_stream); + /// + /// client.query("SELECT 1", &[]).await; + /// } + /// ``` pub struct CopyBothDuplex { #[pin] - sender: mpsc::Sender, - responses: Responses, + sink_sender: mpsc::Sender, + #[pin] + stream_receiver: mpsc::Receiver>, buf: BytesMut, - state: SinkState, #[pin] _p: PhantomPinned, _p2: PhantomData, } } -impl CopyBothDuplex -where - T: Buf + 'static + Send, -{ - pub(crate) fn new(sender: mpsc::Sender, responses: Responses) -> Self { - Self { - sender, - responses, - buf: BytesMut::new(), - state: SinkState::Active, - _p: PhantomPinned, - _p2: PhantomData, - } - } - - /// A poll-based version of `finish`. - pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match self.state { - SinkState::Active => { - ready!(self.as_mut().poll_flush(cx))?; - let mut this = self.as_mut().project(); - ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; - this.sender - .start_send(CopyBothMessage::Done) - .map_err(|_| Error::closed())?; - *this.state = SinkState::Closing; - } - SinkState::Closing => { - let this = self.as_mut().project(); - ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; - *this.state = SinkState::Reading; - } - SinkState::Reading => { - let this = self.as_mut().project(); - match ready!(this.responses.poll_next(cx))? { - Message::CommandComplete(body) => { - let rows = body - .tag() - .map_err(Error::parse)? - .rsplit(' ') - .next() - .unwrap() - .parse() - .unwrap_or(0); - return Poll::Ready(Ok(rows)); - } - _ => return Poll::Ready(Err(Error::unexpected_message())), - } - } - } - } - } - - /// Completes the copy, returning the number of rows inserted. - /// - /// The `Sink::close` method is equivalent to `finish`, except that it does not return the - /// number of rows. - pub async fn finish(mut self: Pin<&mut Self>) -> Result { - future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await - } -} - impl Stream for CopyBothDuplex { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match ready!(this.responses.poll_next(cx)?) { - Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), - Message::CopyDone => Poll::Ready(None), - _ => Poll::Ready(Some(Err(Error::unexpected_message()))), - } + Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) { + Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())), + Some(Ok(_)) => Some(Err(Error::unexpected_message())), + Some(Err(err)) => Some(Err(err)), + None => None, + }) } } @@ -169,7 +271,7 @@ where fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() - .sender + .sink_sender .poll_ready(cx) .map_err(|_| Error::closed()) } @@ -193,8 +295,8 @@ where }; let data = CopyData::new(data).map_err(Error::encode)?; - this.sender - .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + this.sink_sender + .start_send(FrontendMessage::CopyData(data)) .map_err(|_| Error::closed()) } @@ -202,20 +304,23 @@ where let mut this = self.project(); if !this.buf.is_empty() { - ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; let data: Box = Box::new(this.buf.split().freeze()); let data = CopyData::new(data).map_err(Error::encode)?; - this.sender + this.sink_sender .as_mut() - .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .start_send(FrontendMessage::CopyData(data)) .map_err(|_| Error::closed())?; } - this.sender.poll_flush(cx).map_err(|_| Error::closed()) + this.sink_sender.poll_flush(cx).map_err(|_| Error::closed()) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_finish(cx).map_ok(|_| ()) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + this.sink_sender.disconnect(); + Poll::Ready(Ok(())) } } @@ -230,19 +335,24 @@ where let buf = simple_query::encode(client, query)?; - let (mut sender, receiver) = mpsc::channel(1); - let receiver = CopyBothReceiver::new(receiver); - let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + let mut handles = client.start_copy_both()?; - sender - .send(CopyBothMessage::Message(FrontendMessage::Raw(buf))) + handles + .sink_sender + .send(FrontendMessage::Raw(buf)) .await .map_err(|_| Error::closed())?; - match responses.next().await? { - Message::CopyBothResponse(_) => {} + match handles.stream_receiver.next().await.transpose()? { + Some(Message::CopyBothResponse(_)) => {} _ => return Err(Error::unexpected_message()), } - Ok(CopyBothDuplex::new(sender, responses)) + Ok(CopyBothDuplex { + stream_receiver: handles.stream_receiver, + sink_sender: handles.sink_sender, + buf: BytesMut::new(), + _p: PhantomPinned, + _p2: PhantomData, + }) } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index f1e2644c6..e35d4d4e4 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -86,7 +86,8 @@ pub struct DbError { } impl DbError { - pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result { + /// Parses the error fields obtained from Postgres into a `DBError`. + pub fn parse(fields: &mut ErrorFields<'_>) -> io::Result { let mut severity = None; let mut parsed_severity = None; let mut code = None; @@ -107,14 +108,15 @@ impl DbError { let mut routine = None; while let Some(field) = fields.next()? { + let value = String::from_utf8_lossy(field.value_bytes()); match field.type_() { - b'S' => severity = Some(field.value().to_owned()), - b'C' => code = Some(SqlState::from_code(field.value())), - b'M' => message = Some(field.value().to_owned()), - b'D' => detail = Some(field.value().to_owned()), - b'H' => hint = Some(field.value().to_owned()), + b'S' => severity = Some(value.into_owned()), + b'C' => code = Some(SqlState::from_code(&value)), + b'M' => message = Some(value.into_owned()), + b'D' => detail = Some(value.into_owned()), + b'H' => hint = Some(value.into_owned()), b'P' => { - normal_position = Some(field.value().parse::().map_err(|_| { + normal_position = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`P` field did not contain an integer", @@ -122,32 +124,32 @@ impl DbError { })?); } b'p' => { - internal_position = Some(field.value().parse::().map_err(|_| { + internal_position = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`p` field did not contain an integer", ) })?); } - b'q' => internal_query = Some(field.value().to_owned()), - b'W' => where_ = Some(field.value().to_owned()), - b's' => schema = Some(field.value().to_owned()), - b't' => table = Some(field.value().to_owned()), - b'c' => column = Some(field.value().to_owned()), - b'd' => datatype = Some(field.value().to_owned()), - b'n' => constraint = Some(field.value().to_owned()), - b'F' => file = Some(field.value().to_owned()), + b'q' => internal_query = Some(value.into_owned()), + b'W' => where_ = Some(value.into_owned()), + b's' => schema = Some(value.into_owned()), + b't' => table = Some(value.into_owned()), + b'c' => column = Some(value.into_owned()), + b'd' => datatype = Some(value.into_owned()), + b'n' => constraint = Some(value.into_owned()), + b'F' => file = Some(value.into_owned()), b'L' => { - line = Some(field.value().parse::().map_err(|_| { + line = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`L` field did not contain an integer", ) })?); } - b'R' => routine = Some(field.value().to_owned()), + b'R' => routine = Some(value.into_owned()), b'V' => { - parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { + parsed_severity = Some(Severity::from_str(&value).ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "`V` field contained an invalid value", @@ -445,7 +447,8 @@ impl Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { + /// Constructs an `UnexpectedMessage` error. + pub fn unexpected_message() -> Error { Error::new(Kind::UnexpectedMessage, None) } @@ -457,7 +460,8 @@ impl Error { } } - pub(crate) fn parse(e: io::Error) -> Error { + /// Constructs a `Parse` error wrapping the provided one. + pub fn parse(e: io::Error) -> Error { Error::new(Kind::Parse, Some(Box::new(e))) } diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..6e7dffeb1 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -1,6 +1,6 @@ use crate::query::RowStream; use crate::types::{BorrowToSql, ToSql, Type}; -use crate::{Client, Error, Row, Statement, ToStatement, Transaction}; +use crate::{Client, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction}; use async_trait::async_trait; mod private { @@ -12,12 +12,12 @@ mod private { /// This trait is "sealed", and cannot be implemented outside of this crate. #[async_trait] pub trait GenericClient: private::Sealed { - /// Like `Client::execute`. + /// Like [`Client::execute`]. async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::execute_raw`. + /// Like [`Client::execute_raw`]. async fn execute_raw(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement + Sync + Send, @@ -25,12 +25,12 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; - /// Like `Client::query`. + /// Like [`Client::query`]. async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::query_one`. + /// Like [`Client::query_one`]. async fn query_one( &self, statement: &T, @@ -39,7 +39,7 @@ pub trait GenericClient: private::Sealed { where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::query_opt`. + /// Like [`Client::query_opt`]. async fn query_opt( &self, statement: &T, @@ -48,7 +48,7 @@ pub trait GenericClient: private::Sealed { where T: ?Sized + ToStatement + Sync + Send; - /// Like `Client::query_raw`. + /// Like [`Client::query_raw`]. async fn query_raw(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement + Sync + Send, @@ -56,23 +56,39 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; - /// Like `Client::prepare`. + /// Like [`Client::query_typed`] + async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error>; + + /// Like [`Client::query_typed_raw`] + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send; + + /// Like [`Client::prepare`]. async fn prepare(&self, query: &str) -> Result; - /// Like `Client::prepare_typed`. + /// Like [`Client::prepare_typed`]. async fn prepare_typed( &self, query: &str, parameter_types: &[Type], ) -> Result; - /// Like `Client::transaction`. + /// Like [`Client::transaction`]. async fn transaction(&mut self) -> Result, Error>; - /// Like `Client::batch_execute`. + /// Like [`Client::batch_execute`]. async fn batch_execute(&self, query: &str) -> Result<(), Error>; - /// Returns a reference to the underlying `Client`. + /// Like [`Client::simple_query`]. + async fn simple_query(&self, query: &str) -> Result, Error>; + + /// Returns a reference to the underlying [`Client`]. fn client(&self) -> &Client; } @@ -136,6 +152,22 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params).await + } + + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -156,6 +188,10 @@ impl GenericClient for Client { self.batch_execute(query).await } + async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query(query).await + } + fn client(&self) -> &Client { self } @@ -222,6 +258,22 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params).await + } + + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -243,6 +295,10 @@ impl GenericClient for Transaction<'_> { self.batch_execute(query).await } + async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query(query).await + } + fn client(&self) -> &Client { self.client() } diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index 74f453985..7bdd76341 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -12,12 +12,23 @@ impl From<&KeepaliveConfig> for TcpKeepalive { fn from(keepalive_config: &KeepaliveConfig) -> Self { let mut tcp_keepalive = Self::new().with_time(keepalive_config.idle); - #[cfg(not(any(target_os = "redox", target_os = "solaris")))] + #[cfg(not(any( + target_os = "aix", + target_os = "redox", + target_os = "solaris", + target_os = "openbsd" + )))] if let Some(interval) = keepalive_config.interval { tcp_keepalive = tcp_keepalive.with_interval(interval); } - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows")))] + #[cfg(not(any( + target_os = "aix", + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "openbsd" + )))] if let Some(retries) = keepalive_config.retries { tcp_keepalive = tcp_keepalive.with_retries(retries); } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 275978cb2..cde9df841 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -107,7 +107,7 @@ //! | `array-impls` | Enables `ToSql` and `FromSql` trait impls for arrays | - | no | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 0.4 | no | +//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | @@ -116,7 +116,6 @@ //! | `with-uuid-1` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 1.0 | no | //! | `with-time-0_2` | Enable support for the 0.2 version of the `time` crate. | [time](https://crates.io/crates/time/0.2.0) 0.2 | no | //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | -#![doc(html_root_url = "https://docs.rs/tokio-postgres/0.7")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] pub use crate::cancel_token::CancelToken; @@ -132,7 +131,7 @@ pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; pub use crate::query::RowStream; pub use crate::row::{Row, SimpleQueryRow}; -pub use crate::simple_query::SimpleQueryStream; +pub use crate::simple_query::{SimpleColumn, SimpleQueryStream}; #[cfg(feature = "runtime")] pub use crate::socket::Socket; pub use crate::statement::{Column, Statement}; @@ -143,6 +142,7 @@ pub use crate::to_statement::ToStatement; pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; use crate::types::ToSql; +use std::sync::Arc; pub mod binary_copy; mod bind; @@ -165,12 +165,12 @@ mod copy_in; mod copy_out; pub mod error; mod generic_client; +#[cfg(not(target_arch = "wasm32"))] mod keepalive; mod maybe_tls_stream; mod portal; mod prepare; mod query; -pub mod replication; pub mod row; mod simple_query; #[cfg(feature = "runtime")] @@ -251,6 +251,8 @@ pub enum SimpleQueryMessage { /// /// The number of rows modified or selected is returned. CommandComplete(u64), + /// Column values of the proceeding row values + RowDescription(Arc<[SimpleColumn]>), } fn slice_iter<'a>( diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..1d9bacb16 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -95,7 +95,12 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column { + name: field.name().to_string(), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), + r#type: type_, + }; columns.push(column); } } @@ -126,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub(crate) async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e6e1d00a8..3ab002871 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,21 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; +use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; +use crate::{Column, Error, Portal, Row, Statement}; use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; +use postgres_types::Type; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -57,6 +61,68 @@ where }) } +pub async fn query_typed<'a, P, I>( + client: &Arc, + query: &str, + params: I, +) -> Result +where + P: BorrowToSql, + I: IntoIterator, +{ + let buf = { + let params = params.into_iter().collect::>(); + let param_oids = params.iter().map(|(_, t)| t.oid()).collect::>(); + + client.with_buf(|buf| { + frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?; + encode_bind_raw("", params, "", buf)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); + + Ok(buf.split().freeze()) + })? + }; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + loop { + match responses.next().await? { + Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {} + Message::NoData => { + return Ok(RowStream { + statement: Statement::unnamed(vec![], vec![]), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } + Message::RowDescription(row_description) => { + let mut columns: Vec = vec![]; + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column { + name: field.name().to_string(), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), + r#type: type_, + }; + columns.push(column); + } + return Ok(RowStream { + statement: Statement::unnamed(vec![], columns), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } + _ => return Err(Error::unexpected_message()), + } + } +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, @@ -164,27 +230,42 @@ where I: IntoIterator, I::IntoIter: ExactSizeIterator, { - let param_types = statement.params(); let params = params.into_iter(); - - if param_types.len() != params.len() { - return Err(Error::parameters(params.len(), param_types.len())); + if params.len() != statement.params().len() { + return Err(Error::parameters(params.len(), statement.params().len())); } + encode_bind_raw( + statement.name(), + params.zip(statement.params().iter().cloned()), + portal, + buf, + ) +} + +fn encode_bind_raw( + statement_name: &str, + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ let (param_formats, params): (Vec<_>, Vec<_>) = params - .zip(param_types.iter()) - .map(|(p, ty)| (p.borrow_to_sql().encode_format(ty) as i16, p)) + .into_iter() + .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty))) .unzip(); - let params = params.into_iter(); - let mut error_idx = 0; let r = frontend::bind( portal, - statement.name(), + statement_name, param_formats, - params.zip(param_types).enumerate(), - |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) { + params.into_iter().enumerate(), + |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) { Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Err(e) => { diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..767c26921 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -95,6 +95,7 @@ where } /// A row of data returned from the database by a query. +#[derive(Clone)] pub struct Row { statement: Statement, body: DataRowBody, @@ -141,6 +142,7 @@ impl Row { /// # Panics /// /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + #[track_caller] pub fn get<'a, I, T>(&'a self, idx: I) -> T where I: RowIndex + fmt::Display, @@ -239,6 +241,7 @@ impl SimpleQueryRow { /// # Panics /// /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + #[track_caller] pub fn get(&self, idx: I) -> Option<&str> where I: RowIndex + fmt::Display, diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index a97ee126c..a26e43e6e 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -85,35 +85,34 @@ impl Stream for SimpleQueryStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - loop { - match ready!(this.responses.poll_next(cx)?) { - Message::CommandComplete(body) => { - let rows = extract_row_affected(&body)?; - return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))); - } - Message::EmptyQueryResponse => { - return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))); - } - Message::RowDescription(body) => { - let columns = body - .fields() - .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) - .collect::>() - .map_err(Error::parse)? - .into(); + match ready!(this.responses.poll_next(cx)?) { + Message::CommandComplete(body) => { + let rows = extract_row_affected(&body)?; + Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))) + } + Message::EmptyQueryResponse => { + Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))) + } + Message::RowDescription(body) => { + let columns: Arc<[SimpleColumn]> = body + .fields() + .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) + .collect::>() + .map_err(Error::parse)? + .into(); - *this.columns = Some(columns); - } - Message::DataRow(body) => { - let row = match &this.columns { - Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, - None => return Poll::Ready(Some(Err(Error::unexpected_message()))), - }; - return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); - } - Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + *this.columns = Some(columns.clone()); + Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns)))) + } + Message::DataRow(body) => { + let row = match &this.columns { + Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, + None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + }; + Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))) } + Message::ReadyForQuery(_) => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), } } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..4f7ddaec6 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,10 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; -use std::{ - fmt, - sync::{Arc, Weak}, -}; +use std::sync::{Arc, Weak}; struct StatementInner { client: Weak, @@ -17,6 +14,10 @@ struct StatementInner { impl Drop for StatementInner { fn drop(&mut self) { + if self.name.is_empty() { + // Unnamed statements don't need to be closed + return; + } if let Some(client) = self.client.upgrade() { let buf = client.with_buf(|buf| { frontend::close(b'S', &self.name, buf).unwrap(); @@ -49,6 +50,15 @@ impl Statement { })) } + pub(crate) fn unnamed(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + pub(crate) fn name(&self) -> &str { &self.0.name } @@ -64,33 +74,43 @@ impl Statement { } } +impl std::fmt::Debug for Statement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("Statement") + .field("name", &self.0.name) + .field("params", &self.0.params) + .field("columns", &self.0.columns) + .finish_non_exhaustive() + } +} + /// Information about a column of a query. +#[derive(Debug)] pub struct Column { - name: String, - type_: Type, + pub(crate) name: String, + pub(crate) table_oid: Option, + pub(crate) column_id: Option, + pub(crate) r#type: Type, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } - } - /// Returns the name of the column. pub fn name(&self) -> &str { &self.name } - /// Returns the type of the column. - pub fn type_(&self) -> &Type { - &self.type_ + /// Returns the OID of the underlying database table. + pub fn table_oid(&self) -> Option { + self.table_oid + } + + /// Return the column ID within the underlying database table. + pub fn column_id(&self) -> Option { + self.column_id } -} -impl fmt::Debug for Column { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Column") - .field("name", &self.name) - .field("type", &self.type_) - .finish() + /// Returns the type of the column. + pub fn type_(&self) -> &Type { + &self.r#type } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..17a50b60f 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,24 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_typed`. + pub async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.client.query_typed(statement, params).await + } + + /// Like `Client::query_typed_raw`. + pub async fn query_typed_raw(&self, query: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator, + { + self.client.query_typed_raw(query, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, diff --git a/tokio-postgres/src/transaction_builder.rs b/tokio-postgres/src/transaction_builder.rs index 9718ac588..93e9e9801 100644 --- a/tokio-postgres/src/transaction_builder.rs +++ b/tokio-postgres/src/transaction_builder.rs @@ -1,4 +1,6 @@ -use crate::{Client, Error, Transaction}; +use postgres_protocol::message::frontend; + +use crate::{codec::FrontendMessage, connection::RequestMessages, Client, Error, Transaction}; /// The isolation level of a database transaction. #[derive(Debug, Copy, Clone)] @@ -106,7 +108,41 @@ impl<'a> TransactionBuilder<'a> { query.push_str(s); } - self.client.batch_execute(&query).await?; + struct RollbackIfNotDone<'me> { + client: &'me Client, + done: bool, + } + + impl<'a> Drop for RollbackIfNotDone<'a> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } + + // This is done as `Future` created by this method can be dropped after + // `RequestMessages` is synchronously send to the `Connection` by + // `batch_execute()`, but before `Responses` is asynchronously polled to + // completion. In that case `Transaction` won't be created and thus + // won't be rolled back. + { + let mut cleaner = RollbackIfNotDone { + client: self.client, + done: false, + }; + self.client.batch_execute(&query).await?; + cleaner.done = true; + } Ok(Transaction::new(self.client)) } diff --git a/tokio-postgres/tests/test/copy_both.rs b/tokio-postgres/tests/test/copy_both.rs new file mode 100644 index 000000000..2723928ac --- /dev/null +++ b/tokio-postgres/tests/test/copy_both.rs @@ -0,0 +1,125 @@ +use futures_util::{future, StreamExt, TryStreamExt}; +use tokio_postgres::{error::SqlState, Client, SimpleQueryMessage, SimpleQueryRow}; + +async fn q(client: &Client, query: &str) -> Vec { + let msgs = client.simple_query(query).await.unwrap(); + + msgs.into_iter() + .filter_map(|msg| match msg { + SimpleQueryMessage::Row(row) => Some(row), + _ => None, + }) + .collect() +} + +#[tokio::test] +async fn copy_both_error() { + let client = crate::connect("user=postgres replication=database").await; + + let err = client + .copy_both_simple::("START_REPLICATION SLOT undefined LOGICAL 0000/0000") + .await + .err() + .unwrap(); + + assert_eq!(err.code(), Some(&SqlState::UNDEFINED_OBJECT)); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both_stream_error() { + let client = crate::connect("user=postgres replication=true").await; + + q(&client, "CREATE_REPLICATION_SLOT err2 PHYSICAL").await; + + // This will immediately error out after entering CopyBoth mode + let duplex_stream = client + .copy_both_simple::("START_REPLICATION SLOT err2 PHYSICAL FFFF/FFFF") + .await + .unwrap(); + + let mut msgs: Vec<_> = duplex_stream.collect().await; + let result = msgs.pop().unwrap(); + assert_eq!(msgs.len(), 0); + assert!(result.unwrap_err().as_db_error().is_some()); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "DROP_REPLICATION_SLOT err2").await.len(), 0); +} + +#[tokio::test] +async fn copy_both_stream_error_sync() { + let client = crate::connect("user=postgres replication=database").await; + + q(&client, "CREATE_REPLICATION_SLOT err1 TEMPORARY PHYSICAL").await; + + // This will immediately error out after entering CopyBoth mode + let duplex_stream = client + .copy_both_simple::("START_REPLICATION SLOT err1 PHYSICAL FFFF/FFFF") + .await + .unwrap(); + + // Immediately close our sink to send a CopyDone before receiving the ErrorResponse + drop(duplex_stream); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both() { + let client = crate::connect("user=postgres replication=database").await; + + q(&client, "DROP TABLE IF EXISTS replication").await; + q(&client, "CREATE TABLE replication (i text)").await; + + let slot_query = "CREATE_REPLICATION_SLOT slot TEMPORARY LOGICAL \"test_decoding\""; + let lsn = q(&client, slot_query).await[0] + .get("consistent_point") + .unwrap() + .to_owned(); + + // We will attempt to read this from the other end + q(&client, "BEGIN").await; + let xid = q(&client, "SELECT txid_current()").await[0] + .get("txid_current") + .unwrap() + .to_owned(); + q(&client, "INSERT INTO replication VALUES ('processed')").await; + q(&client, "COMMIT").await; + + // Insert a second row to generate unprocessed messages in the stream + q(&client, "INSERT INTO replication VALUES ('ignored')").await; + + let query = format!("START_REPLICATION SLOT slot LOGICAL {}", lsn); + let duplex_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let expected = vec![ + format!("BEGIN {}", xid), + "table public.replication: INSERT: i[text]:'processed'".to_string(), + format!("COMMIT {}", xid), + ]; + + let actual: Vec<_> = duplex_stream + // Process only XLogData messages + .try_filter(|buf| future::ready(buf[0] == b'w')) + // Playback the stream until the first expected message + .try_skip_while(|buf| future::ready(Ok(!buf.ends_with(expected[0].as_ref())))) + // Take only the expected number of messsage, the rest will be discarded by tokio_postgres + .take(expected.len()) + .try_collect() + .await + .unwrap(); + + for (msg, ending) in actual.into_iter().zip(expected.into_iter()) { + assert!(msg.ends_with(ending.as_ref())); + } + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index cab185ae6..3debf4eba 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -20,10 +20,9 @@ use tokio_postgres::{ }; mod binary_copy; +mod copy_both; mod parse; #[cfg(feature = "runtime")] -mod replication; -#[cfg(feature = "runtime")] mod runtime; mod types; @@ -305,6 +304,7 @@ async fn custom_range() { } #[tokio::test] +#[allow(clippy::get_first)] async fn simple_query() { let client = connect("user=postgres").await; @@ -329,6 +329,13 @@ async fn simple_query() { _ => panic!("unexpected message"), } match &messages[2] { + SimpleQueryMessage::RowDescription(columns) => { + assert_eq!(columns.get(0).map(|c| c.name()), Some("id")); + assert_eq!(columns.get(1).map(|c| c.name()), Some("name")); + } + _ => panic!("unexpected message"), + } + match &messages[3] { SimpleQueryMessage::Row(row) => { assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); @@ -337,7 +344,7 @@ async fn simple_query() { } _ => panic!("unexpected message"), } - match &messages[3] { + match &messages[4] { SimpleQueryMessage::Row(row) => { assert_eq!(row.columns().first().map(|c| c.name()), Some("id")); assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); @@ -346,11 +353,11 @@ async fn simple_query() { } _ => panic!("unexpected message"), } - match messages[4] { + match messages[5] { SimpleQueryMessage::CommandComplete(2) => {} _ => panic!("unexpected message"), } - assert_eq!(messages.len(), 5); + assert_eq!(messages.len(), 6); } #[tokio::test] @@ -953,3 +960,123 @@ async fn deferred_constraint() { .await .unwrap_err(); } + +#[tokio::test] +async fn query_typed_no_transaction() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT, + age INT + ); + INSERT INTO foo (name, age) VALUES ('alice', 20), ('bob', 30), ('carol', 40); + ", + ) + .await + .unwrap(); + + let rows: Vec = client + .query_typed( + "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", + &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + let first_row = &rows[0]; + assert_eq!(first_row.get::<_, &str>(0), "bob"); + assert_eq!(first_row.get::<_, i32>(1), 30); + assert_eq!(first_row.get::<_, &str>(2), "literal"); + assert_eq!(first_row.get::<_, i32>(3), 5); + + let second_row = &rows[1]; + assert_eq!(second_row.get::<_, &str>(0), "carol"); + assert_eq!(second_row.get::<_, i32>(1), 40); + assert_eq!(second_row.get::<_, &str>(2), "literal"); + assert_eq!(second_row.get::<_, i32>(3), 5); + + // Test for UPDATE that returns no data + let updated_rows = client + .query_typed("UPDATE foo set age = 33", &[]) + .await + .unwrap(); + assert_eq!(updated_rows.len(), 0); +} + +#[tokio::test] +async fn query_typed_with_transaction() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT, + age INT + ); + ", + ) + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + + let rows: Vec = transaction + .query_typed( + "INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age", + &[ + (&"alice", Type::TEXT), + (&20i32, Type::INT4), + (&"bob", Type::TEXT), + (&30i32, Type::INT4), + (&"carol", Type::TEXT), + (&40i32, Type::INT4), + ], + ) + .await + .unwrap(); + let inserted_values: Vec<(String, i32)> = rows + .iter() + .map(|row| (row.get::<_, String>(0), row.get::<_, i32>(1))) + .collect(); + assert_eq!( + inserted_values, + [ + ("alice".to_string(), 20), + ("bob".to_string(), 30), + ("carol".to_string(), 40) + ] + ); + + let rows: Vec = transaction + .query_typed( + "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", + &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + let first_row = &rows[0]; + assert_eq!(first_row.get::<_, &str>(0), "bob"); + assert_eq!(first_row.get::<_, i32>(1), 30); + assert_eq!(first_row.get::<_, &str>(2), "literal"); + assert_eq!(first_row.get::<_, i32>(3), 5); + + let second_row = &rows[1]; + assert_eq!(second_row.get::<_, &str>(0), "carol"); + assert_eq!(second_row.get::<_, i32>(1), 40); + assert_eq!(second_row.get::<_, &str>(2), "literal"); + assert_eq!(second_row.get::<_, i32>(3), 5); + + // Test for UPDATE that returns no data + let updated_rows = transaction + .query_typed("UPDATE foo set age = 33", &[]) + .await + .unwrap(); + assert_eq!(updated_rows.len(), 0); +} diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index 2c11899ca..04d422e27 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -34,6 +34,14 @@ fn settings() { .keepalives_idle(Duration::from_secs(30)) .target_session_attrs(TargetSessionAttrs::ReadWrite), ); + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-only", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::ReadOnly), + ); } #[test] diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index c325917aa..b010055ba 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -1,4 +1,4 @@ -use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use std::fmt; use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; use tokio_postgres::Client; @@ -53,11 +53,9 @@ async fn test_with_special_naive_date_time_params() { async fn test_date_time_params() { fn make_check(time: &str) -> (Option>, &str) { ( - Some( - NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap() - .and_utc(), - ), + Some(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + )), time, ) } @@ -77,11 +75,9 @@ async fn test_date_time_params() { async fn test_with_special_date_time_params() { fn make_check(time: &str) -> (Timestamp>, &str) { ( - Timestamp::Value( - NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap() - .and_utc(), - ), + Timestamp::Value(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + )), time, ) } diff --git a/tokio-postgres/tests/test/types/eui48_04.rs b/tokio-postgres/tests/test/types/eui48_04.rs deleted file mode 100644 index 074faa37e..000000000 --- a/tokio-postgres/tests/test/types/eui48_04.rs +++ /dev/null @@ -1,18 +0,0 @@ -use eui48_04::MacAddress; - -use crate::types::test_type; - -#[tokio::test] -async fn test_eui48_params() { - test_type( - "MACADDR", - &[ - ( - Some(MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()), - "'12-34-56-ab-cd-ef'", - ), - (None, "NULL"), - ], - ) - .await -} diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 452d149fe..62d54372a 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -17,8 +17,6 @@ use bytes::BytesMut; mod bit_vec_06; #[cfg(feature = "with-chrono-0_4")] mod chrono_04; -#[cfg(feature = "with-eui48-0_4")] -mod eui48_04; #[cfg(feature = "with-eui48-1")] mod eui48_1; #[cfg(feature = "with-geo-types-0_6")] @@ -739,3 +737,25 @@ async fn ltxtquery_any() { ) .await; } + +#[tokio::test] +async fn oidvector() { + test_type( + "oidvector", + // NB: postgres does not support empty oidarrays! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0u32, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +} + +#[tokio::test] +async fn int2vector() { + test_type( + "int2vector", + // NB: postgres does not support empty int2vectors! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0i16, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +}