From 854e8140e7f4fd51189fec214b7afdb57e9f46be Mon Sep 17 00:00:00 2001 From: praveenkrishna Date: Mon, 11 Mar 2019 15:30:25 +0530 Subject: [PATCH] Skip the creation of DistinctLimitNode for global aggregation node Currently if there is a global aggregation node with no aggregation functions and if we try to merge with a limit node we end up creating a DistinctLimitNode which fails when executed because there is no grouping key. So we skip the creation of `DistinctLimitNode` for such aggregation nodes. --- .../rule/MergeLimitWithDistinct.java | 1 + .../planner/optimizations/LimitPushDown.java | 1 + .../rule/TestMergeLimitWithDistinct.java | 66 +++++++++++++++++++ .../tests/AbstractTestAggregations.java | 6 ++ 4 files changed, 74 insertions(+) create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestMergeLimitWithDistinct.java diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/MergeLimitWithDistinct.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/MergeLimitWithDistinct.java index bff95b223c907..b21c64407df59 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/MergeLimitWithDistinct.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/MergeLimitWithDistinct.java @@ -41,6 +41,7 @@ public class MergeLimitWithDistinct private static boolean isDistinct(AggregationNode node) { return node.getAggregations().isEmpty() && + !node.getGroupingKeys().isEmpty() && node.getOutputSymbols().size() == node.getGroupingKeys().size() && node.getOutputSymbols().containsAll(node.getGroupingKeys()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/LimitPushDown.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/LimitPushDown.java index df368fc120f67..fa788b46816d2 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/LimitPushDown.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/LimitPushDown.java @@ -135,6 +135,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext + p.limit( + 1, + p.aggregation(builder -> builder + .singleGroupingSet(p.symbol("foo")) + .source(p.values(p.symbol("foo")))))) + .matches( + node(DistinctLimitNode.class, + node(ValuesNode.class))); + } + + @Test + public void testDoesNotFire() + { + tester().assertThat(new MergeLimitWithDistinct()) + .on(p -> + p.limit( + 1, + p.aggregation(builder -> builder + .addAggregation(p.symbol("c"), expression("count(foo)"), ImmutableList.of(BIGINT)) + .globalGrouping() + .source(p.values(p.symbol("foo")))))) + .doesNotFire(); + + tester().assertThat(new MergeLimitWithDistinct()) + .on(p -> + p.limit( + 1, + p.aggregation(builder -> builder + .globalGrouping() + .source(p.values(p.symbol("foo")))))) + .doesNotFire(); + } +} diff --git a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestAggregations.java b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestAggregations.java index 4f9b9e46421c3..87f0ea2d17a6d 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestAggregations.java +++ b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestAggregations.java @@ -1269,4 +1269,10 @@ public void testOrderedAggregations() "('5-LOW', 445 , NULL)," + "('1-URGENT', 781 , ('O'))"); } + + @Test + public void testAggregationWithConstantArgumentsOverScalar() + { + assertQuery("SELECT count(1) FROM (SELECT count(custkey) FROM orders LIMIT 10) a"); + } }