diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 9515ac2ff373..c02555d63314 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -309,7 +309,7 @@ impl LogicalPlanBuilder { // FIXME: implement next // window_frame: Option, ) -> Result { - let window_expr = window_expr.into_iter().collect::>(); + let window_expr = window_expr.into_iter().collect::>(); // FIXME: implement next // let partition_by_expr = partition_by_expr.into_iter().collect::>(); // FIXME: implement next diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index a3027e589985..63499aa1abe2 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -17,10 +17,6 @@ //! SQL Query Planner (produces logical plan from SQL AST) -use std::str::FromStr; -use std::sync::Arc; -use std::{convert::TryInto, vec}; - use crate::catalog::TableReference; use crate::datasource::TableProvider; use crate::logical_plan::Expr::Alias; @@ -28,6 +24,7 @@ use crate::logical_plan::{ and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, StringifiedPlan, ToDFSchema, }; +use crate::prelude::JoinType; use crate::scalar::ScalarValue; use crate::{ error::{DataFusionError, Result}, @@ -38,11 +35,8 @@ use crate::{ physical_plan::{aggregates, functions, window_functions}, sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}, }; - use arrow::datatypes::*; use hashbrown::HashMap; - -use crate::prelude::JoinType; use sqlparser::ast::{ BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Query, Select, SelectItem, @@ -52,6 +46,9 @@ use sqlparser::ast::{ use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; +use std::str::FromStr; +use std::sync::Arc; +use std::{convert::TryInto, vec}; use super::{ parser::DFParser, @@ -678,11 +675,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec)> { let plan = LogicalPlanBuilder::from(input) - .window(window_exprs)? + .window(window_exprs.clone())? .build()?; let select_exprs = select_exprs .iter() - .map(|expr| expr_as_column_expr(&expr, &plan)) + .map(|expr| rebase_expr(expr, &window_exprs, &plan)) .into_iter() .collect::>>()?; Ok((plan, select_exprs)) @@ -2710,6 +2707,16 @@ mod tests { quick_test(sql, expected); } + #[test] + fn empty_over_with_alias() { + let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; + let expected = "\ + Projection: #order_id AS oid, #MAX(order_id) AS max_oid\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn empty_over_plus() { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; diff --git a/integration-tests/sqls/simple_window_full_aggregation.sql b/integration-tests/sqls/simple_window_full_aggregation.sql new file mode 100644 index 000000000000..94860bc3b183 --- /dev/null +++ b/integration-tests/sqls/simple_window_full_aggregation.sql @@ -0,0 +1,25 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language gOVERning permissions and +-- limitations under the License. + +SELECT + row_number() OVER () AS row_number, + count(c3) OVER () AS count_c3, + avg(c3) OVER () AS avg, + sum(c3) OVER () AS sum, + max(c3) OVER () AS max, + min(c3) OVER () AS min +FROM test +ORDER BY row_number; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index f4967b8457e4..5bd308180e59 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 4, msg="tests are missed") + self.assertEqual(len(files), 5, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv(