diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 8cb6882e1d..137813f847 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -13,6 +13,7 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); @@ -178,6 +179,7 @@ impl<'a> TryFrom<&'a PgTypeInfo> for AnyTypeInfo { PgType::Float8 => AnyTypeInfoKind::Double, PgType::Bytea => AnyTypeInfoKind::Blob, PgType::Text => AnyTypeInfoKind::Text, + PgType::DeclareWithName(UStr::Static("citext")) => AnyTypeInfoKind::Text, _ => { return Err(sqlx_core::Error::AnyDriverError( format!("Any driver does not support the Postgres type {pg_type:?}").into(), diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index e4ca317ffb..5952291e6e 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -457,6 +457,7 @@ impl PgType { PgType::Int8RangeArray => Oid(3927), PgType::Jsonpath => Oid(4072), PgType::JsonpathArray => Oid(4073), + PgType::Custom(ty) => ty.oid, PgType::DeclareWithOid(oid) => *oid, @@ -874,6 +875,7 @@ impl PgType { PgType::Unknown => None, // There is no `VoidArray` PgType::Void => None, + PgType::Custom(ty) => match &ty.kind { PgTypeKind::Simple => None, PgTypeKind::Pseudo => None, diff --git a/sqlx-postgres/src/types/citext.rs b/sqlx-postgres/src/types/citext.rs new file mode 100644 index 0000000000..9fc131d7ca --- /dev/null +++ b/sqlx-postgres/src/types/citext.rs @@ -0,0 +1,106 @@ +use crate::types::array_compatible; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::Type; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Deref; +use std::str::FromStr; + +/// Case-insensitive text (`citext`) support for Postgres. +/// +/// Note that SQLx considers the `citext` type to be compatible with `String` +/// and its various derivatives, so direct usage of this type is generally unnecessary. +/// +/// However, it may be needed, for example, when binding a `citext[]` array, +/// as Postgres will generally not accept a `text[]` array (mapped from `Vec`) in its place. +/// +/// See [the Postgres manual, Appendix F, Section 10][PG.F.10] for details on using `citext`. +/// +/// [PG.F.10]: https://www.postgresql.org/docs/current/citext.html +/// +/// ### Note: Extension Required +/// The `citext` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS "citext"; +/// ``` +/// +/// ### Note: `PartialEq` is Case-Sensitive +/// This type derives `PartialEq` which forwards to the implementation on `String`, which +/// is case-sensitive. This impl exists mainly for testing. +/// +/// To properly emulate the case-insensitivity of `citext` would require use of locale-aware +/// functions in `libc`, and even then would require querying the locale of the database server +/// and setting it locally, which is unsafe. +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgCiText(pub String); + +impl Type for PgCiText { + fn type_info() -> PgTypeInfo { + // Since `citext` is enabled by an extension, it does not have a stable OID. + PgTypeInfo::with_name("citext") + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Deref for PgCiText { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_str() + } +} + +impl From for PgCiText { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From for String { + fn from(value: PgCiText) -> Self { + value.0 + } +} + +impl FromStr for PgCiText { + type Err = core::convert::Infallible; + + fn from_str(s: &str) -> Result { + Ok(PgCiText(s.parse()?)) + } +} + +impl Display for PgCiText { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +impl PgHasArrayType for PgCiText { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_citext") + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::<&str>(ty) + } +} + +impl Encode<'_, Postgres> for PgCiText { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + <&str as Encode>::encode(&**self, buf) + } +} + +impl Decode<'_, Postgres> for PgCiText { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(PgCiText(value.as_str()?.to_owned())) + } +} diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index c0dd5c9141..5a80b8e587 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -11,7 +11,7 @@ //! | `i64` | BIGINT, BIGSERIAL, INT8 | //! | `f32` | REAL, FLOAT4 | //! | `f64` | DOUBLE PRECISION, FLOAT8 | -//! | `&str`, [`String`] | VARCHAR, CHAR(N), TEXT, NAME | +//! | `&str`, [`String`] | VARCHAR, CHAR(N), TEXT, NAME, CITEXT | //! | `&[u8]`, `Vec` | BYTEA | //! | `()` | VOID | //! | [`PgInterval`] | INTERVAL | @@ -19,6 +19,11 @@ //! | [`PgMoney`] | MONEY | //! | [`PgLTree`] | LTREE | //! | [`PgLQuery`] | LQUERY | +//! | [`PgCiText`] | CITEXT1 | +//! +//! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., +//! but this wrapper type is available for edge cases, such as `CITEXT[]` which Postgres +//! does not consider to be compatible with `TEXT[]`. //! //! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) //! Requires the `bigdecimal` Cargo feature flag. @@ -175,6 +180,7 @@ pub(crate) use sqlx_core::types::{Json, Type}; mod array; mod bool; mod bytes; +mod citext; mod float; mod int; mod interval; @@ -224,6 +230,7 @@ mod mac_address; mod bit_vec; pub use array::PgHasArrayType; +pub use citext::PgCiText; pub use interval::PgInterval; pub use lquery::PgLQuery; pub use lquery::PgLQueryLevel; diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index d66532dbff..c6938010e5 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -18,6 +18,7 @@ impl Type for str { PgTypeInfo::BPCHAR, PgTypeInfo::VARCHAR, PgTypeInfo::UNKNOWN, + PgTypeInfo::with_name("citext"), ] .contains(ty) } diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index 81bf0571fc..cfaaf1349d 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -611,3 +611,28 @@ async fn test_bind_arg_override_wildcard() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_to_from_citext() -> anyhow::Result<()> { + // Ensure that the macros consider `CITEXT` to be compatible with `String` and friends + + let mut conn = new::().await?; + + let mut tx = conn.begin().await?; + + let foo_in = "Hello, world!"; + + sqlx::query!("insert into test_citext(foo) values ($1)", foo_in) + .execute(&mut *tx) + .await?; + + let foo_out: String = sqlx::query_scalar!("select foo from test_citext") + .fetch_one(&mut *tx) + .await?; + + assert_eq!(foo_in, foo_out); + + tx.rollback().await?; + + Ok(()) +} diff --git a/tests/postgres/setup.sql b/tests/postgres/setup.sql index 039e08d86f..5a415324d8 100644 --- a/tests/postgres/setup.sql +++ b/tests/postgres/setup.sql @@ -1,6 +1,9 @@ -- https://www.postgresql.org/docs/current/ltree.html CREATE EXTENSION IF NOT EXISTS ltree; +-- https://www.postgresql.org/docs/current/citext.html +CREATE EXTENSION IF NOT EXISTS citext; + -- https://www.postgresql.org/docs/current/sql-createtype.html CREATE TYPE status AS ENUM ('new', 'open', 'closed'); @@ -44,3 +47,7 @@ CREATE TABLE products ( CREATE OR REPLACE PROCEDURE forty_two(INOUT forty_two INT = NULL) LANGUAGE plpgsql AS 'begin forty_two := 42; end;'; + +CREATE TABLE test_citext ( + foo CITEXT NOT NULL +); diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index dbb4092910..445e9fafe9 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -2,7 +2,7 @@ extern crate time_ as time; use std::ops::Bound; -use sqlx::postgres::types::{Oid, PgInterval, PgMoney, PgRange}; +use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; use sqlx_test::{test_decode_type, test_prepared_type, test_type}; @@ -65,6 +65,7 @@ test_type!(str<&str>(Postgres, "'identifier'::name" == "identifier", "'five'::char(4)" == "five", "'more text'::varchar" == "more text", + "'case insensitive searching'::citext" == "case insensitive searching", )); test_type!(string(Postgres, @@ -79,7 +80,7 @@ test_type!(string_vec>(Postgres, == vec!["", "\""], "array['Hello, World', '', 'Goodbye']::text[]" - == vec!["Hello, World", "", "Goodbye"] + == vec!["Hello, World", "", "Goodbye"], )); test_type!(string_array<[String; 3]>(Postgres, @@ -550,6 +551,14 @@ test_prepared_type!(money_vec>(Postgres, "array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)], )); +test_prepared_type!(citext_array>(Postgres, + "array['one','two','three']::citext[]" == vec![ + PgCiText("one".to_string()), + PgCiText("two".to_string()), + PgCiText("three".to_string()), + ], +)); + // FIXME: needed to disable `ltree` tests in version that don't have a binary format for it // but `PgLTree` should just fall back to text format #[cfg(any(postgres_14, postgres_15))]