From 809a20709565ca725008e326863f44da5564c506 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 26 Jun 2020 18:01:07 +0200 Subject: [PATCH] Support for rust_decimal::Decimal --- Cargo.lock | 12 + Cargo.toml | 3 +- sqlx-core/Cargo.toml | 5 +- sqlx-core/src/mysql/types/decimal.rs | 29 ++ sqlx-core/src/mysql/types/mod.rs | 10 + sqlx-core/src/postgres/types/decimal.rs | 435 ++++++++++++++++++++++++ sqlx-core/src/postgres/types/mod.rs | 19 +- 7 files changed, 510 insertions(+), 3 deletions(-) create mode 100644 sqlx-core/src/mysql/types/decimal.rs create mode 100644 sqlx-core/src/postgres/types/decimal.rs diff --git a/Cargo.lock b/Cargo.lock index e029f415b7..99364b3952 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2006,6 +2006,16 @@ dependencies = [ "crossbeam-utils 0.6.6", ] +[[package]] +name = "rust_decimal" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b5f52edf35045e96b07aa29822bf4ce8495295fd5610270f85ab1f26df7ba5" +dependencies = [ + "num-traits", + "serde", +] + [[package]] name = "rustc-demangle" version = "0.1.16" @@ -2361,6 +2371,7 @@ dependencies = [ "md-5", "memchr", "num-bigint", + "num-traits", "once_cell", "parking_lot 0.11.0", "percent-encoding 2.1.0", @@ -2368,6 +2379,7 @@ dependencies = [ "rand", "regex", "rsa", + "rust_decimal", "serde", "serde_json", "sha-1", diff --git a/Cargo.toml b/Cargo.toml index b623f0e07b..8351c8ad90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ offline = [ "sqlx-macros/offline", "sqlx-core/offline" ] # intended mainly for CI and docs all = [ "tls", "all-databases", "all-types" ] all-databases = [ "mysql", "sqlite", "postgres", "mssql", "any" ] -all-types = [ "bigdecimal", "json", "time", "chrono", "ipnetwork", "uuid" ] +all-types = [ "bigdecimal", "decimal", "json", "time", "chrono", "ipnetwork", "uuid" ] # runtime runtime-async-std = [ "sqlx-core/runtime-async-std", "sqlx-macros/runtime-async-std" ] @@ -65,6 +65,7 @@ mssql = [ "sqlx-core/mssql", "sqlx-macros/mssql" ] # types bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"] +decimal = ["sqlx-core/decimal"] chrono = [ "sqlx-core/chrono", "sqlx-macros/chrono" ] ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-macros/ipnetwork" ] uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index bfc3726518..745af0ba5e 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -24,8 +24,9 @@ mssql = [ "uuid", "encoding_rs", "regex" ] any = [] # types -all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ] +all-types = [ "chrono", "time", "bigdecimal", "decimal", "ipnetwork", "json", "uuid" ] bigdecimal = [ "bigdecimal_", "num-bigint" ] +decimal = [ "rust_decimal", "num-bigint", "num-traits" ] json = [ "serde", "serde_json" ] # runtimes @@ -41,6 +42,8 @@ atoi = "0.3.2" sqlx-rt = { path = "../sqlx-rt", version = "0.1.0-pre" } base64 = { version = "0.12.1", default-features = false, optional = true, features = [ "std" ] } bigdecimal_ = { version = "0.1.0", optional = true, package = "bigdecimal" } +rust_decimal = { version = "1.6.0", optional = true } +num-traits = { version = "0.2.12", optional = true } bitflags = { version = "1.2.1", default-features = false } bytes = "0.5.4" byteorder = { version = "1.3.4", default-features = false, features = [ "std" ] } diff --git a/sqlx-core/src/mysql/types/decimal.rs b/sqlx-core/src/mysql/types/decimal.rs new file mode 100644 index 0000000000..0826b6bbd9 --- /dev/null +++ b/sqlx-core/src/mysql/types/decimal.rs @@ -0,0 +1,29 @@ +use rust_decimal::Decimal; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::io::MySqlBufMutExt; +use crate::mysql::protocol::text::ColumnType; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; +use crate::types::Type; + +impl Type for Decimal { + fn type_info() -> MySqlTypeInfo { + MySqlTypeInfo::binary(ColumnType::NewDecimal) + } +} + +impl Encode<'_, MySql> for Decimal { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_str_lenenc(&self.to_string()); + + IsNull::No + } +} + +impl Decode<'_, MySql> for Decimal { + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(value.as_str()?.parse()?) + } +} diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 940ff49428..9aa85b2fc1 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -48,6 +48,13 @@ //! |---------------------------------------|------------------------------------------------------| //! | `bigdecimal::BigDecimal` | DECIMAL | //! +//! ### [`decimal`](https://crates.io/crates/rust_decimal) +//! Requires the `decimal` Cargo feature flag. +//! +//! | Rust type | MySQL type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `rust_decimal::Decimal` | DECIMAL | +//! //! ### [`json`](https://crates.io/crates/json) //! //! Requires the `json` Cargo feature flag. @@ -72,6 +79,9 @@ mod uint; #[cfg(feature = "bigdecimal")] mod bigdecimal; +#[cfg(feature = "decimal")] +mod decimal; + #[cfg(feature = "chrono")] mod chrono; diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs new file mode 100644 index 0000000000..02cd748156 --- /dev/null +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -0,0 +1,435 @@ +use std::convert::{TryFrom, TryInto}; + +use num_bigint::{BigInt, Sign}; +use num_traits::ToPrimitive; +use rust_decimal::{prelude::Zero, Decimal}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::postgres::types::numeric::{PgNumeric, PgNumericSign}; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use crate::types::Type; + +impl Type for Decimal { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC + } +} + +impl Type for [Decimal] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC_ARRAY + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[Decimal] as Type>::type_info() + } +} + +impl TryFrom for Decimal { + type Error = BoxDynError; + + fn try_from(numeric: PgNumeric) -> Result { + let (digits, sign, weight) = match numeric { + PgNumeric::Number { + digits, + sign, + weight, + .. + } => (digits, sign, weight), + + PgNumeric::NotANumber => { + return Err("Decimal does not support NaN values".into()); + } + }; + + if digits.is_empty() { + // Postgres returns an empty digit array for 0 but BigInt expects at least one zero + return Ok(0u64.into()); + } + + let sign = match sign { + PgNumericSign::Positive => Sign::Plus, + PgNumericSign::Negative => Sign::Minus, + }; + + // weight is 0 if the decimal point falls after the first base-10000 digit + let scale = (digits.len() as i64 - weight as i64 - 1) * 4; + + // no optimized algorithm for base-10 so use base-100 for faster processing + let mut cents = Vec::with_capacity(digits.len() * 2); + for digit in &digits { + cents.push((digit / 100) as u8); + cents.push((digit % 100) as u8); + } + + let bigint = BigInt::from_radix_be(sign, ¢s, 100) + .ok_or("PgNumeric contained an out-of-range digit")?; + + match bigint.to_i128() { + Some(num) => Ok(Decimal::from_i128_with_scale(num, scale as u32)), + None => Err("Decimal's integer part out of range.".into()), + } + } +} + +impl TryFrom<&'_ Decimal> for PgNumeric { + type Error = BoxDynError; + + fn try_from(decimal: &Decimal) -> Result { + if decimal.is_zero() { + return Ok(PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![], + }); + } + + let scale = decimal.scale() as u16; + let groups_diff = scale & 0x3; + let s = decimal.serialize(); + + let mut mantissa = [ + u32::from_le_bytes(s[4..8].try_into().unwrap()), + u32::from_le_bytes(s[8..12].try_into().unwrap()), + u32::from_le_bytes(s[12..16].try_into().unwrap()), + 0u32, + ]; + + if groups_diff > 0 { + let remainder: u16 = 4 - groups_diff; + let power = 10u32.pow(remainder as u32); + mul_by_u32(&mut mantissa, power); + } + + // array to store max mantissa of Decimal in Postgres decimal format + let mut digits = Vec::with_capacity(8); + + while !mantissa.iter().all(|b| *b == 0) { + let remainder = div_by_u32(&mut mantissa, 10000) as u16; + digits.push(remainder as i16) + } + + digits.reverse(); + + let digits_after_decimal = (scale + 3) as u16 / 4; + let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + + while let Some(&0) = digits.last() { + digits.pop(); + } + + Ok(PgNumeric::Number { + sign: match decimal.is_sign_negative() { + false => PgNumericSign::Positive, + true => PgNumericSign::Negative, + }, + scale: scale as i16, + weight, + digits, + }) + } +} + +/// ### Panics +/// If this `Decimal` cannot be represented by [PgNumeric]. +impl Encode<'_, Postgres> for Decimal { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + PgNumeric::try_from(self) + .expect("Decimal magnitude too great for Postgres NUMERIC type") + .encode(buf); + + IsNull::No + } +} + +impl Decode<'_, Postgres> for Decimal { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), + PgValueFormat::Text => Ok(value.as_str()?.parse::()?), + } + } +} + +// Returns remainder +fn div_by_u32(bits: &mut [u32], divisor: u32) -> u32 { + assert_ne!(0, divisor); + + if divisor == 1 { + // dividend remains unchanged + 0 + } else { + let mut remainder = 0u32; + let divisor = u64::from(divisor); + + for part in bits.iter_mut().rev() { + let temp = (u64::from(remainder) << 32) + u64::from(*part); + remainder = (temp % divisor) as u32; + *part = (temp / divisor) as u32; + } + + remainder + } +} + +fn mul_by_u32(bits: &mut [u32], m: u32) -> u32 { + let mut overflow = 0; + + for num in bits.iter_mut() { + let (lo, hi) = mul_part(*num, m, overflow); + + *num = lo; + overflow = hi; + } + + overflow +} + +fn mul_part(left: u32, right: u32, high: u32) -> (u32, u32) { + let result = u64::from(left) * u64::from(right) + u64::from(high); + let hi = (result >> 32) as u32; + let lo = result as u32; + + (lo, hi) +} + +#[cfg(test)] +mod decimal_to_pgnumeric { + use super::{Decimal, PgNumeric, PgNumericSign}; + use std::convert::TryFrom; + + #[test] + fn zero() { + let zero: Decimal = "0".parse().unwrap(); + + assert_eq!( + PgNumeric::try_from(&zero).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![] + } + ); + } + + #[test] + fn one() { + let one: Decimal = "1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1] + } + ); + } + + #[test] + fn ten() { + let ten: Decimal = "10".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![10] + } + ); + } + + #[test] + fn one_hundred() { + let one_hundred: Decimal = "100".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_hundred).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![100] + } + ); + } + + #[test] + fn ten_thousand() { + // Decimal doesn't normalize here + let ten_thousand: Decimal = "10000".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten_thousand).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1] + } + ); + } + + #[test] + fn two_digits() { + let two_digits: Decimal = "12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&two_digits).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn one_tenth() { + let one_tenth: Decimal = "0.1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_tenth).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 1, + weight: -1, + digits: vec![1000] + } + ); + } + + #[test] + fn decimal_1() { + let decimal: Decimal = "1.2345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 4, + weight: 0, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn decimal_2() { + let decimal: Decimal = "0.12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![1234, 5000] + } + ); + } + + #[test] + fn decimal_3() { + let decimal: Decimal = "0.01234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![0123, 4000] + } + ); + } + + #[test] + fn decimal_4() { + let decimal: Decimal = "12345.67890".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: 1, + digits: vec![1, 2345, 6789] + } + ); + } + + #[test] + fn one_digit_decimal() { + let one_digit_decimal: Decimal = "0.00001234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_digit_decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 8, + weight: -2, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let four_digit: Decimal = "1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_negative_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_four_digit: Decimal = "-1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let eight_digit: Decimal = "12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } + + #[test] + fn issue_423_negative_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_eight_digit: Decimal = "-12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } +} diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index b65c29d49d..ea13ce0dfd 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -13,6 +13,20 @@ //! | `&str`, `String` | VARCHAR, CHAR(N), TEXT, NAME | //! | `&[u8]`, `Vec` | BYTEA | //! +//! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) +//! Requires the `bigdecimal` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bigdecimal::BigDecimal` | NUMERIC | +//! +//! ### [`decimal`](https://crates.io/crates/rust_decimal) +//! Requires the `decimal` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `rust_decimal::Decimal` | NUMERIC | +//! //! ### [`chrono`](https://crates.io/crates/chrono) //! //! Requires the `chrono` Cargo feature flag. @@ -145,9 +159,12 @@ mod tuple; #[cfg(feature = "bigdecimal")] mod bigdecimal; -#[cfg(feature = "bigdecimal")] +#[cfg(any(feature = "bigdecimal", feature = "decimal"))] mod numeric; +#[cfg(feature = "decimal")] +mod decimal; + #[cfg(feature = "chrono")] mod chrono;