Skip to content

Commit

Permalink
A couple of additional operators
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 30, 2021
1 parent 5f81203 commit df9b924
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 4 deletions.
5 changes: 5 additions & 0 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ impl Expr {
/// This function errors when it is impossible to cast the
/// expression to the target [arrow::datatypes::DataType].
pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
// TODO(kszucs): most of the operations do not validate the type correctness
// like all of the binary expressions below. Perhaps Expr should track the
// type of the expression?
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self)
Expand Down Expand Up @@ -1417,6 +1420,8 @@ pub fn random() -> Expr {
}
}

// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
Expand Down
3 changes: 2 additions & 1 deletion python/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use pyo3::prelude::*;
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use datafusion::execution::context::{ExecutionConfig, ExecutionContext};
use datafusion::execution::context::ExecutionContext;
use datafusion::prelude::CsvReadOptions;

use crate::catalog::PyCatalog;
Expand All @@ -45,6 +45,7 @@ pub(crate) struct PyExecutionContext {

#[pymethods]
impl PyExecutionContext {
// TODO(kszucs): should expose the configuration options as keyword arguments
#[new]
fn new() -> Self {
PyExecutionContext {
Expand Down
8 changes: 7 additions & 1 deletion python/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use std::sync::Arc;
use pyo3::prelude::*;
use tokio::runtime::Runtime;

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowConvert;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::logical_plan::JoinType;

Expand All @@ -45,6 +45,11 @@ impl PyDataFrame {

#[pymethods]
impl PyDataFrame {
/// Returns the schema from the logical plan
fn schema(&self) -> Schema {
self.df.schema().into()
}

#[args(args = "*")]
fn select(&self, args: Vec<PyExpr>) -> PyResult<Self> {
let expr = args.into_iter().map(|e| e.into()).collect();
Expand All @@ -64,6 +69,7 @@ impl PyDataFrame {
Ok(Self::new(df))
}

#[args(exprs = "*")]
fn sort(&self, exprs: Vec<PyExpr>) -> PyResult<Self> {
let exprs = exprs.into_iter().map(|e| e.into()).collect();
let df = self.df.sort(exprs)?;
Expand Down
15 changes: 15 additions & 0 deletions python/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol};
use std::convert::{From, Into};
use std::vec::Vec;

use datafusion::arrow::datatypes::DataType;
use datafusion::logical_plan::{col, lit, Expr};
use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF};
use datafusion::scalar::ScalarValue;
Expand Down Expand Up @@ -60,6 +61,10 @@ impl PyNumberProtocol for PyExpr {
Ok((lhs.expr * rhs.expr).into())
}

fn __mod__(lhs: PyExpr, rhs: PyExpr) -> PyResult<PyExpr> {
Ok(lhs.expr.clone().modulus(rhs.expr).into())
}

fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult<PyExpr> {
Ok(lhs.expr.clone().and(rhs.expr).into())
}
Expand Down Expand Up @@ -104,6 +109,16 @@ impl PyExpr {
pub fn is_null(&self) -> PyExpr {
self.expr.clone().is_null().into()
}

pub fn cast(&self, to: DataType) -> PyExpr {
// self.expr.cast_to() requires DFSchema to validate that the cast
// is supported, omit that for now
let expr = Expr::Cast {
expr: Box::new(self.expr.clone()),
data_type: to,
};
expr.into()
}
}

/// Represents a PyScalarUDF
Expand Down
7 changes: 5 additions & 2 deletions python/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::logical_plan;
use datafusion::logical_plan::Expr;
use datafusion::physical_plan::{
aggregates::AggregateFunction, functions::BuiltinScalarFunction, udaf::AggregateUDF,
udf::ScalarUDF,
aggregates::AggregateFunction, functions::BuiltinScalarFunction,
};

use crate::{
Expand Down Expand Up @@ -170,6 +169,7 @@ scalar_function!(
"Computes the MD5 hash of the argument, with the result written in hexadecimal."
);
scalar_function!(octet_length, OctetLength, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces.");
scalar_function!(regexp_match, RegexpMatch);
scalar_function!(
regexp_replace,
RegexpReplace,
Expand Down Expand Up @@ -215,6 +215,7 @@ scalar_function!(
ToHex,
"Converts the number to its equivalent hexadecimal representation."
);
scalar_function!(to_timestamp, ToTimestamp);
scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.");
scalar_function!(trim, Trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string.");
scalar_function!(trunc, Trunc);
Expand Down Expand Up @@ -315,6 +316,7 @@ pub fn init(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(now))?;
m.add_wrapped(wrap_pyfunction!(octet_length))?;
m.add_wrapped(wrap_pyfunction!(random))?;
m.add_wrapped(wrap_pyfunction!(regexp_match))?;
m.add_wrapped(wrap_pyfunction!(regexp_replace))?;
m.add_wrapped(wrap_pyfunction!(repeat))?;
m.add_wrapped(wrap_pyfunction!(replace))?;
Expand All @@ -337,6 +339,7 @@ pub fn init(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(sum))?;
m.add_wrapped(wrap_pyfunction!(tan))?;
m.add_wrapped(wrap_pyfunction!(to_hex))?;
m.add_wrapped(wrap_pyfunction!(to_timestamp))?;
m.add_wrapped(wrap_pyfunction!(translate))?;
m.add_wrapped(wrap_pyfunction!(trim))?;
m.add_wrapped(wrap_pyfunction!(trunc))?;
Expand Down

0 comments on commit df9b924

Please sign in to comment.