From 767eeb0a8bf17916aafb9a88abd52e7350acb596 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Jun 2021 18:14:25 +0800 Subject: [PATCH] closing up type checks (#506) --- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/core/proto/ballista.proto | 6 +- .../core/src/serde/logical_plan/from_proto.rs | 49 +-- .../core/src/serde/logical_plan/to_proto.rs | 56 ++- .../src/serde/physical_plan/from_proto.rs | 1 + datafusion/src/logical_plan/expr.rs | 50 ++- datafusion/src/optimizer/utils.rs | 5 +- datafusion/src/physical_plan/mod.rs | 1 + datafusion/src/physical_plan/planner.rs | 3 +- datafusion/src/physical_plan/window_frames.rs | 337 ++++++++++++++++++ datafusion/src/sql/planner.rs | 52 ++- datafusion/src/sql/utils.rs | 12 + 12 files changed, 512 insertions(+), 62 deletions(-) create mode 100644 datafusion/src/physical_plan/window_frames.rs diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 99822cfe2aee..1f23a2a42e2a 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -35,7 +35,7 @@ futures = "0.3" log = "0.4" prost = "0.7" serde = {version = "1", features = ["derive"]} -sqlparser = "0.8" +sqlparser = "0.9.0" tokio = "1.0" tonic = "0.4" uuid = { version = "0.8", features = ["v4"] } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 0ed9f243fd0a..38d87e934e5f 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -177,9 +177,9 @@ message WindowExprNode { // repeated LogicalExprNode partition_by = 5; repeated LogicalExprNode order_by = 6; // repeated LogicalExprNode filter = 7; - // oneof window_frame { - // WindowFrame frame = 8; - // } + oneof window_frame { + WindowFrame frame = 8; + } } message BetweenNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 662d9d0a929a..4a198174a2ba 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -20,12 +20,6 @@ use crate::error::BallistaError; use crate::serde::{proto_error, protobuf}; use crate::{convert_box_required, convert_required}; -use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -use std::{ - convert::{From, TryInto}, - unimplemented, -}; - use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, @@ -33,10 +27,17 @@ use datafusion::logical_plan::{ }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; +use datafusion::physical_plan::window_frames::{ + WindowFrame, WindowFrameBound, WindowFrameUnits, +}; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; use datafusion::scalar::ScalarValue; use protobuf::logical_plan_node::LogicalPlanType; use protobuf::{logical_expr_node::ExprType, scalar_type}; +use std::{ + convert::{From, TryInto}, + unimplemented, +}; // use uuid::Uuid; @@ -83,20 +84,6 @@ impl TryInto for &protobuf::LogicalPlanNode { .iter() .map(|expr| expr.try_into()) .collect::, _>>()?; - - // let partition_by_expr = window - // .partition_by_expr - // .iter() - // .map(|expr| expr.try_into()) - // .collect::, _>>()?; - // let order_by_expr = window - // .order_by_expr - // .iter() - // .map(|expr| expr.try_into()) - // .collect::, _>>()?; - // // FIXME: add filter by expr - // // FIXME: parse the window_frame data - // let window_frame = None; LogicalPlanBuilder::from(&input) .window(window_expr)? .build() @@ -929,6 +916,15 @@ impl TryInto for &protobuf::LogicalExprNode { .map(|e| e.try_into()) .into_iter() .collect::, _>>()?; + let window_frame = expr + .window_frame + .as_ref() + .map::, _>(|e| match e { + window_expr_node::WindowFrame::Frame(frame) => { + frame.clone().try_into() + } + }) + .transpose()?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = protobuf::AggregateFunction::from_i32(*i) @@ -945,6 +941,7 @@ impl TryInto for &protobuf::LogicalExprNode { ), args: vec![parse_required_expr(&expr.expr)?], order_by, + window_frame, }) } window_expr_node::WindowFunction::BuiltInFunction(i) => { @@ -964,6 +961,7 @@ impl TryInto for &protobuf::LogicalExprNode { ), args: vec![parse_required_expr(&expr.expr)?], order_by, + window_frame, }) } } @@ -1333,8 +1331,15 @@ impl TryFrom for WindowFrame { ) })? .try_into()?; - // FIXME parse end bound - let end_bound = None; + let end_bound = window + .end_bound + .map(|end_bound| match end_bound { + protobuf::window_frame::EndBound::Bound(end_bound) => { + end_bound.try_into() + } + }) + .transpose()? + .unwrap_or(WindowFrameBound::CurrentRow); Ok(WindowFrame { units, start_bound, diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index d7734f05da56..56270030b59f 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -24,12 +24,17 @@ use std::{ convert::{TryFrom, TryInto}, }; +use super::super::proto_error; use crate::datasource::DfTableAdapter; use crate::serde::{protobuf, BallistaError}; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::datasource::CsvFile; use datafusion::logical_plan::{Expr, JoinType, LogicalPlan}; use datafusion::physical_plan::aggregates::AggregateFunction; +use datafusion::physical_plan::functions::BuiltinScalarFunction; +use datafusion::physical_plan::window_frames::{ + WindowFrame, WindowFrameBound, WindowFrameUnits, +}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -38,10 +43,6 @@ use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType, ScalarListValue, ScalarType, }; -use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; - -use super::super::proto_error; -use datafusion::physical_plan::functions::BuiltinScalarFunction; impl protobuf::IntervalUnit { pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self { @@ -1007,6 +1008,7 @@ impl TryInto for &Expr { ref fun, ref args, ref order_by, + ref window_frame, .. } => { let window_function = match fun { @@ -1026,10 +1028,16 @@ impl TryInto for &Expr { .iter() .map(|e| e.try_into()) .collect::, _>>()?; + let window_frame = window_frame.map(|window_frame| { + protobuf::window_expr_node::WindowFrame::Frame( + window_frame.clone().into(), + ) + }); let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), window_function: Some(window_function), order_by, + window_frame, }); Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), @@ -1256,23 +1264,35 @@ impl From for protobuf::WindowFrameUnits { } } -impl TryFrom for protobuf::WindowFrameBound { - type Error = BallistaError; - - fn try_from(_bound: WindowFrameBound) -> Result { - Err(BallistaError::NotImplemented( - "WindowFrameBound => protobuf::WindowFrameBound".to_owned(), - )) +impl From for protobuf::WindowFrameBound { + fn from(bound: WindowFrameBound) -> Self { + match bound { + WindowFrameBound::CurrentRow => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow + .into(), + bound_value: None, + }, + WindowFrameBound::Preceding(v) => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(), + bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + }, + WindowFrameBound::Following(v) => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(), + bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + }, + } } } -impl TryFrom for protobuf::WindowFrame { - type Error = BallistaError; - - fn try_from(_window: WindowFrame) -> Result { - Err(BallistaError::NotImplemented( - "WindowFrame => protobuf::WindowFrame".to_owned(), - )) +impl From for protobuf::WindowFrame { + fn from(window: WindowFrame) -> Self { + protobuf::WindowFrame { + window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(), + start_bound: Some(window.start_bound.into()), + end_bound: Some(protobuf::window_frame::EndBound::Bound( + window.end_bound.into(), + )), + } } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 22944313666f..5fcc971527c6 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -237,6 +237,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { fun, args, order_by, + .. } => { let arg = df_planner .create_physical_expr( diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 5103d5dc5051..bbc6ffabe928 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -19,22 +19,19 @@ //! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct. pub use super::Operator; - -use std::fmt; -use std::sync::Arc; - -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::{compute::can_cast_types, datatypes::DataType}; - use crate::error::{DataFusionError, Result}; use crate::logical_plan::{DFField, DFSchema}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, - window_functions, + window_frames, window_functions, }; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; +use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; +use arrow::{compute::can_cast_types, datatypes::DataType}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::collections::HashSet; +use std::fmt; +use std::sync::Arc; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -199,6 +196,8 @@ pub enum Expr { args: Vec, /// List of order by expressions order_by: Vec, + /// Window frame + window_frame: Option, }, /// aggregate function AggregateUDF { @@ -735,10 +734,12 @@ impl Expr { args, fun, order_by, + window_frame, } => Expr::WindowFunction { args: rewrite_vec(args, rewriter)?, fun, order_by: rewrite_vec(order_by, rewriter)?, + window_frame, }, Expr::AggregateFunction { args, @@ -1283,8 +1284,23 @@ impl fmt::Debug for Expr { Expr::ScalarUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args) } - Expr::WindowFunction { fun, ref args, .. } => { - fmt_function(f, &fun.to_string(), false, args) + Expr::WindowFunction { + fun, + ref args, + window_frame, + .. + } => { + fmt_function(f, &fun.to_string(), false, args)?; + if let Some(window_frame) = window_frame { + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + )?; + } + Ok(()) } Expr::AggregateFunction { fun, @@ -1401,8 +1417,18 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_name(&fun.name, false, args, input_schema) } - Expr::WindowFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), false, args, input_schema) + Expr::WindowFunction { + fun, + args, + window_frame, + .. + } => { + let fun_name = + create_function_name(&fun.to_string(), false, args, input_schema)?; + Ok(match window_frame { + Some(window_frame) => format!("{} {}", fun_name, window_frame), + None => fun_name, + }) } Expr::AggregateFunction { fun, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 2cb65066feb9..65c95bee20d4 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -337,7 +337,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), - Expr::WindowFunction { fun, .. } => { + Expr::WindowFunction { + fun, window_frame, .. + } => { let index = expressions .iter() .position(|expr| { @@ -353,6 +355,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions[..index].to_vec(), order_by: expressions[index + 1..].to_vec(), + window_frame: *window_frame, }) } Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index af6969c43cbd..490e02875c42 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -617,5 +617,6 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod union; +pub mod window_frames; pub mod window_functions; pub mod windows; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 754ace08de6a..d7451c787096 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -17,8 +17,6 @@ //! Physical query planner -use std::sync::Arc; - use super::{ aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, windows, @@ -56,6 +54,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::{compute::can_cast_types, datatypes::DataType}; use expressions::col; use log::debug; +use std::sync::Arc; /// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`]. pub trait ExtensionPlanner { diff --git a/datafusion/src/physical_plan/window_frames.rs b/datafusion/src/physical_plan/window_frames.rs new file mode 100644 index 000000000000..f0be5a221fbf --- /dev/null +++ b/datafusion/src/physical_plan/window_frames.rs @@ -0,0 +1,337 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window frame +//! +//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: +//! - A frame type - either ROWS, RANGE or GROUPS, +//! - A starting frame boundary, +//! - An ending frame boundary, +//! - An EXCLUDE clause. + +use crate::error::{DataFusionError, Result}; +use sqlparser::ast; +use std::cmp::Ordering; +use std::convert::{From, TryFrom}; +use std::fmt; + +/// The frame-spec determines which output rows are read by an aggregate window function. +/// +/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the +/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to +/// CURRENT ROW. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WindowFrame { + /// A frame type - either ROWS, RANGE or GROUPS + pub units: WindowFrameUnits, + /// A starting frame boundary + pub start_bound: WindowFrameBound, + /// An ending frame boundary + pub end_bound: WindowFrameBound, +} + +impl fmt::Display for WindowFrame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} BETWEEN {} AND {}", + self.units, self.start_bound, self.end_bound + )?; + Ok(()) + } +} + +impl TryFrom for WindowFrame { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrame) -> Result { + let start_bound = value.start_bound.into(); + let end_bound = value + .end_bound + .map(WindowFrameBound::from) + .unwrap_or(WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = start_bound { + Err(DataFusionError::Execution( + "Invalid window frame: start bound cannot be unbounded following" + .to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = end_bound { + Err(DataFusionError::Execution( + "Invalid window frame: end bound cannot be unbounded preceding" + .to_owned(), + )) + } else if start_bound > end_bound { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + let units = value.units.into(); + Ok(Self { + units, + start_bound, + end_bound, + }) + } + } +} + +impl Default for WindowFrame { + fn default() -> Self { + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::CurrentRow, + } + } +} + +/// There are five ways to describe starting and ending frame boundaries: +/// +/// 1. UNBOUNDED PRECEDING +/// 2. PRECEDING +/// 3. CURRENT ROW +/// 4. FOLLOWING +/// 5. UNBOUNDED FOLLOWING +/// +/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) +#[derive(Debug, Clone, Copy, Eq)] +pub enum WindowFrameBound { + /// 1. UNBOUNDED PRECEDING + /// The frame boundary is the first row in the partition. + /// + /// 2. PRECEDING + /// must be a non-negative constant numeric expression. The boundary is a row that + /// is "units" prior to the current row. + Preceding(Option), + /// 3. The current row. + /// + /// For RANGE and GROUPS frame types, peers of the current row are also + /// included in the frame, unless specifically excluded by the EXCLUDE clause. + /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame + /// boundary. + CurrentRow, + /// 4. This is the same as " PRECEDING" except that the boundary is units after the + /// current rather than before the current row. + /// + /// 5. UNBOUNDED FOLLOWING + /// The frame boundary is the last row in the partition. + Following(Option), +} + +impl From for WindowFrameBound { + fn from(value: ast::WindowFrameBound) -> Self { + match value { + ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), + ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::CurrentRow => Self::CurrentRow, + } + } +} + +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), + WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), + WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), + WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + } + } +} + +impl PartialEq for WindowFrameBound { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for WindowFrameBound { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WindowFrameBound { + fn cmp(&self, other: &Self) -> Ordering { + self.get_rank().cmp(&other.get_rank()) + } +} + +impl WindowFrameBound { + /// get the rank of this window frame bound. + /// + /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value + /// which requires special handling e.g. with preceding the larger the value the smaller the + /// rank and also for 0 preceding / following it is the same as current row + fn get_rank(&self) -> (u8, u64) { + match self { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } + } +} + +/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the +/// starting and ending boundaries of the frame are measured. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WindowFrameUnits { + /// The ROWS frame type means that the starting and ending boundaries for the frame are + /// determined by counting individual rows relative to the current row. + Rows, + /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one + /// term. Call that term "X". With the RANGE frame type, the elements of the frame are + /// determined by computing the value of expression X for all rows in the partition and framing + /// those rows for which the value of X is within a certain range of the value of X for the + /// current row. + Range, + /// The GROUPS frame type means that the starting and ending boundaries are determine + /// by counting "groups" relative to the current group. A "group" is a set of rows that all have + /// equivalent values for all all terms of the window ORDER BY clause. + Groups, +} + +impl fmt::Display for WindowFrameUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + WindowFrameUnits::Rows => "ROWS", + WindowFrameUnits::Range => "RANGE", + WindowFrameUnits::Groups => "GROUPS", + }) + } +} + +impl From for WindowFrameUnits { + fn from(value: ast::WindowFrameUnits) -> Self { + match value { + ast::WindowFrameUnits::Range => Self::Range, + ast::WindowFrameUnits::Groups => Self::Groups, + ast::WindowFrameUnits::Rows => Self::Rows, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_frame_creation() -> Result<()> { + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Following(None), + end_bound: None, + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(None), + end_bound: Some(ast::WindowFrameBound::Preceding(None)), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(1)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + ); + Ok(()) + } + + #[test] + fn test_eq() { + assert_eq!( + WindowFrameBound::Preceding(Some(0)), + WindowFrameBound::CurrentRow + ); + assert_eq!( + WindowFrameBound::CurrentRow, + WindowFrameBound::Following(Some(0)) + ); + assert_eq!( + WindowFrameBound::Following(Some(2)), + WindowFrameBound::Following(Some(2)) + ); + assert_eq!( + WindowFrameBound::Following(None), + WindowFrameBound::Following(None) + ); + assert_eq!( + WindowFrameBound::Preceding(Some(2)), + WindowFrameBound::Preceding(Some(2)) + ); + assert_eq!( + WindowFrameBound::Preceding(None), + WindowFrameBound::Preceding(None) + ); + } + + #[test] + fn test_ord() { + assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); + // ! yes this is correct! + assert!( + WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) + ); + assert!( + WindowFrameBound::Preceding(Some(u64::MAX)) + < WindowFrameBound::Preceding(Some(u64::MAX - 1)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(1000000)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(u64::MAX)) + ); + assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); + assert!( + WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) + ); + assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); + assert!( + WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) + ); + assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); + assert!( + WindowFrameBound::Following(Some(u64::MAX)) + < WindowFrameBound::Following(None) + ); + } +} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index aa6b5a93f483..6bf7b776c8db 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1121,13 +1121,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // then, window function if let Some(window) = &function.over { - if window.partition_by.is_empty() && window.window_frame.is_none() { + if window.partition_by.is_empty() { let order_by = window .order_by .iter() .map(|e| self.order_by_to_sort_expr(e)) .into_iter() .collect::>>()?; + let window_frame = window + .window_frame + .as_ref() + .map(|window_frame| window_frame.clone().try_into()) + .transpose()?; let fun = window_functions::WindowFunction::from_str(&name); if let Ok(window_functions::WindowFunction::AggregateFunction( aggregate_fun, @@ -1140,6 +1145,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { args: self .aggregate_fn_to_expr(&aggregate_fun, function)?, order_by, + window_frame, }); } else if let Ok( window_functions::WindowFunction::BuiltInWindowFunction( @@ -1151,8 +1157,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_fun, ), - args:self.function_args_to_expr(function)?, - order_by + args: self.function_args_to_expr(function)?, + order_by, + window_frame, }); } } @@ -2806,6 +2813,45 @@ mod tests { quick_test(sql, expected); } + #[test] + fn over_order_by_with_window_frame_double_end() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn over_order_by_with_window_frame_single_end() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn over_order_by_with_window_frame_single_end_groups() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + /// psql result /// ``` /// QUERY PLAN diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 80a25d04468f..7a5dc0da1b53 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -239,6 +239,7 @@ where fun, args, order_by, + window_frame, } => Ok(Expr::WindowFunction { fun: fun.clone(), args: args @@ -249,6 +250,7 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, + window_frame: *window_frame, }), Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF { fun: fun.clone(), @@ -453,21 +455,25 @@ mod tests { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], order_by: vec![], + window_frame: None, }; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -500,21 +506,25 @@ mod tests { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![age_asc.clone(), name_desc.clone()], + window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], order_by: vec![age_asc.clone(), name_desc.clone()], + window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], + window_frame: None, }; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -551,6 +561,7 @@ mod tests { nulls_first: true, }, ], + window_frame: None, }, Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), @@ -572,6 +583,7 @@ mod tests { nulls_first: true, }, ], + window_frame: None, }, ]; let expected = vec![