Skip to content

Commit

Permalink
add window expr
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed May 21, 2021
1 parent 913bf86 commit 1d3b076
Show file tree
Hide file tree
Showing 21 changed files with 1,236 additions and 103 deletions.
75 changes: 74 additions & 1 deletion ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ message LogicalExprNode {

ScalarValue literal = 3;


// binary expressions
BinaryExprNode binary_expr = 4;

Expand All @@ -60,6 +59,9 @@ message LogicalExprNode {
bool wildcard = 15;
ScalarFunctionNode scalar_function = 16;
TryCastNode try_cast = 17;

// window expressions
WindowExprNode window_expr = 18;
}
}

Expand Down Expand Up @@ -151,6 +153,25 @@ message AggregateExprNode {
LogicalExprNode expr = 2;
}

enum BuiltInWindowFunction {
ROW_NUMBER = 0;
RANK = 1;
DENSE_RANK = 2;
LAG = 3;
LEAD = 4;
FIRST_VALUE = 5;
LAST_VALUE = 6;
}

message WindowExprNode {
oneof window_function {
AggregateFunction aggr_function = 1;
BuiltInWindowFunction built_in_function = 2;
// udaf = 3
}
LogicalExprNode expr = 4;
}

message BetweenNode {
LogicalExprNode expr = 1;
bool negated = 2;
Expand Down Expand Up @@ -200,6 +221,7 @@ message LogicalPlanNode {
EmptyRelationNode empty_relation = 10;
CreateExternalTableNode create_external_table = 11;
ExplainNode explain = 12;
WindowNode window = 13;
}
}

Expand Down Expand Up @@ -288,6 +310,49 @@ message AggregateNode {
repeated LogicalExprNode aggr_expr = 3;
}

message WindowNode {
LogicalPlanNode input = 1;
repeated LogicalExprNode window_expr = 2;
repeated LogicalExprNode partition_by_expr = 3;
repeated LogicalExprNode order_by_expr = 4;
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
oneof window_frame {
WindowFrame frame = 5;
}
}

enum WindowFrameUnits {
ROWS = 0;
RANGE = 1;
GROUPS = 2;
}

message WindowFrame {
WindowFrameUnits window_frame_units = 1;
WindowFrameBound start_bound = 2;
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
oneof end_bound {
WindowFrameBound bound = 3;
}
}

enum WindowFrameBoundType {
CURRENT_ROW = 0;
PRECEDING = 1;
FOLLOWING = 2;
}

message WindowFrameBound {
WindowFrameBoundType window_frame_bound_type = 1;
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
oneof bound_value {
uint64 value = 2;
}
}

enum JoinType {
INNER = 0;
LEFT = 1;
Expand Down Expand Up @@ -334,6 +399,7 @@ message PhysicalPlanNode {
MergeExecNode merge = 14;
UnresolvedShuffleExecNode unresolved = 15;
RepartitionExecNode repartition = 16;
WindowAggExecNode window = 17;
}
}

Expand Down Expand Up @@ -399,6 +465,13 @@ enum AggregateMode {
FINAL_PARTITIONED = 2;
}

message WindowAggExecNode {
PhysicalPlanNode input = 1;
repeated LogicalExprNode window_expr = 2;
repeated string window_expr_name = 3;
Schema input_schema = 4;
}

message HashAggregateExecNode {
repeated LogicalExprNode group_expr = 1;
repeated LogicalExprNode aggr_expr = 2;
Expand Down
190 changes: 179 additions & 11 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

//! Serde code to convert from protocol buffers to Rust data structures.

use crate::error::BallistaError;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::{
convert::{From, TryInto},
unimplemented,
};

use crate::error::BallistaError;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};

use arrow::datatypes::{DataType, Field, Schema};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
use datafusion::scalar::ScalarValue;
use protobuf::logical_plan_node::LogicalPlanType;
use protobuf::{logical_expr_node::ExprType, scalar_type};
Expand Down Expand Up @@ -75,6 +76,33 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.build()
.map_err(|e| e.into())
}
LogicalPlanType::Window(window) => {
let input: LogicalPlan = convert_box_required!(window.input)?;
let window_expr = window
.window_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;

// let partition_by_expr = window
// .partition_by_expr
// .iter()
// .map(|expr| expr.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// let order_by_expr = window
// .order_by_expr
// .iter()
// .map(|expr| expr.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// // FIXME parse the window_frame data
// let window_frame = None;
LogicalPlanBuilder::from(&input)
.window(
window_expr, /*, partition_by_expr, order_by_expr, window_frame*/
)?
.build()
.map_err(|e| e.into())
}
LogicalPlanType::Aggregate(aggregate) => {
let input: LogicalPlan = convert_box_required!(aggregate.input)?;
let group_expr = aggregate
Expand Down Expand Up @@ -871,7 +899,10 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
type Error = BallistaError;

fn try_into(self) -> Result<Expr, Self::Error> {
use datafusion::physical_plan::window_functions;
use protobuf::logical_expr_node::ExprType;
use protobuf::window_expr_node;
use protobuf::WindowExprNode;

let expr_type = self
.expr_type
Expand All @@ -889,6 +920,48 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
Ok(Expr::Literal(scalar_value))
}
ExprType::WindowExpr(expr) => {
let window_function = expr
.window_function
.as_ref()
.ok_or_else(|| proto_error("Received empty window function"))?;
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
.ok_or_else(|| {
proto_error(format!(
"Received an unknown aggregate window function: {}",
i
))
})?;

Ok(Expr::WindowFunction {
fun: window_functions::WindowFunction::AggregateFunction(
AggregateFunction::from(aggr_function),
),
args: vec![parse_required_expr(&expr.expr)?],
})
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
let built_in_function =
protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else(
|| {
proto_error(format!(
"Received an unknown built-in window function: {}",
i
))
},
)?;

Ok(Expr::WindowFunction {
fun: window_functions::WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::from(built_in_function),
),
args: vec![parse_required_expr(&expr.expr)?],
})
}
}
}
ExprType::AggregateExpr(expr) => {
let aggr_function =
protobuf::AggregateFunction::from_i32(expr.aggr_function)
Expand All @@ -898,13 +971,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
expr.aggr_function
))
})?;
let fun = match aggr_function {
protobuf::AggregateFunction::Min => AggregateFunction::Min,
protobuf::AggregateFunction::Max => AggregateFunction::Max,
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
};
let fun = AggregateFunction::from(aggr_function);

Ok(Expr::AggregateFunction {
fun,
Expand Down Expand Up @@ -1152,6 +1219,7 @@ impl TryInto<arrow::datatypes::Field> for &protobuf::Field {
}

use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp};
use datafusion::physical_plan::{aggregates, windows};
use datafusion::prelude::{
array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper,
};
Expand Down Expand Up @@ -1202,3 +1270,103 @@ fn parse_optional_expr(
None => Ok(None),
}
}

impl From<protobuf::WindowFrameUnits> for WindowFrameUnits {
fn from(units: protobuf::WindowFrameUnits) -> Self {
match units {
protobuf::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
protobuf::WindowFrameUnits::Range => WindowFrameUnits::Range,
protobuf::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
}
}
}

impl TryFrom<protobuf::WindowFrameBound> for WindowFrameBound {
type Error = BallistaError;

fn try_from(bound: protobuf::WindowFrameBound) -> Result<Self, Self::Error> {
let bound_type = protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type).ok_or_else(|| {
proto_error(format!(
"Received a WindowFrameBound message with unknown WindowFrameBoundType {}",
bound.window_frame_bound_type
))
})?.into();
match bound_type {
protobuf::WindowFrameBoundType::CurrentRow => {
Ok(WindowFrameBound::CurrentRow)
}
protobuf::WindowFrameBoundType::Preceding => {
// FIXME implement bound value parsing
Ok(WindowFrameBound::Preceding(Some(1)))
}
protobuf::WindowFrameBoundType::Following => {
// FIXME implement bound value parsing
Ok(WindowFrameBound::Following(Some(1)))
}
}
}
}

impl TryFrom<protobuf::WindowFrame> for WindowFrame {
type Error = BallistaError;

fn try_from(window: protobuf::WindowFrame) -> Result<Self, Self::Error> {
let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units)
.ok_or_else(|| {
proto_error(format!(
"Received a WindowFrame message with unknown WindowFrameUnits {}",
window.window_frame_units
))
})?
.into();
let start_bound = window
.start_bound
.ok_or_else(|| {
proto_error(
"Received a WindowFrame message with no start_bound".to_owned(),
)
})?
.try_into()?;
// FIXME parse end bound
let end_bound = None;
Ok(WindowFrame {
units,
start_bound,
end_bound,
})
}
}

impl From<protobuf::AggregateFunction> for AggregateFunction {
fn from(aggr_function: protobuf::AggregateFunction) -> Self {
match aggr_function {
protobuf::AggregateFunction::Min => AggregateFunction::Min,
protobuf::AggregateFunction::Max => AggregateFunction::Max,
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
}
}
}

impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
match built_in_function {
protobuf::BuiltInWindowFunction::RowNumber => {
BuiltInWindowFunction::RowNumber
}
protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
protobuf::BuiltInWindowFunction::DenseRank => {
BuiltInWindowFunction::DenseRank
}
protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
protobuf::BuiltInWindowFunction::FirstValue => {
BuiltInWindowFunction::FirstValue
}
protobuf::BuiltInWindowFunction::LastValue => {
BuiltInWindowFunction::LastValue
}
}
}
}
Loading

0 comments on commit 1d3b076

Please sign in to comment.