diff --git a/python/Cargo.toml b/python/Cargo.toml index ee99359a82f0..fe84e5234c33 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -31,7 +31,7 @@ libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" pyo3 = { version = "0.14.1", features = ["extension-module"] } -datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "e4df37a4001423909964348289360da66acdd0a3" } +datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" } [lib] name = "datafusion" diff --git a/python/src/functions.rs b/python/src/functions.rs index 415490743185..23f010a6ae45 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -20,7 +20,7 @@ use crate::udf; use crate::{expression, types::PyDataType}; use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; -use pyo3::{prelude::*, wrap_pyfunction}; +use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction}; use std::sync::Arc; /// Expression representing a column on the existing plan. @@ -76,11 +76,33 @@ fn now() -> expression::Expression { #[pyfunction] fn random() -> expression::Expression { expression::Expression { - // here lit(0) is a stub for conform to arity - expr: logical_plan::random(logical_plan::lit(0)), + expr: logical_plan::random(), } } +/// Concatenates the text representations of all the arguments. +/// NULL arguments are ignored. +#[pyfunction(args = "*")] +fn concat(args: &PyTuple) -> PyResult { + let expressions = expression::from_tuple(args)?; + let args = expressions.into_iter().map(|e| e.expr).collect::>(); + Ok(expression::Expression { + expr: logical_plan::concat(&args), + }) +} + +/// Concatenates all but the first argument, with separators. +/// The first argument is used as the separator string, and should not be NULL. +/// Other NULL arguments are ignored. +#[pyfunction(sep, args = "*")] +fn concat_ws(sep: String, args: &PyTuple) -> PyResult { + let expressions = expression::from_tuple(args)?; + let args = expressions.into_iter().map(|e| e.expr).collect::>(); + Ok(expression::Expression { + expr: logical_plan::concat_ws(sep, &args), + }) +} + macro_rules! define_unary_function { ($NAME: ident) => { #[doc = "This function is not documented yet"] @@ -132,7 +154,6 @@ define_unary_function!( "Returns number of characters in the string." ); define_unary_function!(chr, "Returns the character with the given code."); -define_unary_function!(concat_ws, "Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored."); define_unary_function!(initcap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); define_unary_function!(left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); define_unary_function!(lower, "Converts the string to all lower case"); @@ -179,15 +200,6 @@ define_unary_function!(min); define_unary_function!(max); define_unary_function!(count); -/* -#[pyfunction] -fn concat(value: Vec) -> expression::Expression { - expression::Expression { - expr: logical_plan::concat(value.into_iter().map(|e| e.expr)), - } -} - */ - pub(crate) fn create_udf( fun: PyObject, input_types: Vec, @@ -250,26 +262,39 @@ fn udaf( } pub fn init(module: &PyModule) -> PyResult<()> { - module.add_function(wrap_pyfunction!(col, module)?)?; - module.add_function(wrap_pyfunction!(lit, module)?)?; - // see https://github.com/apache/arrow-datafusion/issues/226 - //module.add_function(wrap_pyfunction!(concat, module)?)?; - module.add_function(wrap_pyfunction!(udf, module)?)?; + module.add_function(wrap_pyfunction!(abs, module)?)?; + module.add_function(wrap_pyfunction!(acos, module)?)?; module.add_function(wrap_pyfunction!(array, module)?)?; module.add_function(wrap_pyfunction!(ascii, module)?)?; + module.add_function(wrap_pyfunction!(asin, module)?)?; + module.add_function(wrap_pyfunction!(atan, module)?)?; + module.add_function(wrap_pyfunction!(avg, module)?)?; module.add_function(wrap_pyfunction!(bit_length, module)?)?; + module.add_function(wrap_pyfunction!(btrim, module)?)?; + module.add_function(wrap_pyfunction!(ceil, module)?)?; module.add_function(wrap_pyfunction!(character_length, module)?)?; module.add_function(wrap_pyfunction!(chr, module)?)?; - module.add_function(wrap_pyfunction!(btrim, module)?)?; + module.add_function(wrap_pyfunction!(col, module)?)?; module.add_function(wrap_pyfunction!(concat_ws, module)?)?; + module.add_function(wrap_pyfunction!(concat, module)?)?; + module.add_function(wrap_pyfunction!(cos, module)?)?; + module.add_function(wrap_pyfunction!(count, module)?)?; + module.add_function(wrap_pyfunction!(exp, module)?)?; + module.add_function(wrap_pyfunction!(floor, module)?)?; module.add_function(wrap_pyfunction!(in_list, module)?)?; module.add_function(wrap_pyfunction!(initcap, module)?)?; module.add_function(wrap_pyfunction!(left, module)?)?; + module.add_function(wrap_pyfunction!(lit, module)?)?; + module.add_function(wrap_pyfunction!(ln, module)?)?; + module.add_function(wrap_pyfunction!(log10, module)?)?; + module.add_function(wrap_pyfunction!(log2, module)?)?; module.add_function(wrap_pyfunction!(lower, module)?)?; module.add_function(wrap_pyfunction!(lpad, module)?)?; + module.add_function(wrap_pyfunction!(ltrim, module)?)?; + module.add_function(wrap_pyfunction!(max, module)?)?; module.add_function(wrap_pyfunction!(md5, module)?)?; + module.add_function(wrap_pyfunction!(min, module)?)?; module.add_function(wrap_pyfunction!(now, module)?)?; - module.add_function(wrap_pyfunction!(ltrim, module)?)?; module.add_function(wrap_pyfunction!(octet_length, module)?)?; module.add_function(wrap_pyfunction!(random, module)?)?; module.add_function(wrap_pyfunction!(regexp_replace, module)?)?; @@ -277,43 +302,29 @@ pub fn init(module: &PyModule) -> PyResult<()> { module.add_function(wrap_pyfunction!(replace, module)?)?; module.add_function(wrap_pyfunction!(reverse, module)?)?; module.add_function(wrap_pyfunction!(right, module)?)?; + module.add_function(wrap_pyfunction!(round, module)?)?; module.add_function(wrap_pyfunction!(rpad, module)?)?; module.add_function(wrap_pyfunction!(rtrim, module)?)?; module.add_function(wrap_pyfunction!(sha224, module)?)?; module.add_function(wrap_pyfunction!(sha256, module)?)?; module.add_function(wrap_pyfunction!(sha384, module)?)?; module.add_function(wrap_pyfunction!(sha512, module)?)?; + module.add_function(wrap_pyfunction!(signum, module)?)?; + module.add_function(wrap_pyfunction!(sin, module)?)?; module.add_function(wrap_pyfunction!(split_part, module)?)?; + module.add_function(wrap_pyfunction!(sqrt, module)?)?; module.add_function(wrap_pyfunction!(starts_with, module)?)?; module.add_function(wrap_pyfunction!(strpos, module)?)?; module.add_function(wrap_pyfunction!(substr, module)?)?; + module.add_function(wrap_pyfunction!(sum, module)?)?; + module.add_function(wrap_pyfunction!(tan, module)?)?; module.add_function(wrap_pyfunction!(to_hex, module)?)?; module.add_function(wrap_pyfunction!(translate, module)?)?; module.add_function(wrap_pyfunction!(trim, module)?)?; - module.add_function(wrap_pyfunction!(upper, module)?)?; - module.add_function(wrap_pyfunction!(sum, module)?)?; - module.add_function(wrap_pyfunction!(count, module)?)?; - module.add_function(wrap_pyfunction!(min, module)?)?; - module.add_function(wrap_pyfunction!(max, module)?)?; - module.add_function(wrap_pyfunction!(avg, module)?)?; - module.add_function(wrap_pyfunction!(udaf, module)?)?; - module.add_function(wrap_pyfunction!(sqrt, module)?)?; - module.add_function(wrap_pyfunction!(sin, module)?)?; - module.add_function(wrap_pyfunction!(cos, module)?)?; - module.add_function(wrap_pyfunction!(tan, module)?)?; - module.add_function(wrap_pyfunction!(asin, module)?)?; - module.add_function(wrap_pyfunction!(acos, module)?)?; - module.add_function(wrap_pyfunction!(atan, module)?)?; - module.add_function(wrap_pyfunction!(floor, module)?)?; - module.add_function(wrap_pyfunction!(ceil, module)?)?; - module.add_function(wrap_pyfunction!(round, module)?)?; module.add_function(wrap_pyfunction!(trunc, module)?)?; - module.add_function(wrap_pyfunction!(abs, module)?)?; - module.add_function(wrap_pyfunction!(signum, module)?)?; - module.add_function(wrap_pyfunction!(exp, module)?)?; - module.add_function(wrap_pyfunction!(ln, module)?)?; - module.add_function(wrap_pyfunction!(log2, module)?)?; - module.add_function(wrap_pyfunction!(log10, module)?)?; + module.add_function(wrap_pyfunction!(udaf, module)?)?; + module.add_function(wrap_pyfunction!(udf, module)?)?; + module.add_function(wrap_pyfunction!(upper, module)?)?; Ok(()) } diff --git a/python/tests/test_math_functions.py b/python/tests/test_math_functions.py index 56d4824aeb9d..cb03753121fa 100644 --- a/python/tests/test_math_functions.py +++ b/python/tests/test_math_functions.py @@ -44,6 +44,7 @@ def test_math_functions(df): f.ln(col_v + f.lit(1)), f.log2(col_v + f.lit(1)), f.log10(col_v + f.lit(1)), + f.random(), ) result = df.collect() assert len(result) == 1 @@ -58,3 +59,4 @@ def test_math_functions(df): np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0)) np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0)) np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0)) + np.testing.assert_array_less(result.column(10), np.ones_like(values))