From 63a91d34ef4eab24f00eaa46868ad5d58382027d Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 26 Jun 2020 18:01:07 +0200 Subject: [PATCH 1/9] Support for rust_decimal::Decimal --- Cargo.lock | 12 + Cargo.toml | 3 +- sqlx-core/Cargo.toml | 5 +- sqlx-core/src/mysql/types/decimal.rs | 29 ++ sqlx-core/src/mysql/types/mod.rs | 10 + sqlx-core/src/postgres/types/decimal.rs | 454 ++++++++++++++++++++++++ sqlx-core/src/postgres/types/mod.rs | 19 +- 7 files changed, 529 insertions(+), 3 deletions(-) create mode 100644 sqlx-core/src/mysql/types/decimal.rs create mode 100644 sqlx-core/src/postgres/types/decimal.rs diff --git a/Cargo.lock b/Cargo.lock index 240ca80866..e01dce2745 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2077,6 +2077,16 @@ dependencies = [ "crossbeam-utils 0.6.6", ] +[[package]] +name = "rust_decimal" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b5f52edf35045e96b07aa29822bf4ce8495295fd5610270f85ab1f26df7ba5" +dependencies = [ + "num-traits", + "serde", +] + [[package]] name = "rustc-demangle" version = "0.1.16" @@ -2450,6 +2460,7 @@ dependencies = [ "md-5", "memchr", "num-bigint", + "num-traits", "once_cell", "parking_lot 0.11.0", "percent-encoding 2.1.0", @@ -2457,6 +2468,7 @@ dependencies = [ "rand", "regex", "rsa", + "rust_decimal", "serde", "serde_json", "sha-1", diff --git a/Cargo.toml b/Cargo.toml index 0db5d9c89b..4f89817c3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ offline = [ "sqlx-macros/offline", "sqlx-core/offline" ] # intended mainly for CI and docs all = [ "tls", "all-databases", "all-types" ] all-databases = [ "mysql", "sqlite", "postgres", "mssql", "any" ] -all-types = [ "bigdecimal", "json", "time", "chrono", "ipnetwork", "uuid" ] +all-types = [ "bigdecimal", "decimal", "json", "time", "chrono", "ipnetwork", "uuid" ] # runtime runtime-async-std = [ "sqlx-core/runtime-async-std", "sqlx-macros/runtime-async-std" ] @@ -66,6 +66,7 @@ mssql = [ "sqlx-core/mssql", "sqlx-macros/mssql" ] # types bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"] +decimal = ["sqlx-core/decimal"] chrono = [ "sqlx-core/chrono", "sqlx-macros/chrono" ] ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-macros/ipnetwork" ] uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 8fa0dbc55f..89b0c6b26c 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -24,8 +24,9 @@ mssql = [ "uuid", "encoding_rs", "regex" ] any = [] # types -all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ] +all-types = [ "chrono", "time", "bigdecimal", "decimal", "ipnetwork", "json", "uuid" ] bigdecimal = [ "bigdecimal_", "num-bigint" ] +decimal = [ "rust_decimal", "num-bigint", "num-traits" ] json = [ "serde", "serde_json" ] # runtimes @@ -41,6 +42,8 @@ atoi = "0.3.2" sqlx-rt = { path = "../sqlx-rt", version = "0.1.0-pre" } base64 = { version = "0.12.1", default-features = false, optional = true, features = [ "std" ] } bigdecimal_ = { version = "0.1.0", optional = true, package = "bigdecimal" } +rust_decimal = { version = "1.6.0", optional = true } +num-traits = { version = "0.2.12", optional = true } bitflags = { version = "1.2.1", default-features = false } bytes = "0.5.4" byteorder = { version = "1.3.4", default-features = false, features = [ "std" ] } diff --git a/sqlx-core/src/mysql/types/decimal.rs b/sqlx-core/src/mysql/types/decimal.rs new file mode 100644 index 0000000000..0826b6bbd9 --- /dev/null +++ b/sqlx-core/src/mysql/types/decimal.rs @@ -0,0 +1,29 @@ +use rust_decimal::Decimal; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::io::MySqlBufMutExt; +use crate::mysql::protocol::text::ColumnType; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; +use crate::types::Type; + +impl Type for Decimal { + fn type_info() -> MySqlTypeInfo { + MySqlTypeInfo::binary(ColumnType::NewDecimal) + } +} + +impl Encode<'_, MySql> for Decimal { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_str_lenenc(&self.to_string()); + + IsNull::No + } +} + +impl Decode<'_, MySql> for Decimal { + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(value.as_str()?.parse()?) + } +} diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 940ff49428..9aa85b2fc1 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -48,6 +48,13 @@ //! |---------------------------------------|------------------------------------------------------| //! | `bigdecimal::BigDecimal` | DECIMAL | //! +//! ### [`decimal`](https://crates.io/crates/rust_decimal) +//! Requires the `decimal` Cargo feature flag. +//! +//! | Rust type | MySQL type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `rust_decimal::Decimal` | DECIMAL | +//! //! ### [`json`](https://crates.io/crates/json) //! //! Requires the `json` Cargo feature flag. @@ -72,6 +79,9 @@ mod uint; #[cfg(feature = "bigdecimal")] mod bigdecimal; +#[cfg(feature = "decimal")] +mod decimal; + #[cfg(feature = "chrono")] mod chrono; diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs new file mode 100644 index 0000000000..12f6d23bca --- /dev/null +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -0,0 +1,454 @@ +use std::convert::{TryFrom, TryInto}; + +use num_bigint::{BigInt, Sign}; +use num_traits::ToPrimitive; +use rust_decimal::{prelude::Zero, Decimal}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::postgres::types::numeric::{PgNumeric, PgNumericSign}; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use crate::types::Type; + +impl Type for Decimal { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC + } +} + +impl Type for [Decimal] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUMERIC_ARRAY + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[Decimal] as Type>::type_info() + } +} + +impl TryFrom for Decimal { + type Error = BoxDynError; + + fn try_from(numeric: PgNumeric) -> Result { + let (digits, sign, weight) = match numeric { + PgNumeric::Number { + digits, + sign, + weight, + .. + } => (digits, sign, weight), + + PgNumeric::NotANumber => { + return Err("Decimal does not support NaN values".into()); + } + }; + + if digits.is_empty() { + // Postgres returns an empty digit array for 0 but BigInt expects at least one zero + return Ok(0u64.into()); + } + + let sign = match sign { + PgNumericSign::Positive => Sign::Plus, + PgNumericSign::Negative => Sign::Minus, + }; + + // weight is 0 if the decimal point falls after the first base-10000 digit + let scale = (digits.len() as i64 - weight as i64 - 1) * 4; + + // no optimized algorithm for base-10 so use base-100 for faster processing + let mut cents = Vec::with_capacity(digits.len() * 2); + for digit in &digits { + cents.push((digit / 100) as u8); + cents.push((digit % 100) as u8); + } + + let bigint = BigInt::from_radix_be(sign, ¢s, 100) + .ok_or("PgNumeric contained an out-of-range digit")?; + + match bigint.to_i128() { + Some(num) => Ok(Decimal::from_i128_with_scale(num, scale as u32)), + None => Err("Decimal's integer part out of range.".into()), + } + } +} + +impl TryFrom<&'_ Decimal> for PgNumeric { + type Error = BoxDynError; + + fn try_from(decimal: &Decimal) -> Result { + if decimal.is_zero() { + return Ok(PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![], + }); + } + + let scale = decimal.scale() as u16; + + // A serialized version of the decimal number. The resulting byte array + // will have the following representation: + // + // Bytes 1-4: flags + // Bytes 5-8: lo portion of m + // Bytes 9-12: mid portion of m + // Bytes 13-16: high portion of m + let s = decimal.serialize(); + + // As u96. + let mut mantissa = [ + // lo + u32::from_le_bytes(s[4..8].try_into().unwrap()), + // mid + u32::from_le_bytes(s[8..12].try_into().unwrap()), + // hi + u32::from_le_bytes(s[12..16].try_into().unwrap()), + // flags + 0u32, + ]; + + // If our scale is not a multiple of 4, we need to go to the next + // multiple. + let groups_diff = scale % 4; + if groups_diff > 0 { + let remainder: u16 = 4 - groups_diff; + let power = 10u32.pow(remainder as u32); + mul_by_u32(&mut mantissa, power); + } + + // Array to store max mantissa of Decimal in Postgres decimal format. + let mut digits = Vec::with_capacity(8); + + // Convert to base-10000. + while !mantissa.iter().all(|b| *b == 0) { + let remainder = div_by_u32(&mut mantissa, 10000) as u16; + digits.push(remainder as i16) + } + + // Change the endianness. + digits.reverse(); + + // Weight is number of digits on the left side of the decimal. + let digits_after_decimal = (scale + 3) as u16 / 4; + let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + + // Remove non-significant zeroes. + while let Some(&0) = digits.last() { + digits.pop(); + } + + Ok(PgNumeric::Number { + sign: match decimal.is_sign_negative() { + false => PgNumericSign::Positive, + true => PgNumericSign::Negative, + }, + scale: scale as i16, + weight, + digits, + }) + } +} + +/// ### Panics +/// If this `Decimal` cannot be represented by [PgNumeric]. +impl Encode<'_, Postgres> for Decimal { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + PgNumeric::try_from(self) + .expect("Decimal magnitude too great for Postgres NUMERIC type") + .encode(buf); + + IsNull::No + } +} + +impl Decode<'_, Postgres> for Decimal { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), + PgValueFormat::Text => Ok(value.as_str()?.parse::()?), + } + } +} + +// Returns remainder +fn div_by_u32(bits: &mut [u32], divisor: u32) -> u32 { + assert_ne!(0, divisor); + + if divisor == 1 { + // dividend remains unchanged + 0 + } else { + let mut remainder = 0u32; + let divisor = u64::from(divisor); + + for part in bits.iter_mut().rev() { + let temp = (u64::from(remainder) << 32) + u64::from(*part); + remainder = (temp % divisor) as u32; + *part = (temp / divisor) as u32; + } + + remainder + } +} + +fn mul_by_u32(bits: &mut [u32], m: u32) -> u32 { + let mut overflow = 0; + + for num in bits.iter_mut() { + let (lo, hi) = mul_part(*num, m, overflow); + + *num = lo; + overflow = hi; + } + + overflow +} + +fn mul_part(left: u32, right: u32, high: u32) -> (u32, u32) { + let result = u64::from(left) * u64::from(right) + u64::from(high); + let hi = (result >> 32) as u32; + let lo = result as u32; + + (lo, hi) +} + +#[cfg(test)] +mod decimal_to_pgnumeric { + use super::{Decimal, PgNumeric, PgNumericSign}; + use std::convert::TryFrom; + + #[test] + fn zero() { + let zero: Decimal = "0".parse().unwrap(); + + assert_eq!( + PgNumeric::try_from(&zero).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![] + } + ); + } + + #[test] + fn one() { + let one: Decimal = "1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1] + } + ); + } + + #[test] + fn ten() { + let ten: Decimal = "10".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![10] + } + ); + } + + #[test] + fn one_hundred() { + let one_hundred: Decimal = "100".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_hundred).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![100] + } + ); + } + + #[test] + fn ten_thousand() { + // Decimal doesn't normalize here + let ten_thousand: Decimal = "10000".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&ten_thousand).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1] + } + ); + } + + #[test] + fn two_digits() { + let two_digits: Decimal = "12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&two_digits).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn one_tenth() { + let one_tenth: Decimal = "0.1".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_tenth).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 1, + weight: -1, + digits: vec![1000] + } + ); + } + + #[test] + fn decimal_1() { + let decimal: Decimal = "1.2345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 4, + weight: 0, + digits: vec![1, 2345] + } + ); + } + + #[test] + fn decimal_2() { + let decimal: Decimal = "0.12345".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![1234, 5000] + } + ); + } + + #[test] + fn decimal_3() { + let decimal: Decimal = "0.01234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: -1, + digits: vec![0123, 4000] + } + ); + } + + #[test] + fn decimal_4() { + let decimal: Decimal = "12345.67890".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 5, + weight: 1, + digits: vec![1, 2345, 6789] + } + ); + } + + #[test] + fn one_digit_decimal() { + let one_digit_decimal: Decimal = "0.00001234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&one_digit_decimal).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 8, + weight: -2, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let four_digit: Decimal = "1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_negative_four_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_four_digit: Decimal = "-1234".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_four_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 0, + digits: vec![1234] + } + ); + } + + #[test] + fn issue_423_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let eight_digit: Decimal = "12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } + + #[test] + fn issue_423_negative_eight_digit() { + // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 + let negative_eight_digit: Decimal = "-12345678".parse().unwrap(); + assert_eq!( + PgNumeric::try_from(&negative_eight_digit).unwrap(), + PgNumeric::Number { + sign: PgNumericSign::Negative, + scale: 0, + weight: 1, + digits: vec![1234, 5678] + } + ); + } +} diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index aaed2d5799..9ab446f845 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -20,6 +20,20 @@ //! [`PgRange`]: struct.PgRange.html //! [`PgMoney`]: struct.PgMoney.html //! +//! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) +//! Requires the `bigdecimal` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bigdecimal::BigDecimal` | NUMERIC | +//! +//! ### [`decimal`](https://crates.io/crates/rust_decimal) +//! Requires the `decimal` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `rust_decimal::Decimal` | NUMERIC | +//! //! ### [`chrono`](https://crates.io/crates/chrono) //! //! Requires the `chrono` Cargo feature flag. @@ -154,9 +168,12 @@ mod tuple; #[cfg(feature = "bigdecimal")] mod bigdecimal; -#[cfg(feature = "bigdecimal")] +#[cfg(any(feature = "bigdecimal", feature = "decimal"))] mod numeric; +#[cfg(feature = "decimal")] +mod decimal; + #[cfg(feature = "chrono")] mod chrono; From 3c1ca9c7e399439abe8cc727db92afd63799a99c Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 3 Jul 2020 14:33:33 +0200 Subject: [PATCH 2/9] Add money conversions to `Decimal` --- sqlx-core/src/postgres/types/money.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sqlx-core/src/postgres/types/money.rs b/sqlx-core/src/postgres/types/money.rs index 50e026eadc..7d2f3fdaf1 100644 --- a/sqlx-core/src/postgres/types/money.rs +++ b/sqlx-core/src/postgres/types/money.rs @@ -35,6 +35,15 @@ impl PgMoney { bigdecimal::BigDecimal::new(digits, scale) } + + /// Convert the money value into a [`Decimal`] using the correct precision + /// defined in the PostgreSQL settings. The default precision is two. + /// + /// [`Decimal`]: ../../types/struct.BigDecimal.html + #[cfg(feature = "decimal")] + pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal { + rust_decimal::Decimal::new(self.0, scale) + } } impl Type for PgMoney { @@ -214,4 +223,13 @@ mod tests { money.to_bigdecimal(2) ); } + + #[test] + #[cfg(feature = "decimal")] + fn conversion_to_decimal_works() { + assert_eq!( + rust_decimal::Decimal::new(12345, 2), + PgMoney(12345).to_decimal(2) + ); + } } From 3ba9b85da1edd03b61ab920ca975331c8d9a0205 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 3 Jul 2020 14:34:46 +0200 Subject: [PATCH 3/9] Test Decimal conversions in my and pg --- sqlx-core/src/types/mod.rs | 4 ++++ tests/mysql/types.rs | 17 ++++++++++++++++- tests/postgres/types.rs | 16 +++++++++++++++- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 1d658117ef..965a7eb4fb 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -39,6 +39,10 @@ pub mod time { #[cfg_attr(docsrs, doc(cfg(feature = "bigdecimal")))] pub use bigdecimal::BigDecimal; +#[cfg(feature = "decimal")] +#[cfg_attr(docsrs, doc(cfg(feature = "decimal")))] +pub use rust_decimal::Decimal; + #[cfg(feature = "ipnetwork")] #[cfg_attr(docsrs, doc(cfg(feature = "ipnetwork")))] pub mod ipnetwork { diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index a8795c0227..57a84d1ae5 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -1,5 +1,8 @@ extern crate time_ as time; +#[cfg(feature = "decimal")] +use std::str::FromStr; + use sqlx::mysql::MySql; use sqlx::{Executor, Row}; use sqlx_test::test_type; @@ -185,7 +188,7 @@ mod time_tests { } #[cfg(feature = "bigdecimal")] -test_type!(decimal( +test_type!(bigdecimal( MySql, "CAST(0 as DECIMAL(0, 0))" == "0".parse::().unwrap(), "CAST(1 AS DECIMAL(1, 0))" == "1".parse::().unwrap(), @@ -196,6 +199,18 @@ test_type!(decimal( "CAST(12345.6789 AS DECIMAL(9, 4))" == "12345.6789".parse::().unwrap(), )); +#[cfg(feature = "decimal")] +test_type!(decimal(MySql, + "CAST(0 as DECIMAL(0, 0))" == sqlx::types::Decimal::from_str("0").unwrap(), + "CAST(1 AS DECIMAL(1, 0))" == sqlx::types::Decimal::from_str("1").unwrap(), + // bug in rust_decimal: https://github.com/paupino/rust-decimal/issues/251 + //"CAST(10000 AS DECIMAL(5, 0))" == sqlx::types::Decimal::from_str("10000").unwrap(), + "CAST(0.1 AS DECIMAL(2, 1))" == sqlx::types::Decimal::from_str("0.1").unwrap(), + "CAST(0.01234 AS DECIMAL(6, 5))" == sqlx::types::Decimal::from_str("0.01234").unwrap(), + "CAST(12.34 AS DECIMAL(4, 2))" == sqlx::types::Decimal::from_str("12.34").unwrap(), + "CAST(12345.6789 AS DECIMAL(9, 4))" == sqlx::types::Decimal::from_str("12345.6789").unwrap(), +)); + #[cfg(feature = "json")] mod json_tests { use super::*; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index c9bb49b949..f94ec971ea 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -1,6 +1,8 @@ extern crate time_ as time; use std::ops::Bound; +#[cfg(feature = "decimal")] +use std::str::FromStr; use sqlx::postgres::types::{PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; @@ -324,7 +326,7 @@ mod json { } #[cfg(feature = "bigdecimal")] -test_type!(decimal(Postgres, +test_type!(bigdecimal(Postgres, // https://github.com/launchbadge/sqlx/issues/283 "0::numeric" == "0".parse::().unwrap(), @@ -337,6 +339,18 @@ test_type!(decimal(Postgres, "12345.6789::numeric" == "12345.6789".parse::().unwrap(), )); +#[cfg(feature = "decimal")] +test_type!(decimal(Postgres, + "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), + "1::numeric" == sqlx::types::Decimal::from_str("1").unwrap(), + // bug in rust_decimal: https://github.com/paupino/rust-decimal/issues/251 + //"10000::numeric" == sqlx::types::Decimal::from_str("10000").unwrap(), + "0.1::numeric" == sqlx::types::Decimal::from_str("0.1").unwrap(), + "0.01234::numeric" == sqlx::types::Decimal::from_str("0.01234").unwrap(), + "12.34::numeric" == sqlx::types::Decimal::from_str("12.34").unwrap(), + "12345.6789::numeric" == sqlx::types::Decimal::from_str("12345.6789").unwrap(), +)); + const EXC2: Bound = Bound::Excluded(2); const EXC3: Bound = Bound::Excluded(3); const INC1: Bound = Bound::Included(1); From 2ba944ab78761d71be0e6a134ba76f95ecd7a204 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Sat, 4 Jul 2020 13:02:16 +0200 Subject: [PATCH 4/9] Fixed an overflow with a negative scale --- sqlx-core/src/postgres/types/decimal.rs | 13 ++++++++++--- tests/mysql/types.rs | 3 +-- tests/postgres/types.rs | 3 +-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs index 12f6d23bca..424536e428 100644 --- a/sqlx-core/src/postgres/types/decimal.rs +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -69,9 +69,16 @@ impl TryFrom for Decimal { let bigint = BigInt::from_radix_be(sign, ¢s, 100) .ok_or("PgNumeric contained an out-of-range digit")?; - match bigint.to_i128() { - Some(num) => Ok(Decimal::from_i128_with_scale(num, scale as u32)), - None => Err("Decimal's integer part out of range.".into()), + match (bigint.to_i128(), scale) { + // A negative scale, meaning we have nothing on the right and must + // add zeroes to the left. + (Some(num), scale) if scale < 0 => Ok(Decimal::from_i128_with_scale( + num * 10i128.pow(scale.abs() as u32), + 0, + )), + // A positive scale, so we have decimals on the right. + (Some(num), _) => Ok(Decimal::from_i128_with_scale(num, scale as u32)), + (None, _) => Err("Decimal's integer part out of range.".into()), } } } diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index 57a84d1ae5..3742149d39 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -203,8 +203,7 @@ test_type!(bigdecimal( test_type!(decimal(MySql, "CAST(0 as DECIMAL(0, 0))" == sqlx::types::Decimal::from_str("0").unwrap(), "CAST(1 AS DECIMAL(1, 0))" == sqlx::types::Decimal::from_str("1").unwrap(), - // bug in rust_decimal: https://github.com/paupino/rust-decimal/issues/251 - //"CAST(10000 AS DECIMAL(5, 0))" == sqlx::types::Decimal::from_str("10000").unwrap(), + "CAST(10000 AS DECIMAL(5, 0))" == sqlx::types::Decimal::from_str("10000").unwrap(), "CAST(0.1 AS DECIMAL(2, 1))" == sqlx::types::Decimal::from_str("0.1").unwrap(), "CAST(0.01234 AS DECIMAL(6, 5))" == sqlx::types::Decimal::from_str("0.01234").unwrap(), "CAST(12.34 AS DECIMAL(4, 2))" == sqlx::types::Decimal::from_str("12.34").unwrap(), diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index f94ec971ea..b4302ee1af 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -343,8 +343,7 @@ test_type!(bigdecimal(Postgres, test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), "1::numeric" == sqlx::types::Decimal::from_str("1").unwrap(), - // bug in rust_decimal: https://github.com/paupino/rust-decimal/issues/251 - //"10000::numeric" == sqlx::types::Decimal::from_str("10000").unwrap(), + "10000::numeric" == sqlx::types::Decimal::from_str("10000").unwrap(), "0.1::numeric" == sqlx::types::Decimal::from_str("0.1").unwrap(), "0.01234::numeric" == sqlx::types::Decimal::from_str("0.01234").unwrap(), "12.34::numeric" == sqlx::types::Decimal::from_str("12.34").unwrap(), From 660c23fec97b839fb76d7592fd154f1f32be7433 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Sat, 4 Jul 2020 13:31:45 +0200 Subject: [PATCH 5/9] Set the write part to use simpler math with u128 --- sqlx-core/src/postgres/types/decimal.rs | 77 ++++++------------------- 1 file changed, 18 insertions(+), 59 deletions(-) diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs index 424536e428..4d17908d0c 100644 --- a/sqlx-core/src/postgres/types/decimal.rs +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -107,34 +107,35 @@ impl TryFrom<&'_ Decimal> for PgNumeric { // Bytes 13-16: high portion of m let s = decimal.serialize(); - // As u96. - let mut mantissa = [ - // lo - u32::from_le_bytes(s[4..8].try_into().unwrap()), - // mid - u32::from_le_bytes(s[8..12].try_into().unwrap()), - // hi - u32::from_le_bytes(s[12..16].try_into().unwrap()), - // flags - 0u32, - ]; + // Moving the flags from the beginning of the array to the end, giving + // us a representation of u96 we can convert to u128. + // + // We also just set the flags as zero, so we don't need to chop them off + // from the resulting integer. + let mut mantissa = u128::from_le_bytes([ + s[4], s[5], s[6], s[7], // lo portion + s[8], s[9], s[10], s[11], // mid portion + s[12], s[13], s[14], s[15], // hi portion + 0, 0, 0, 0, // flags (cleared) + ]); // If our scale is not a multiple of 4, we need to go to the next // multiple. let groups_diff = scale % 4; if groups_diff > 0 { - let remainder: u16 = 4 - groups_diff; - let power = 10u32.pow(remainder as u32); - mul_by_u32(&mut mantissa, power); + let remainder = 4 - groups_diff as u32; + let power = 10u32.pow(remainder as u32) as u128; + + mantissa = mantissa * power; } // Array to store max mantissa of Decimal in Postgres decimal format. let mut digits = Vec::with_capacity(8); // Convert to base-10000. - while !mantissa.iter().all(|b| *b == 0) { - let remainder = div_by_u32(&mut mantissa, 10000) as u16; - digits.push(remainder as i16) + while mantissa != 0 { + digits.push((mantissa % 10_000) as i16); + mantissa /= 10_000; } // Change the endianness. @@ -182,48 +183,6 @@ impl Decode<'_, Postgres> for Decimal { } } -// Returns remainder -fn div_by_u32(bits: &mut [u32], divisor: u32) -> u32 { - assert_ne!(0, divisor); - - if divisor == 1 { - // dividend remains unchanged - 0 - } else { - let mut remainder = 0u32; - let divisor = u64::from(divisor); - - for part in bits.iter_mut().rev() { - let temp = (u64::from(remainder) << 32) + u64::from(*part); - remainder = (temp % divisor) as u32; - *part = (temp / divisor) as u32; - } - - remainder - } -} - -fn mul_by_u32(bits: &mut [u32], m: u32) -> u32 { - let mut overflow = 0; - - for num in bits.iter_mut() { - let (lo, hi) = mul_part(*num, m, overflow); - - *num = lo; - overflow = hi; - } - - overflow -} - -fn mul_part(left: u32, right: u32, high: u32) -> (u32, u32) { - let result = u64::from(left) * u64::from(right) + u64::from(high); - let hi = (result >> 32) as u32; - let lo = result as u32; - - (lo, hi) -} - #[cfg(test)] mod decimal_to_pgnumeric { use super::{Decimal, PgNumeric, PgNumericSign}; From 55fb0821aefa29d3ce27e0aba93f5fd1e656b9f1 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 7 Jul 2020 19:56:25 +0200 Subject: [PATCH 6/9] Conversions from `Decimal` to `PgMoney` --- sqlx-core/src/postgres/types/money.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sqlx-core/src/postgres/types/money.rs b/sqlx-core/src/postgres/types/money.rs index 7d2f3fdaf1..5a25739206 100644 --- a/sqlx-core/src/postgres/types/money.rs +++ b/sqlx-core/src/postgres/types/money.rs @@ -44,6 +44,20 @@ impl PgMoney { pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal { rust_decimal::Decimal::new(self.0, scale) } + + /// Convert a [`Decimal`] value into money using the correct precision + /// defined in the PostgreSQL settings. The default precision is two. + /// + /// [`Decimal`]: ../../types/struct.BigDecimal.html + #[cfg(feature = "decimal")] + pub fn from_decimal(decimal: rust_decimal::Decimal, scale: u32) -> Self { + let cents = (decimal * rust_decimal::Decimal::new(10i64.pow(scale), 0)).round(); + + let mut buf: [u8; 8] = [0; 8]; + buf.copy_from_slice(¢s.serialize()[4..12]); + + Self(i64::from_le_bytes(buf)) + } } impl Type for PgMoney { @@ -232,4 +246,12 @@ mod tests { PgMoney(12345).to_decimal(2) ); } + + #[test] + #[cfg(feature = "decimal")] + fn conversion_from_decimal_works() { + let dec = rust_decimal::Decimal::new(12345, 2); + + assert_eq!(PgMoney(12345), PgMoney::from_decimal(dec, 2)); + } } From a23fc393c43e99c55ae706f177b50b8955bf1cb0 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Thu, 9 Jul 2020 11:21:13 +0200 Subject: [PATCH 7/9] Simplify mantissa handling --- sqlx-core/src/postgres/types/decimal.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs index 4d17908d0c..2224e1a222 100644 --- a/sqlx-core/src/postgres/types/decimal.rs +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -1,8 +1,7 @@ -use std::convert::{TryFrom, TryInto}; - use num_bigint::{BigInt, Sign}; use num_traits::ToPrimitive; use rust_decimal::{prelude::Zero, Decimal}; +use std::convert::{TryFrom, TryInto}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -105,19 +104,10 @@ impl TryFrom<&'_ Decimal> for PgNumeric { // Bytes 5-8: lo portion of m // Bytes 9-12: mid portion of m // Bytes 13-16: high portion of m - let s = decimal.serialize(); + let mut mantissa = u128::from_le_bytes(decimal.serialize()); - // Moving the flags from the beginning of the array to the end, giving - // us a representation of u96 we can convert to u128. - // - // We also just set the flags as zero, so we don't need to chop them off - // from the resulting integer. - let mut mantissa = u128::from_le_bytes([ - s[4], s[5], s[6], s[7], // lo portion - s[8], s[9], s[10], s[11], // mid portion - s[12], s[13], s[14], s[15], // hi portion - 0, 0, 0, 0, // flags (cleared) - ]); + // chop off the flags + mantissa >>= 32; // If our scale is not a multiple of 4, we need to go to the next // multiple. From f5f053debab745e469031fde5a1fdea338cf1f73 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Thu, 9 Jul 2020 11:27:53 +0200 Subject: [PATCH 8/9] Add macro extensions for `Decimal` --- Cargo.toml | 4 ++-- sqlx-macros/Cargo.toml | 1 + sqlx-macros/src/database/mysql.rs | 3 +++ sqlx-macros/src/database/postgres.rs | 3 +++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4f89817c3d..71f18cee29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,8 +65,8 @@ sqlite = [ "sqlx-core/sqlite", "sqlx-macros/sqlite" ] mssql = [ "sqlx-core/mssql", "sqlx-macros/mssql" ] # types -bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"] -decimal = ["sqlx-core/decimal"] +bigdecimal = [ "sqlx-core/bigdecimal", "sqlx-macros/bigdecimal" ] +decimal = [ "sqlx-core/decimal", "sqlx-macros/decimal" ] chrono = [ "sqlx-core/chrono", "sqlx-macros/chrono" ] ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-macros/ipnetwork" ] uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 966b67af36..0566655dba 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -34,6 +34,7 @@ mssql = [ "sqlx-core/mssql" ] # type bigdecimal = [ "sqlx-core/bigdecimal" ] +decimal = [ "sqlx-core/decimal" ] chrono = [ "sqlx-core/chrono" ] time = [ "sqlx-core/time" ] ipnetwork = [ "sqlx-core/ipnetwork" ] diff --git a/sqlx-macros/src/database/mysql.rs b/sqlx-macros/src/database/mysql.rs index 43c5291edc..d583a040c5 100644 --- a/sqlx-macros/src/database/mysql.rs +++ b/sqlx-macros/src/database/mysql.rs @@ -46,6 +46,9 @@ impl_database_ext! { #[cfg(feature = "bigdecimal")] sqlx::types::BigDecimal, + + #[cfg(feature = "decimal")] + sqlx::types::Decimal, }, ParamChecking::Weak, feature-types: info => info.__type_feature_gate(), diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index 3394588a4a..48c60be59a 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -46,6 +46,9 @@ impl_database_ext! { #[cfg(feature = "bigdecimal")] sqlx::types::BigDecimal, + #[cfg(feature = "decimal")] + sqlx::types::Decimal, + #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, From 1d6079dd103f30a758643d669eccb39a0746eb26 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Thu, 9 Jul 2020 14:02:47 +0200 Subject: [PATCH 9/9] Document better! --- sqlx-core/src/postgres/types/money.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sqlx-core/src/postgres/types/money.rs b/sqlx-core/src/postgres/types/money.rs index 5a25739206..1da24bdc2a 100644 --- a/sqlx-core/src/postgres/types/money.rs +++ b/sqlx-core/src/postgres/types/money.rs @@ -25,10 +25,11 @@ use std::{ pub struct PgMoney(pub i64); impl PgMoney { - /// Convert the money value into a [`BigDecimal`] using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert the money value into a [`BigDecimal`] using the correct + /// precision defined in the PostgreSQL settings. The default precision is + /// two. /// - /// [`BigDecimal`]: ../../types/struct.BigDecimal.html + /// [`BigDecimal`]: crate::types::BigDecimal #[cfg(feature = "bigdecimal")] pub fn to_bigdecimal(self, scale: i64) -> bigdecimal::BigDecimal { let digits = num_bigint::BigInt::from(self.0); @@ -39,7 +40,7 @@ impl PgMoney { /// Convert the money value into a [`Decimal`] using the correct precision /// defined in the PostgreSQL settings. The default precision is two. /// - /// [`Decimal`]: ../../types/struct.BigDecimal.html + /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal { rust_decimal::Decimal::new(self.0, scale) @@ -48,7 +49,9 @@ impl PgMoney { /// Convert a [`Decimal`] value into money using the correct precision /// defined in the PostgreSQL settings. The default precision is two. /// - /// [`Decimal`]: ../../types/struct.BigDecimal.html + /// Conversion may involve a loss of precision. + /// + /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] pub fn from_decimal(decimal: rust_decimal::Decimal, scale: u32) -> Self { let cents = (decimal * rust_decimal::Decimal::new(10i64.pow(scale), 0)).round();