diff --git a/CHANGELOG.md b/CHANGELOG.md index c6103bd113..406030f4f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to ### Added - cosmwasm-std: Implement `ops::Rem` for `Uint{64,128,256,512}`. +- cosmwasm-std: Implement `Decimal{,256}::checked_mul` and + `Decimal{,256}::checked_pow`. ### Changed diff --git a/packages/std/src/math/decimal.rs b/packages/std/src/math/decimal.rs index 4617967331..76762e7998 100644 --- a/packages/std/src/math/decimal.rs +++ b/packages/std/src/math/decimal.rs @@ -8,6 +8,7 @@ use std::str::FromStr; use thiserror::Error; use crate::errors::StdError; +use crate::OverflowError; use super::Fraction; use super::Isqrt; @@ -151,6 +152,53 @@ impl Decimal { Self::DECIMAL_PLACES as u32 } + /// Multiplies one `Decimal` by another, returning an `OverflowError` if an overflow occurred. + pub fn checked_mul(self, other: Self) -> Result { + let result_as_uint256 = self.numerator().full_mul(other.numerator()) + / Uint256::from_uint128(Self::DECIMAL_FRACTIONAL); // from_uint128 is a const method and should be "free" + result_as_uint256 + .try_into() + .map(Self) + .map_err(|_| OverflowError { + operation: crate::OverflowOperation::Mul, + operand1: self.to_string(), + operand2: other.to_string(), + }) + } + + /// Raises a value to the power of `exp`, returning an `OverflowError` if an overflow occurred. + pub fn checked_pow(self, exp: u32) -> Result { + // This uses the exponentiation by squaring algorithm: + // https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Basic_method + + fn inner(mut x: Decimal, mut n: u32) -> Result { + if n == 0 { + return Ok(Decimal::one()); + } + + let mut y = Decimal::one(); + + while n > 1 { + if n % 2 == 0 { + x = x.checked_mul(x)?; + n /= 2; + } else { + y = x.checked_mul(y)?; + x = x.checked_mul(x)?; + n = (n - 1) / 2; + } + } + + Ok(x * y) + } + + inner(self, exp).map_err(|_| OverflowError { + operation: crate::OverflowOperation::Pow, + operand1: self.to_string(), + operand2: exp.to_string(), + }) + } + /// Returns the approximate square root as a Decimal. /// /// This should not overflow or panic. @@ -952,6 +1000,37 @@ mod tests { let _value = Decimal::MAX * Decimal::percent(101); } + #[test] + fn decimal_checked_mul() { + let test_data = [ + (Decimal::zero(), Decimal::zero()), + (Decimal::zero(), Decimal::one()), + (Decimal::one(), Decimal::zero()), + (Decimal::percent(10), Decimal::zero()), + (Decimal::percent(10), Decimal::percent(5)), + (Decimal::MAX, Decimal::one()), + (Decimal::MAX / 2u128.into(), Decimal::percent(200)), + (Decimal::permille(6), Decimal::permille(13)), + ]; + + // The regular std::ops::Mul is our source of truth for these tests. + for (x, y) in test_data.iter().cloned() { + assert_eq!(x * y, x.checked_mul(y).unwrap()); + } + } + + #[test] + fn decimal_checked_mul_overflow() { + assert_eq!( + Decimal::MAX.checked_mul(Decimal::percent(200)), + Err(OverflowError { + operation: crate::OverflowOperation::Mul, + operand1: Decimal::MAX.to_string(), + operand2: Decimal::percent(200).to_string(), + }) + ); + } + #[test] // in this test the Decimal is on the right fn uint128_decimal_multiply() { @@ -1068,6 +1147,91 @@ mod tests { ); } + #[test] + fn decimal_checked_pow() { + for exp in 0..10 { + assert_eq!(Decimal::one().checked_pow(exp).unwrap(), Decimal::one()); + } + + // This case is mathematically undefined but we ensure consistency with Rust stdandard types + // https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=20df6716048e77087acd40194b233494 + assert_eq!(Decimal::zero().checked_pow(0).unwrap(), Decimal::one()); + + for exp in 1..10 { + assert_eq!(Decimal::zero().checked_pow(exp).unwrap(), Decimal::zero()); + } + + for num in &[ + Decimal::percent(50), + Decimal::percent(99), + Decimal::percent(200), + ] { + assert_eq!(num.checked_pow(0).unwrap(), Decimal::one()) + } + + assert_eq!( + Decimal::percent(20).checked_pow(2).unwrap(), + Decimal::percent(4) + ); + + assert_eq!( + Decimal::percent(20).checked_pow(3).unwrap(), + Decimal::permille(8) + ); + + assert_eq!( + Decimal::percent(200).checked_pow(4).unwrap(), + Decimal::percent(1600) + ); + + assert_eq!( + Decimal::percent(200).checked_pow(4).unwrap(), + Decimal::percent(1600) + ); + + assert_eq!( + Decimal::percent(700).checked_pow(5).unwrap(), + Decimal::percent(1680700) + ); + + assert_eq!( + Decimal::percent(700).checked_pow(8).unwrap(), + Decimal::percent(576480100) + ); + + assert_eq!( + Decimal::percent(700).checked_pow(10).unwrap(), + Decimal::percent(28247524900) + ); + + assert_eq!( + Decimal::percent(120).checked_pow(123).unwrap(), + Decimal(5486473221892422150877397607u128.into()) + ); + + assert_eq!( + Decimal::percent(10).checked_pow(2).unwrap(), + Decimal(10000000000000000u128.into()) + ); + + assert_eq!( + Decimal::percent(10).checked_pow(18).unwrap(), + Decimal(1u128.into()) + ); + } + + #[test] + fn decimal_checked_pow_overflow() { + assert_eq!( + Decimal::MAX.checked_pow(2), + Err(OverflowError { + operation: crate::OverflowOperation::Pow, + operand1: Decimal::MAX.to_string(), + operand2: "2".to_string(), + }) + ); + } + #[test] fn decimal_to_string() { // Integers diff --git a/packages/std/src/math/decimal256.rs b/packages/std/src/math/decimal256.rs index 2da086762d..34ece9bf45 100644 --- a/packages/std/src/math/decimal256.rs +++ b/packages/std/src/math/decimal256.rs @@ -8,7 +8,7 @@ use std::str::FromStr; use thiserror::Error; use crate::errors::StdError; -use crate::Uint512; +use crate::{OverflowError, Uint512}; use super::Fraction; use super::Isqrt; @@ -164,6 +164,53 @@ impl Decimal256 { Self::DECIMAL_PLACES as u32 } + /// Multiplies one `Decimal256` by another, returning an `OverflowError` if an overflow occurred. + pub fn checked_mul(self, other: Self) -> Result { + let result_as_uint512 = self.numerator().full_mul(other.numerator()) + / Uint512::from_uint256(Self::DECIMAL_FRACTIONAL); // from_uint128 is a const method and should be "free" + result_as_uint512 + .try_into() + .map(Self) + .map_err(|_| OverflowError { + operation: crate::OverflowOperation::Mul, + operand1: self.to_string(), + operand2: other.to_string(), + }) + } + + /// Raises a value to the power of `exp`, returning an `OverflowError` if an overflow occurred. + pub fn checked_pow(self, exp: u32) -> Result { + // This uses the exponentiation by squaring algorithm: + // https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Basic_method + + fn inner(mut x: Decimal256, mut n: u32) -> Result { + if n == 0 { + return Ok(Decimal256::one()); + } + + let mut y = Decimal256::one(); + + while n > 1 { + if n % 2 == 0 { + x = x.checked_mul(x)?; + n /= 2; + } else { + y = x.checked_mul(y)?; + x = x.checked_mul(x)?; + n = (n - 1) / 2; + } + } + + Ok(x * y) + } + + inner(self, exp).map_err(|_| OverflowError { + operation: crate::OverflowOperation::Pow, + operand1: self.to_string(), + operand2: exp.to_string(), + }) + } + /// Returns the approximate square root as a Decimal256. /// /// This should not overflow or panic. @@ -1031,6 +1078,37 @@ mod tests { let _value = Decimal256::MAX * Decimal256::percent(101); } + #[test] + fn decimal256_checked_mul() { + let test_data = [ + (Decimal256::zero(), Decimal256::zero()), + (Decimal256::zero(), Decimal256::one()), + (Decimal256::one(), Decimal256::zero()), + (Decimal256::percent(10), Decimal256::zero()), + (Decimal256::percent(10), Decimal256::percent(5)), + (Decimal256::MAX, Decimal256::one()), + (Decimal256::MAX / 2u128.into(), Decimal256::percent(200)), + (Decimal256::permille(6), Decimal256::permille(13)), + ]; + + // The regular std::ops::Mul is our source of truth for these tests. + for (x, y) in test_data.iter().cloned() { + assert_eq!(x * y, x.checked_mul(y).unwrap()); + } + } + + #[test] + fn decimal256_checked_mul_overflow() { + assert_eq!( + Decimal256::MAX.checked_mul(Decimal256::percent(200)), + Err(OverflowError { + operation: crate::OverflowOperation::Mul, + operand1: Decimal256::MAX.to_string(), + operand2: Decimal256::percent(200).to_string(), + }) + ); + } + #[test] // in this test the Decimal256 is on the right fn uint128_decimal_multiply() { @@ -1151,6 +1229,100 @@ mod tests { ); } + #[test] + fn decimal256_checked_pow() { + for exp in 0..10 { + assert_eq!( + Decimal256::one().checked_pow(exp).unwrap(), + Decimal256::one() + ); + } + + // This case is mathematically undefined but we ensure consistency with Rust stdandard types + // https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=20df6716048e77087acd40194b233494 + assert_eq!( + Decimal256::zero().checked_pow(0).unwrap(), + Decimal256::one() + ); + + for exp in 1..10 { + assert_eq!( + Decimal256::zero().checked_pow(exp).unwrap(), + Decimal256::zero() + ); + } + + for num in &[ + Decimal256::percent(50), + Decimal256::percent(99), + Decimal256::percent(200), + ] { + assert_eq!(num.checked_pow(0).unwrap(), Decimal256::one()) + } + + assert_eq!( + Decimal256::percent(20).checked_pow(2).unwrap(), + Decimal256::percent(4) + ); + + assert_eq!( + Decimal256::percent(20).checked_pow(3).unwrap(), + Decimal256::permille(8) + ); + + assert_eq!( + Decimal256::percent(200).checked_pow(4).unwrap(), + Decimal256::percent(1600) + ); + + assert_eq!( + Decimal256::percent(200).checked_pow(4).unwrap(), + Decimal256::percent(1600) + ); + + assert_eq!( + Decimal256::percent(700).checked_pow(5).unwrap(), + Decimal256::percent(1680700) + ); + + assert_eq!( + Decimal256::percent(700).checked_pow(8).unwrap(), + Decimal256::percent(576480100) + ); + + assert_eq!( + Decimal256::percent(700).checked_pow(10).unwrap(), + Decimal256::percent(28247524900) + ); + + assert_eq!( + Decimal256::percent(120).checked_pow(123).unwrap(), + Decimal256(5486473221892422150877397607u128.into()) + ); + + assert_eq!( + Decimal256::percent(10).checked_pow(2).unwrap(), + Decimal256(10000000000000000u128.into()) + ); + + assert_eq!( + Decimal256::percent(10).checked_pow(18).unwrap(), + Decimal256(1u128.into()) + ); + } + + #[test] + fn decimal256_checked_pow_overflow() { + assert_eq!( + Decimal256::MAX.checked_pow(2), + Err(OverflowError { + operation: crate::OverflowOperation::Pow, + operand1: Decimal256::MAX.to_string(), + operand2: "2".to_string(), + }) + ); + } + #[test] fn decimal256_to_string() { // Integers