Skip to content

Commit

Permalink
closing up type checks (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist authored Jun 7, 2021
1 parent ee2b9ef commit 767eeb0
Show file tree
Hide file tree
Showing 12 changed files with 512 additions and 62 deletions.
2 changes: 1 addition & 1 deletion ballista/rust/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ futures = "0.3"
log = "0.4"
prost = "0.7"
serde = {version = "1", features = ["derive"]}
sqlparser = "0.8"
sqlparser = "0.9.0"
tokio = "1.0"
tonic = "0.4"
uuid = { version = "0.8", features = ["v4"] }
Expand Down
6 changes: 3 additions & 3 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ message WindowExprNode {
// repeated LogicalExprNode partition_by = 5;
repeated LogicalExprNode order_by = 6;
// repeated LogicalExprNode filter = 7;
// oneof window_frame {
// WindowFrame frame = 8;
// }
oneof window_frame {
WindowFrame frame = 8;
}
}

message BetweenNode {
Expand Down
49 changes: 27 additions & 22 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@
use crate::error::BallistaError;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::{
convert::{From, TryInto},
unimplemented,
};

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
use datafusion::physical_plan::window_frames::{
WindowFrame, WindowFrameBound, WindowFrameUnits,
};
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
use datafusion::scalar::ScalarValue;
use protobuf::logical_plan_node::LogicalPlanType;
use protobuf::{logical_expr_node::ExprType, scalar_type};
use std::{
convert::{From, TryInto},
unimplemented,
};

// use uuid::Uuid;

Expand Down Expand Up @@ -83,20 +84,6 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;

// let partition_by_expr = window
// .partition_by_expr
// .iter()
// .map(|expr| expr.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// let order_by_expr = window
// .order_by_expr
// .iter()
// .map(|expr| expr.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// // FIXME: add filter by expr
// // FIXME: parse the window_frame data
// let window_frame = None;
LogicalPlanBuilder::from(&input)
.window(window_expr)?
.build()
Expand Down Expand Up @@ -929,6 +916,15 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
.map(|e| e.try_into())
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
let window_frame = expr
.window_frame
.as_ref()
.map::<Result<WindowFrame, _>, _>(|e| match e {
window_expr_node::WindowFrame::Frame(frame) => {
frame.clone().try_into()
}
})
.transpose()?;
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
Expand All @@ -945,6 +941,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
),
args: vec![parse_required_expr(&expr.expr)?],
order_by,
window_frame,
})
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
Expand All @@ -964,6 +961,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
),
args: vec![parse_required_expr(&expr.expr)?],
order_by,
window_frame,
})
}
}
Expand Down Expand Up @@ -1333,8 +1331,15 @@ impl TryFrom<protobuf::WindowFrame> for WindowFrame {
)
})?
.try_into()?;
// FIXME parse end bound
let end_bound = None;
let end_bound = window
.end_bound
.map(|end_bound| match end_bound {
protobuf::window_frame::EndBound::Bound(end_bound) => {
end_bound.try_into()
}
})
.transpose()?
.unwrap_or(WindowFrameBound::CurrentRow);
Ok(WindowFrame {
units,
start_bound,
Expand Down
56 changes: 38 additions & 18 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@ use std::{
convert::{TryFrom, TryInto},
};

use super::super::proto_error;
use crate::datasource::DfTableAdapter;
use crate::serde::{protobuf, BallistaError};
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{Expr, JoinType, LogicalPlan};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::functions::BuiltinScalarFunction;
use datafusion::physical_plan::window_frames::{
WindowFrame, WindowFrameBound, WindowFrameUnits,
};
use datafusion::physical_plan::window_functions::{
BuiltInWindowFunction, WindowFunction,
};
Expand All @@ -38,10 +43,6 @@ use protobuf::{
arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType,
ScalarListValue, ScalarType,
};
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};

use super::super::proto_error;
use datafusion::physical_plan::functions::BuiltinScalarFunction;

impl protobuf::IntervalUnit {
pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self {
Expand Down Expand Up @@ -1007,6 +1008,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
ref fun,
ref args,
ref order_by,
ref window_frame,
..
} => {
let window_function = match fun {
Expand All @@ -1026,10 +1028,16 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
.iter()
.map(|e| e.try_into())
.collect::<Result<Vec<_>, _>>()?;
let window_frame = window_frame.map(|window_frame| {
protobuf::window_expr_node::WindowFrame::Frame(
window_frame.clone().into(),
)
});
let window_expr = Box::new(protobuf::WindowExprNode {
expr: Some(Box::new(arg.try_into()?)),
window_function: Some(window_function),
order_by,
window_frame,
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::WindowExpr(window_expr)),
Expand Down Expand Up @@ -1256,23 +1264,35 @@ impl From<WindowFrameUnits> for protobuf::WindowFrameUnits {
}
}

impl TryFrom<WindowFrameBound> for protobuf::WindowFrameBound {
type Error = BallistaError;

fn try_from(_bound: WindowFrameBound) -> Result<Self, Self::Error> {
Err(BallistaError::NotImplemented(
"WindowFrameBound => protobuf::WindowFrameBound".to_owned(),
))
impl From<WindowFrameBound> for protobuf::WindowFrameBound {
fn from(bound: WindowFrameBound) -> Self {
match bound {
WindowFrameBound::CurrentRow => protobuf::WindowFrameBound {
window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow
.into(),
bound_value: None,
},
WindowFrameBound::Preceding(v) => protobuf::WindowFrameBound {
window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(),
bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value),
},
WindowFrameBound::Following(v) => protobuf::WindowFrameBound {
window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(),
bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value),
},
}
}
}

impl TryFrom<WindowFrame> for protobuf::WindowFrame {
type Error = BallistaError;

fn try_from(_window: WindowFrame) -> Result<Self, Self::Error> {
Err(BallistaError::NotImplemented(
"WindowFrame => protobuf::WindowFrame".to_owned(),
))
impl From<WindowFrame> for protobuf::WindowFrame {
fn from(window: WindowFrame) -> Self {
protobuf::WindowFrame {
window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(),
start_bound: Some(window.start_bound.into()),
end_bound: Some(protobuf::window_frame::EndBound::Bound(
window.end_bound.into(),
)),
}
}
}

Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/src/serde/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
fun,
args,
order_by,
..
} => {
let arg = df_planner
.create_physical_expr(
Expand Down
50 changes: 38 additions & 12 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,19 @@
//! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct.

pub use super::Operator;

use std::fmt;
use std::sync::Arc;

use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
use arrow::{compute::can_cast_types, datatypes::DataType};

use crate::error::{DataFusionError, Result};
use crate::logical_plan::{DFField, DFSchema};
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
window_functions,
window_frames, window_functions,
};
use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue};
use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
use arrow::{compute::can_cast_types, datatypes::DataType};
use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;

/// `Expr` is a central struct of DataFusion's query API, and
/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
Expand Down Expand Up @@ -199,6 +196,8 @@ pub enum Expr {
args: Vec<Expr>,
/// List of order by expressions
order_by: Vec<Expr>,
/// Window frame
window_frame: Option<window_frames::WindowFrame>,
},
/// aggregate function
AggregateUDF {
Expand Down Expand Up @@ -735,10 +734,12 @@ impl Expr {
args,
fun,
order_by,
window_frame,
} => Expr::WindowFunction {
args: rewrite_vec(args, rewriter)?,
fun,
order_by: rewrite_vec(order_by, rewriter)?,
window_frame,
},
Expr::AggregateFunction {
args,
Expand Down Expand Up @@ -1283,8 +1284,23 @@ impl fmt::Debug for Expr {
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args)
}
Expr::WindowFunction { fun, ref args, .. } => {
fmt_function(f, &fun.to_string(), false, args)
Expr::WindowFunction {
fun,
ref args,
window_frame,
..
} => {
fmt_function(f, &fun.to_string(), false, args)?;
if let Some(window_frame) = window_frame {
write!(
f,
" {} BETWEEN {} AND {}",
window_frame.units,
window_frame.start_bound,
window_frame.end_bound
)?;
}
Ok(())
}
Expr::AggregateFunction {
fun,
Expand Down Expand Up @@ -1401,8 +1417,18 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_name(&fun.name, false, args, input_schema)
}
Expr::WindowFunction { fun, args, .. } => {
create_function_name(&fun.to_string(), false, args, input_schema)
Expr::WindowFunction {
fun,
args,
window_frame,
..
} => {
let fun_name =
create_function_name(&fun.to_string(), false, args, input_schema)?;
Ok(match window_frame {
Some(window_frame) => format!("{} {}", fun_name, window_frame),
None => fun_name,
})
}
Expr::AggregateFunction {
fun,
Expand Down
5 changes: 4 additions & 1 deletion datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::WindowFunction { fun, .. } => {
Expr::WindowFunction {
fun, window_frame, ..
} => {
let index = expressions
.iter()
.position(|expr| {
Expand All @@ -353,6 +355,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions[..index].to_vec(),
order_by: expressions[index + 1..].to_vec(),
window_frame: *window_frame,
})
}
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
Expand Down
1 change: 1 addition & 0 deletions datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,5 +617,6 @@ pub mod udf;
#[cfg(feature = "unicode_expressions")]
pub mod unicode_expressions;
pub mod union;
pub mod window_frames;
pub mod window_functions;
pub mod windows;
3 changes: 1 addition & 2 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

//! Physical query planner

use std::sync::Arc;

use super::{
aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary,
functions, hash_join::PartitionMode, udaf, union::UnionExec, windows,
Expand Down Expand Up @@ -56,6 +54,7 @@ use arrow::datatypes::{Schema, SchemaRef};
use arrow::{compute::can_cast_types, datatypes::DataType};
use expressions::col;
use log::debug;
use std::sync::Arc;

/// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`].
pub trait ExtensionPlanner {
Expand Down
Loading

0 comments on commit 767eeb0

Please sign in to comment.