Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add random SQL function #303

Merged
merged 3 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
regex = { version = "^1.4.3", optional = true }
lazy_static = { version = "^1.4.0", optional = true }
smallvec = { version = "1.6", features = ["union"] }
rand = "0.8"

[dev-dependencies]
rand = "0.8"
criterion = "0.3"
tempfile = "3"
doc-comment = "0.3"
Expand Down
11 changes: 10 additions & 1 deletion datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ pub enum BuiltinScalarFunction {
NullIf,
/// octet_length
OctetLength,
/// random
Random,
/// regexp_replace
RegexpReplace,
/// repeat
Expand Down Expand Up @@ -219,7 +221,10 @@ impl BuiltinScalarFunction {
/// an allowlist of functions to take zero arguments, so that they will get special treatment
/// while executing.
fn supports_zero_argument(&self) -> bool {
matches!(self, BuiltinScalarFunction::Now)
matches!(
self,
BuiltinScalarFunction::Random | BuiltinScalarFunction::Now
)
}
}

Expand Down Expand Up @@ -275,6 +280,7 @@ impl FromStr for BuiltinScalarFunction {
"md5" => BuiltinScalarFunction::MD5,
"nullif" => BuiltinScalarFunction::NullIf,
"octet_length" => BuiltinScalarFunction::OctetLength,
"random" => BuiltinScalarFunction::Random,
"regexp_replace" => BuiltinScalarFunction::RegexpReplace,
"repeat" => BuiltinScalarFunction::Repeat,
"replace" => BuiltinScalarFunction::Replace,
Expand Down Expand Up @@ -438,6 +444,7 @@ pub fn return_type(
));
}
}),
BuiltinScalarFunction::Random => Ok(DataType::Float64),
BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
DataType::Utf8 => DataType::Utf8,
Expand Down Expand Up @@ -742,6 +749,7 @@ pub fn create_physical_expr(
BuiltinScalarFunction::Ln => math_expressions::ln,
BuiltinScalarFunction::Log10 => math_expressions::log10,
BuiltinScalarFunction::Log2 => math_expressions::log2,
BuiltinScalarFunction::Random => math_expressions::random,
BuiltinScalarFunction::Round => math_expressions::round,
BuiltinScalarFunction::Signum => math_expressions::signum,
BuiltinScalarFunction::Sin => math_expressions::sin,
Expand Down Expand Up @@ -1307,6 +1315,7 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
]),
BuiltinScalarFunction::Random => Signature::Exact(vec![]),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
Expand Down
36 changes: 35 additions & 1 deletion datafusion/src/physical_plan/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
// under the License.

//! Math expressions

use super::{ColumnarValue, ScalarValue};
use crate::error::{DataFusionError, Result};
use arrow::array::{Float32Array, Float64Array};
use arrow::datatypes::DataType;
use rand::{thread_rng, Rng};
use std::iter;
use std::sync::Arc;

macro_rules! downcast_compute_op {
Expand Down Expand Up @@ -100,3 +101,36 @@ math_unary_function!("exp", exp);
math_unary_function!("ln", ln);
math_unary_function!("log2", log2);
math_unary_function!("log10", log10);

/// random SQL function
pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
ColumnarValue::Array(array) => array.len(),
_ => {
return Err(DataFusionError::Internal(
"Expect random function to take no param".to_string(),
))
}
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}

#[cfg(test)]
mod tests {

use super::*;
use arrow::array::{Float64Array, NullArray};

#[test]
fn test_random_expression() {
let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))];
let array = random(&args).expect("fail").into_array(1);
let floats = array.as_any().downcast_ref::<Float64Array>().expect("fail");

assert_eq!(floats.len(), 1);
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
}
}
1 change: 0 additions & 1 deletion datafusion/src/physical_plan/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ pub fn data_types(
if current_types.is_empty() {
return Ok(vec![]);
}

let valid_types = get_valid_types(signature, current_types)?;

if valid_types
Expand Down
11 changes: 11 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,17 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_random_expression() -> Result<()> {
let mut ctx = create_ctx()?;
let sql = "SELECT random() r1";
let actual = execute(&mut ctx, sql).await;
let r1 = actual[0][0].parse::<f64>().unwrap();
assert!(0.0 <= r1);
assert!(r1 < 1.0);
Ok(())
}

#[tokio::test]
async fn test_cast_expressions_error() -> Result<()> {
// sin(utf8) should error
Expand Down