diff --git a/Cargo.toml b/Cargo.toml index ba95924064d..f8e285c1c64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ hashbrown = { version = ">= 0.9, < 0.14", optional = true } indexmap = { version = "1.6", optional = true } num-bigint = { version = "0.4", optional = true } num-complex = { version = ">= 0.2, < 0.5", optional = true } +rust_decimal = { version = "1.0.0", default-features = false, optional = true } serde = { version = "1.0", optional = true } [dev-dependencies] @@ -53,6 +54,7 @@ send_wrapper = "0.6" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.61" rayon = "1.0.2" +rust_decimal = { version = "1.8.0", features = ["std"] } widestring = "0.5.1" [build-dependencies] @@ -110,6 +112,7 @@ full = [ "eyre", "anyhow", "experimental-inspect", + "rust_decimal", ] [[bench]] @@ -120,6 +123,11 @@ harness = false name = "bench_err" harness = false +[[bench]] +name = "bench_decimal" +harness = false +required-features = ["rust_decimal"] + [[bench]] name = "bench_dict" harness = false @@ -173,5 +181,5 @@ members = [ [package.metadata.docs.rs] no-default-features = true -features = ["macros", "num-bigint", "num-complex", "hashbrown", "serde", "multiple-pymethods", "indexmap", "eyre", "chrono"] +features = ["macros", "num-bigint", "num-complex", "hashbrown", "serde", "multiple-pymethods", "indexmap", "eyre", "chrono", "rust_decimal"] rustdoc-args = ["--cfg", "docsrs"] diff --git a/benches/bench_decimal.rs b/benches/bench_decimal.rs new file mode 100644 index 00000000000..7a370ac3505 --- /dev/null +++ b/benches/bench_decimal.rs @@ -0,0 +1,32 @@ +use criterion::{black_box, criterion_group, criterion_main, Bencher, Criterion}; + +use pyo3::prelude::*; +use pyo3::types::PyDict; +use rust_decimal::Decimal; + +fn decimal_via_extract(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + let locals = PyDict::new(py); + py.run( + r#" +import decimal +py_dec = decimal.Decimal("0.0") +"#, + None, + Some(locals), + ) + .unwrap(); + let py_dec = locals.get_item("py_dec").unwrap(); + + b.iter(|| { + let _: Decimal = black_box(py_dec).extract().unwrap(); + }); + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("decimal_via_extract", decimal_via_extract); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/guide/src/features.md b/guide/src/features.md index f24bb2f109f..d9a051e4641 100644 --- a/guide/src/features.md +++ b/guide/src/features.md @@ -129,6 +129,10 @@ Adds a dependency on [num-bigint](https://docs.rs/num-bigint) and enables conver Adds a dependency on [num-complex](https://docs.rs/num-complex) and enables conversions into its [`Complex`](https://docs.rs/num-complex/latest/num_complex/struct.Complex.html) type. +### `rust_decimal` + +Adds a dependency on [rust_decimal](https://docs.rs/rust_decimal) and enables conversions into its [`Decimal`](https://docs.rs/rust_decimal/latest/rust_decimal/struct.Decimal.html) type. + ### `serde` Enables (de)serialization of `Py` objects via [serde](https://serde.rs/). diff --git a/newsfragments/3016.added.md b/newsfragments/3016.added.md new file mode 100644 index 00000000000..e421dadd6ca --- /dev/null +++ b/newsfragments/3016.added.md @@ -0,0 +1 @@ +Added support for converting to and from Python's `decimal.Decimal` and `rust_decimal::Decimal`. diff --git a/noxfile.py b/noxfile.py index 0d02e0c3f48..ff7ab7f0ac9 100644 --- a/noxfile.py +++ b/noxfile.py @@ -422,6 +422,10 @@ def set_minimal_package_versions(session: nox.Session, venv_backend="none"): "examples/word-count", ) min_pkg_versions = { + # newer versions of rust_decimal want newer arrayvec + "rust_decimal": "1.18.0", + # newer versions of arrayvec use const generics (Rust 1.51+) + "arrayvec": "0.5.2", "csv": "1.1.6", "indexmap": "1.6.2", "inventory": "0.3.4", diff --git a/src/conversions/mod.rs b/src/conversions/mod.rs index 20d19e81234..5544dc23532 100644 --- a/src/conversions/mod.rs +++ b/src/conversions/mod.rs @@ -7,5 +7,6 @@ pub mod hashbrown; pub mod indexmap; pub mod num_bigint; pub mod num_complex; +pub mod rust_decimal; pub mod serde; mod std; diff --git a/src/conversions/rust_decimal.rs b/src/conversions/rust_decimal.rs new file mode 100644 index 00000000000..3e42a352b87 --- /dev/null +++ b/src/conversions/rust_decimal.rs @@ -0,0 +1,221 @@ +#![cfg(feature = "rust_decimal")] +//! Conversions to and from [rust_decimal](https://docs.rs/rust_decimal)'s [`Decimal`] type. +//! +//! This is useful for converting Python's decimal.Decimal into and from a native Rust type. +//! +//! # Setup +//! +//! To use this feature, add to your **`Cargo.toml`**: +//! +//! ```toml +//! [dependencies] +//! rust_decimal = "1.0" +// workaround for `extended_key_value_attributes`: https://github.com/rust-lang/rust/issues/82768#issuecomment-803935643 +#![cfg_attr(docsrs, cfg_attr(docsrs, doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"rust_decimal\"] }")))] +#![cfg_attr( + not(docsrs), + doc = "pyo3 = { version = ..., features = [\"rust_decimal\"] }" +)] +//! ``` +//! +//! Note that you must use a compatible version of rust_decimal and PyO3. +//! The required rust_decimal version may vary based on the version of PyO3. +//! +//! # Example +//! +//! Rust code to create a function that adds one to a Decimal +//! +//! ```rust +//! use rust_decimal::Decimal; +//! use pyo3::prelude::*; +//! +//! #[pyfunction] +//! fn add_one(d: Decimal) -> Decimal { +//! d + Decimal::ONE +//! } +//! +//! #[pymodule] +//! fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +//! m.add_function(wrap_pyfunction!(add_one, m)?)?; +//! Ok(()) +//! } +//! ``` +//! +//! Python code that validates the functionality +//! +//! +//! ```python +//! from my_module import add_one +//! from decimal import Decimal +//! +//! d = Decimal("2") +//! value = add_one(d) +//! +//! assert d + 1 == value +//! ``` + +use crate::exceptions::PyValueError; +use crate::once_cell::GILOnceCell; +use crate::types::PyType; +use crate::{intern, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject}; +use rust_decimal::Decimal; +use std::str::FromStr; + +impl FromPyObject<'_> for Decimal { + fn extract(obj: &PyAny) -> PyResult { + // use the string representation to not be lossy + if let Ok(val) = obj.extract() { + Ok(Decimal::new(val, 0)) + } else { + Decimal::from_str(obj.str()?.to_str()?) + .map_err(|e| PyValueError::new_err(e.to_string())) + } + } +} + +static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new(); + +fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> { + DECIMAL_CLS + .get_or_try_init(py, || { + py.import(intern!(py, "decimal"))? + .getattr(intern!(py, "Decimal"))? + .extract() + }) + .map(|ty| ty.as_ref(py)) +} + +impl ToPyObject for Decimal { + fn to_object(&self, py: Python<'_>) -> PyObject { + // TODO: handle error gracefully when ToPyObject can error + // look up the decimal.Decimal + let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal"); + // now call the constructor with the Rust Decimal string-ified + // to not be lossy + let ret = dec_cls + .call1((self.to_string(),)) + .expect("failed to call decimal.Decimal(value)"); + ret.to_object(py) + } +} + +impl IntoPy for Decimal { + fn into_py(self, py: Python<'_>) -> PyObject { + self.to_object(py) + } +} + +#[cfg(test)] +mod test_rust_decimal { + use super::*; + use crate::err::PyErr; + use crate::types::PyDict; + use rust_decimal::Decimal; + + #[cfg(not(target_arch = "wasm32"))] + use proptest::prelude::*; + + macro_rules! convert_constants { + ($name:ident, $rs:expr, $py:literal) => { + #[test] + fn $name() { + Python::with_gil(|py| { + let rs_orig = $rs; + let rs_dec = rs_orig.into_py(py); + let locals = PyDict::new(py); + locals.set_item("rs_dec", &rs_dec).unwrap(); + // Checks if Rust Decimal -> Python Decimal conversion is correct + py.run( + &format!( + "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec", + $py + ), + None, + Some(locals), + ) + .unwrap(); + // Checks if Python Decimal -> Rust Decimal conversion is correct + let py_dec = locals.get_item("py_dec").unwrap(); + let py_result: Decimal = FromPyObject::extract(py_dec).unwrap(); + assert_eq!(rs_orig, py_result); + }) + } + }; + } + + convert_constants!(convert_zero, Decimal::ZERO, "0"); + convert_constants!(convert_one, Decimal::ONE, "1"); + convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1"); + convert_constants!(convert_two, Decimal::TWO, "2"); + convert_constants!(convert_ten, Decimal::TEN, "10"); + convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100"); + convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000"); + + #[cfg(not(target_arch = "wasm32"))] + proptest! { + #[test] + fn test_roundtrip( + lo in any::(), + mid in any::(), + high in any::(), + negative in any::(), + scale in 0..28u32 + ) { + let num = Decimal::from_parts(lo, mid, high, negative, scale); + Python::with_gil(|py| { + let rs_dec = num.into_py(py); + let locals = PyDict::new(py); + locals.set_item("rs_dec", &rs_dec).unwrap(); + py.run( + &format!( + "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec", + num), + None, Some(locals)).unwrap(); + let roundtripped: Decimal = rs_dec.extract(py).unwrap(); + assert_eq!(num, roundtripped); + }) + } + + #[test] + fn test_integers(num in any::()) { + Python::with_gil(|py| { + let py_num = num.into_py(py); + let roundtripped: Decimal = py_num.extract(py).unwrap(); + let rs_dec = Decimal::new(num, 0); + assert_eq!(rs_dec, roundtripped); + }) + } + } + + #[test] + fn test_nan() { + Python::with_gil(|py| { + let locals = PyDict::new(py); + py.run( + "import decimal\npy_dec = decimal.Decimal(\"NaN\")", + None, + Some(locals), + ) + .unwrap(); + let py_dec = locals.get_item("py_dec").unwrap(); + let roundtripped: Result = FromPyObject::extract(py_dec); + assert!(roundtripped.is_err()); + }) + } + + #[test] + fn test_infinity() { + Python::with_gil(|py| { + let locals = PyDict::new(py); + py.run( + "import decimal\npy_dec = decimal.Decimal(\"Infinity\")", + None, + Some(locals), + ) + .unwrap(); + let py_dec = locals.get_item("py_dec").unwrap(); + let roundtripped: Result = FromPyObject::extract(py_dec); + assert!(roundtripped.is_err()); + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index ae37d65b1f2..1a33a135476 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -99,6 +99,8 @@ //! [`BigUint`] types. //! - [`num-complex`]: Enables conversions between Python objects and [num-complex]'s [`Complex`] //! type. +//! - [`rust_decimal`]: Enables conversions between Python's decimal.Decimal and [rust_decimal]'s +//! [`Decimal`] type. //! - [`serde`]: Allows implementing [serde]'s [`Serialize`] and [`Deserialize`] traits for //! [`Py`]`` for all `T` that implement [`Serialize`] and [`Deserialize`]. //! @@ -275,6 +277,9 @@ //! [`num-bigint`]: ./num_bigint/index.html "Documentation about the `num-bigint` feature." //! [`num-complex`]: ./num_complex/index.html "Documentation about the `num-complex` feature." //! [`pyo3-build-config`]: https://docs.rs/pyo3-build-config +//! [rust_decimal]: https://docs.rs/rust_decimal +//! [`rust_decimal`]: ./rust_decimal/index.html "Documenation about the `rust_decimal` feature." +//! [`Decimal`]: https://docs.rs/rust_decimal/latest/rust_decimal/struct.Decimal.html //! [`serde`]: <./serde/index.html> "Documentation about the `serde` feature." //! [calling_rust]: https://pyo3.rs/latest/python_from_rust.html "Calling Python from Rust - PyO3 user guide" //! [examples subdirectory]: https://github.com/PyO3/pyo3/tree/main/examples