diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index da289e1a139d..f7e4916a94af 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -522,6 +522,9 @@ impl Expr { /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + // TODO(kszucs): most of the operations do not validate the type correctness + // like all of the binary expressions below. Perhaps Expr should track the + // type of the expression? let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self) @@ -1417,6 +1420,8 @@ pub fn random() -> Expr { } } +// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many +// varying arity functions /// Create an convenience function representing a unary scalar function macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { diff --git a/python/src/context.rs b/python/src/context.rs index 8d6c93b91255..e41122f90998 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -27,7 +27,7 @@ use pyo3::prelude::*; use datafusion::arrow::datatypes::{DataType, Schema}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; use crate::catalog::PyCatalog; @@ -45,6 +45,7 @@ pub(crate) struct PyExecutionContext { #[pymethods] impl PyExecutionContext { + // TODO(kszucs): should expose the configuration options as keyword arguments #[new] fn new() -> Self { PyExecutionContext { diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 69e6edf52fe2..20a52a8b9a18 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use pyo3::prelude::*; use tokio::runtime::Runtime; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::PyArrowConvert; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::dataframe::DataFrame; use datafusion::logical_plan::JoinType; @@ -45,6 +45,11 @@ impl PyDataFrame { #[pymethods] impl PyDataFrame { + /// Returns the schema from the logical plan + fn schema(&self) -> Schema { + self.df.schema().into() + } + #[args(args = "*")] fn select(&self, args: Vec) -> PyResult { let expr = args.into_iter().map(|e| e.into()).collect(); @@ -64,6 +69,7 @@ impl PyDataFrame { Ok(Self::new(df)) } + #[args(exprs = "*")] fn sort(&self, exprs: Vec) -> PyResult { let exprs = exprs.into_iter().map(|e| e.into()).collect(); let df = self.df.sort(exprs)?; diff --git a/python/src/expression.rs b/python/src/expression.rs index c55b248fd51d..63e9d7b1d665 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -19,6 +19,7 @@ use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; use std::convert::{From, Into}; use std::vec::Vec; +use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan::{col, lit, Expr}; use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; use datafusion::scalar::ScalarValue; @@ -60,6 +61,10 @@ impl PyNumberProtocol for PyExpr { Ok((lhs.expr * rhs.expr).into()) } + fn __mod__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().modulus(rhs.expr).into()) + } + fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { Ok(lhs.expr.clone().and(rhs.expr).into()) } @@ -104,6 +109,16 @@ impl PyExpr { pub fn is_null(&self) -> PyExpr { self.expr.clone().is_null().into() } + + pub fn cast(&self, to: DataType) -> PyExpr { + // self.expr.cast_to() requires DFSchema to validate that the cast + // is supported, omit that for now + let expr = Expr::Cast { + expr: Box::new(self.expr.clone()), + data_type: to, + }; + expr.into() + } } /// Represents a PyScalarUDF diff --git a/python/src/functions.rs b/python/src/functions.rs index 087fa041591b..cf150816cb5d 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -23,8 +23,7 @@ use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; use datafusion::logical_plan::Expr; use datafusion::physical_plan::{ - aggregates::AggregateFunction, functions::BuiltinScalarFunction, udaf::AggregateUDF, - udf::ScalarUDF, + aggregates::AggregateFunction, functions::BuiltinScalarFunction, }; use crate::{ @@ -170,6 +169,7 @@ scalar_function!( "Computes the MD5 hash of the argument, with the result written in hexadecimal." ); scalar_function!(octet_length, OctetLength, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); +scalar_function!(regexp_match, RegexpMatch); scalar_function!( regexp_replace, RegexpReplace, @@ -215,6 +215,7 @@ scalar_function!( ToHex, "Converts the number to its equivalent hexadecimal representation." ); +scalar_function!(to_timestamp, ToTimestamp); scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); scalar_function!(trim, Trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); scalar_function!(trunc, Trunc); @@ -315,6 +316,7 @@ pub fn init(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; m.add_wrapped(wrap_pyfunction!(random))?; + m.add_wrapped(wrap_pyfunction!(regexp_match))?; m.add_wrapped(wrap_pyfunction!(regexp_replace))?; m.add_wrapped(wrap_pyfunction!(repeat))?; m.add_wrapped(wrap_pyfunction!(replace))?; @@ -337,6 +339,7 @@ pub fn init(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(sum))?; m.add_wrapped(wrap_pyfunction!(tan))?; m.add_wrapped(wrap_pyfunction!(to_hex))?; + m.add_wrapped(wrap_pyfunction!(to_timestamp))?; m.add_wrapped(wrap_pyfunction!(translate))?; m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?;