diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index b319d5b25f12..d49d53cf8d85 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -36,7 +36,7 @@ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; use datafusion::logical_plan::{DFSchema, Expr}; -use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; +use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::expressions::col; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::PartitionMode; @@ -45,7 +45,6 @@ use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; -use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::windows::WindowAggExec; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, @@ -205,76 +204,27 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ) })? .clone(); - let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - var_provider: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; - + let ctx_state = ExecutionContextState::new(); let window_agg_expr: Vec<(Expr, String)> = window_agg .window_expr .iter() .zip(window_agg.window_expr_name.iter()) .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) .collect::, _>>()?; - - let mut physical_window_expr = vec![]; - let df_planner = DefaultPhysicalPlanner::default(); - - for (expr, name) in &window_agg_expr { - match expr { - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - .. - } => { - let arg = df_planner - .create_physical_expr( - &args[0], - &physical_schema, - &ctx_state, - ) - .map_err(|e| { - BallistaError::General(format!("{:?}", e)) - })?; - if !partition_by.is_empty() { - return Err(BallistaError::NotImplemented("Window function with partition by is not yet implemented".to_owned())); - } - if !order_by.is_empty() { - return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned())); - } - if window_frame.is_some() { - return Err(BallistaError::NotImplemented("Window function with window frame is not yet implemented".to_owned())); - } - let window_expr = create_window_expr( - &fun, - &[arg], - &physical_schema, - name.to_owned(), - )?; - physical_window_expr.push(window_expr); - } - _ => { - return Err(BallistaError::General( - "Invalid expression for WindowAggrExec".to_string(), - )); - } - } - } - + let physical_window_expr = window_agg_expr + .iter() + .map(|(expr, name)| { + df_planner.create_window_expr_with_name( + expr, + name.to_string(), + &physical_schema, + &ctx_state, + ) + }) + .collect::, _>>()?; Ok(Arc::new(WindowAggExec::try_new( physical_window_expr, input, @@ -297,7 +247,6 @@ impl TryInto> for &protobuf::PhysicalPlanNode { AggregateMode::FinalPartitioned } }; - let group = hash_agg .group_expr .iter() @@ -306,25 +255,13 @@ impl TryInto> for &protobuf::PhysicalPlanNode { compile_expr(expr, &input.schema()).map(|e| (e, name.to_string())) }) .collect::, _>>()?; - let logical_agg_expr: Vec<(Expr, String)> = hash_agg .aggr_expr .iter() .zip(hash_agg.aggr_expr_name.iter()) .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) .collect::, _>>()?; - - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - var_provider: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; - + let ctx_state = ExecutionContextState::new(); let input_schema = hash_agg .input_schema .as_ref() @@ -336,37 +273,18 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .clone(); let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - - let mut physical_aggr_expr = vec![]; - let df_planner = DefaultPhysicalPlanner::default(); - for (expr, name) in &logical_agg_expr { - match expr { - Expr::AggregateFunction { fun, args, .. } => { - let arg = df_planner - .create_physical_expr( - &args[0], - &physical_schema, - &ctx_state, - ) - .map_err(|e| { - BallistaError::General(format!("{:?}", e)) - })?; - physical_aggr_expr.push(create_aggregate_expr( - &fun, - false, - &[arg], - &physical_schema, - name.to_string(), - )?); - } - _ => { - return Err(BallistaError::General( - "Invalid expression for HashAggregateExec".to_string(), - )) - } - } - } + let physical_aggr_expr = logical_agg_expr + .iter() + .map(|(expr, name)| { + df_planner.create_aggregate_expr_with_name( + expr, + name.to_string(), + &physical_schema, + &ctx_state, + ) + }) + .collect::, _>>()?; Ok(Arc::new(HashAggregateExec::try_new( agg_mode, group, @@ -484,15 +402,7 @@ fn compile_expr( schema: &Schema, ) -> Result, BallistaError> { let df_planner = DefaultPhysicalPlanner::default(); - let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; - let state = ExecutionContextState { - catalog_list, - scalar_functions: HashMap::new(), - var_provider: HashMap::new(), - aggregate_functions: HashMap::new(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; + let state = ExecutionContextState::new(); let expr: Expr = expr.try_into()?; df_planner .create_physical_expr(&expr, schema, &state) diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d7451c787096..d42948a8666c 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -731,34 +731,82 @@ impl DefaultPhysicalPlanner { } } - /// Create a window expression from a logical expression - pub fn create_window_expr( + /// Create a window expression with a name from a logical expression + pub fn create_window_expr_with_name( &self, e: &Expr, - logical_input_schema: &DFSchema, + name: String, physical_input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) over () as total" - let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (e.name(logical_input_schema)?, e), - }; - match e { - Expr::WindowFunction { fun, args, .. } => { + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { let args = args .iter() .map(|e| { self.create_physical_expr(e, physical_input_schema, ctx_state) }) .collect::>>()?; - // if !order_by.is_empty() { - // return Err(DataFusionError::NotImplemented( - // "Window function with order by is not yet implemented".to_owned(), - // )); - // } - windows::create_window_expr(fun, &args, physical_input_schema, name) + let partition_by = partition_by + .iter() + .map(|e| { + self.create_physical_expr(e, physical_input_schema, ctx_state) + }) + .collect::>>()?; + let order_by = order_by + .iter() + .map(|e| match e { + Expr::Sort { + expr, + asc, + nulls_first, + } => self.create_physical_sort_expr( + expr, + &physical_input_schema, + SortOptions { + descending: !*asc, + nulls_first: *nulls_first, + }, + &ctx_state, + ), + _ => Err(DataFusionError::Plan( + "Sort only accepts sort expressions".to_string(), + )), + }) + .collect::>>()?; + if !partition_by.is_empty() { + return Err(DataFusionError::NotImplemented( + "window expression with non-empty partition by clause is not yet supported" + .to_owned(), + )); + } + if !order_by.is_empty() { + return Err(DataFusionError::NotImplemented( + "window expression with non-empty order by clause is not yet supported" + .to_owned(), + )); + } + if window_frame.is_some() { + return Err(DataFusionError::NotImplemented( + "window expression with window frame definition is not yet supported" + .to_owned(), + )); + } + windows::create_window_expr( + fun, + name, + &args, + &partition_by, + &order_by, + *window_frame, + physical_input_schema, + ) } other => Err(DataFusionError::Internal(format!( "Invalid window expression '{:?}'", @@ -767,20 +815,30 @@ impl DefaultPhysicalPlanner { } } - /// Create an aggregate expression from a logical expression - pub fn create_aggregate_expr( + /// Create a window expression from a logical expression or an alias + pub fn create_window_expr( &self, e: &Expr, logical_input_schema: &DFSchema, physical_input_schema: &Schema, ctx_state: &ExecutionContextState, - ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) as total" + ) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), _ => (e.name(logical_input_schema)?, e), }; + self.create_window_expr_with_name(e, name, physical_input_schema, ctx_state) + } + /// Create an aggregate expression with a name from a logical expression + pub fn create_aggregate_expr_with_name( + &self, + e: &Expr, + name: String, + physical_input_schema: &Schema, + ctx_state: &ExecutionContextState, + ) -> Result> { match e { Expr::AggregateFunction { fun, @@ -819,7 +877,23 @@ impl DefaultPhysicalPlanner { } } - /// Create an aggregate expression from a logical expression + /// Create an aggregate expression from a logical expression or an alias + pub fn create_aggregate_expr( + &self, + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + ctx_state: &ExecutionContextState, + ) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (e.name(logical_input_schema)?, e), + }; + self.create_aggregate_expr_with_name(e, name, physical_input_schema, ctx_state) + } + + /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( &self, e: &Expr, diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 9a6b92985b51..565a9eef2857 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -18,9 +18,11 @@ //! Execution plan for window functions use crate::error::{DataFusionError, Result}; + +use crate::logical_plan::window_frames::WindowFrame; use crate::physical_plan::{ aggregates, common, - expressions::{Literal, NthValue, RowNumber}, + expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, @@ -61,12 +63,18 @@ pub struct WindowAggExec { /// Create a physical expression for window function pub fn create_window_expr( fun: &WindowFunction, + name: String, args: &[Arc], + // https://github.com/apache/arrow-datafusion/issues/299 + _partition_by: &[Arc], + // https://github.com/apache/arrow-datafusion/issues/360 + _order_by: &[PhysicalSortExpr], + // https://github.com/apache/arrow-datafusion/issues/361 + _window_frame: Option, input_schema: &Schema, - name: String, ) -> Result> { - match fun { - WindowFunction::AggregateFunction(fun) => Ok(Arc::new(AggregateWindowExpr { + Ok(match fun { + WindowFunction::AggregateFunction(fun) => Arc::new(AggregateWindowExpr { aggregate: aggregates::create_aggregate_expr( fun, false, @@ -74,11 +82,11 @@ pub fn create_window_expr( input_schema, name, )?, - })), - WindowFunction::BuiltInWindowFunction(fun) => Ok(Arc::new(BuiltInWindowExpr { + }), + WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr { window: create_built_in_window_expr(fun, args, input_schema, name)?, - })), - } + }), + }) } fn create_built_in_window_expr( @@ -537,9 +545,12 @@ mod tests { let window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "count".to_owned(), )?], input, schema.clone(), @@ -567,21 +578,30 @@ mod tests { vec![ create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "count".to_owned(), )?, create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Max), + "max".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "max".to_owned(), )?, create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Min), + "min".to_owned(), &[col("c3")], + &[], + &[], + Some(WindowFrame::default()), schema.as_ref(), - "min".to_owned(), )?, ], input,