Skip to content

Commit

Permalink
feat(expr): support for sqrt function (risingwavelabs#9017)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@HQ-10MSTD3EY.roblox.local>
  • Loading branch information
lyang24 and root authored Apr 10, 2023
1 parent 95ab15c commit 1e3221b
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions e2e_test/batch/functions/sqrt.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# testing sqrt(double precision)
query T
SELECT abs(sqrt('1004.3') - '31.690692639953454') < 1e-12;
----
t

query T
SELECT abs(sqrt('1.2345678901234e+200') - '1.1111111061110856e+100') < 1e-12;
----
t

query T
SELECT abs(sqrt('1.2345678901234e-200') - '1.1111111061110855e-100') < 1e-12;
----
t

# testing sqrt(numeric)
query T
SELECT abs(sqrt(1004.3) - 31.690692639953453690117860318) < 1e-15;
----
t

query T
SELECT abs(sqrt(82416.3252::decimal) - 287.08243624436518286386154499) < 1e-15;
----
t
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ message ExprNode {
ACOS = 250;
ATAN = 251;
ATAN2 = 252;
SQRT = 253;

// Boolean comparison
IS_TRUE = 301;
Expand Down
1 change: 1 addition & 0 deletions src/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ risingwave_common = { path = "../common" }
risingwave_expr_macro = { path = "macro" }
risingwave_pb = { path = "../prost" }
risingwave_udf = { path = "../udf" }
rust_decimal = { version = "1", features = ["db-postgres", "maths"] }
speedate = "0.7.0"
static_assertions = "1"
thiserror = "1"
Expand Down
72 changes: 70 additions & 2 deletions src/expr/src/vector_op/arithmetic_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use std::convert::TryInto;
use std::fmt::Debug;

use chrono::{Duration, NaiveDateTime};
use num_traits::real::Real;
use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Signed, Zero};
use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Float, Signed, Zero};
use risingwave_common::types::{CheckedAdd, Date, Decimal, Interval, Time, Timestamp, F64};
use risingwave_expr_macro::function;
use rust_decimal::MathematicalOps;

use crate::{ExprError, Result};

Expand Down Expand Up @@ -327,6 +327,39 @@ where
r.mul_float(l).ok_or(ExprError::NumericOutOfRange)
}

#[function("sqrt(float64) -> float64")]
pub fn sqrt_f64(expr: F64) -> Result<F64> {
if expr < F64::from(0.0) {
return Err(ExprError::InvalidParam {
name: "sqrt input",
reason: "input cannot be negative value".to_string(),
});
}
// Edge cases: nan, inf, negative zero should return itself.
match expr.is_nan() || expr == f64::INFINITY || expr.is_negative() {
true => Ok(expr),
false => Ok(expr.sqrt()),
}
}

#[function("sqrt(decimal) -> decimal")]
pub fn sqrt_decimal(expr: Decimal) -> Result<Decimal> {
match expr {
Decimal::NaN | Decimal::PositiveInf => Ok(expr),
Decimal::Normalized(value) => match value.sqrt() {
Some(res) => Ok(Decimal::from(res)),
None => Err(ExprError::InvalidParam {
name: "sqrt input",
reason: "input cannot be negative value".to_string(),
}),
},
Decimal::NegativeInf => Err(ExprError::InvalidParam {
name: "sqrt input",
reason: "input cannot be negative value".to_string(),
}),
}
}

#[cfg(test)]
mod tests {
use std::str::FromStr;
Expand Down Expand Up @@ -428,6 +461,41 @@ mod tests {
NaiveDateTime::parse_from_str("1993-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
)
);
assert_eq!(sqrt_f64(F64::from(25.00)).unwrap(), F64::from(5.0));
assert_eq!(
sqrt_f64(F64::from(107)).unwrap(),
F64::from(10.344080432788601)
);
assert_eq!(
sqrt_f64(F64::from(12.234567)).unwrap(),
F64::from(3.4977945908815173)
);
assert!(sqrt_f64(F64::from(-25.00)).is_err());
// sqrt edge cases.
assert_eq!(sqrt_f64(F64::from(f64::NAN)).unwrap(), F64::from(f64::NAN));
assert_eq!(
sqrt_f64(F64::from(f64::neg_zero())).unwrap(),
F64::from(f64::neg_zero())
);
assert_eq!(
sqrt_f64(F64::from(f64::INFINITY)).unwrap(),
F64::from(f64::INFINITY)
);
assert!(sqrt_f64(F64::from(f64::NEG_INFINITY)).is_err());
assert_eq!(sqrt_decimal(dec("25.0")).unwrap(), dec("5.0"));
assert_eq!(
sqrt_decimal(dec("107")).unwrap(),
dec("10.344080432788600469738599442")
);
assert_eq!(
sqrt_decimal(dec("12.234567")).unwrap(),
dec("3.4977945908815171589625746860")
);
assert!(sqrt_decimal(dec("-25.0")).is_err());
assert_eq!(sqrt_decimal(dec("nan")).unwrap(), dec("nan"));
assert_eq!(sqrt_decimal(dec("inf")).unwrap(), dec("inf"));
assert_eq!(sqrt_decimal(dec("-0")).unwrap(), dec("-0"));
assert!(sqrt_decimal(dec("-inf")).is_err());
}

fn dec(s: &str) -> Decimal {
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ impl Binder {
("asin", raw_call(ExprType::Asin)),
("acos", raw_call(ExprType::Acos)),
("atan", raw_call(ExprType::Atan)),
("atan2", raw_call(ExprType::Atan2)),
("atan2", raw_call(ExprType::Atan2)),
("sqrt", raw_call(ExprType::Sqrt)),

(
"to_timestamp",
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ impl Binder {
UnaryOperator::Plus => {
return self.rewrite_positive(expr);
}
UnaryOperator::PGSquareRoot => ExprType::Sqrt,
_ => {
return Err(ErrorCode::NotImplemented(
format!("unsupported unary expression: {:?}", op),
Expand Down
4 changes: 2 additions & 2 deletions src/tests/regress/data/sql/float8.sql
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ select floor(f1) as floor_f1 from float8_tbl f;
SET extra_float_digits = 0;

-- square root
--@ SELECT sqrt(double precision '64') AS eight;
SELECT sqrt(double precision '64') AS eight;

--@ SELECT |/ double precision '64' AS eight;
SELECT |/ double precision '64' AS eight;

--@ SELECT f.f1, |/f.f1 AS sqrt_f1
--@ FROM FLOAT8_TBL f
Expand Down

0 comments on commit 1e3221b

Please sign in to comment.