diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 703a13a1cb1d..c17192cad1b5 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -671,32 +671,18 @@ mod tests { )]; let sort = sort_exec(sort_exprs.clone(), source); - let window_agg_exec = Arc::new(WindowAggExec::try_new( - vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), - "count".to_owned(), - &[col("non_nullable_col", &schema)?], - &[], - &sort_exprs, - Arc::new(WindowFrame::new(true)), - schema.as_ref(), - )?], - sort.clone(), - sort.schema(), - vec![], - Some(sort_exprs), - )?) as Arc; + let window_agg = window_exec("non_nullable_col", sort_exprs, sort); let sort_exprs = vec![sort_expr_options( "non_nullable_col", - &window_agg_exec.schema(), + &window_agg.schema(), SortOptions { descending: false, nulls_first: false, }, )]; - let sort = sort_exec(sort_exprs.clone(), window_agg_exec); + let sort = sort_exec(sort_exprs.clone(), window_agg); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( @@ -707,21 +693,7 @@ mod tests { ); // let filter_exec = sort_exec; - let physical_plan = Arc::new(WindowAggExec::try_new( - vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), - "count".to_owned(), - &[col("non_nullable_col", &schema)?], - &[], - &sort_exprs, - Arc::new(WindowFrame::new(true)), - schema.as_ref(), - )?], - filter.clone(), - filter.schema(), - vec![], - Some(sort_exprs), - )?) as Arc; + let physical_plan = window_exec("non_nullable_col", sort_exprs, filter); let expected_input = vec![ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", @@ -888,6 +860,35 @@ mod tests { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } + fn window_exec( + col_name: &str, + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); + let schema = input.schema(); + + Arc::new( + WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col(col_name, &schema).unwrap()], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + ) + .unwrap()], + input.clone(), + input.schema(), + vec![], + Some(sort_exprs), + ) + .unwrap(), + ) + } + /// Create a non sorted parquet exec fn parquet_exec(schema: &SchemaRef) -> Arc { Arc::new(ParquetExec::new(