-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add disjunctive predicate for filtered aggregates #11425
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import com.facebook.presto.sql.planner.plan.AggregationNode; | ||
import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; | ||
import com.facebook.presto.sql.planner.plan.Assignments; | ||
import com.facebook.presto.sql.planner.plan.FilterNode; | ||
import com.facebook.presto.sql.planner.plan.ProjectNode; | ||
import com.facebook.presto.sql.tree.Expression; | ||
import com.facebook.presto.sql.tree.FunctionCall; | ||
|
@@ -30,7 +31,9 @@ | |
import java.util.Optional; | ||
|
||
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; | ||
import static com.facebook.presto.sql.ExpressionUtils.combineDisjunctsWithDefault; | ||
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; | ||
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; | ||
import static com.google.common.base.Verify.verify; | ||
|
||
/** | ||
|
@@ -46,6 +49,7 @@ | |
* - Aggregation | ||
* F1(...) mask ($0) | ||
* F2(...) mask ($1) | ||
* - Filter(mask ($0) OR mask ($1)) | ||
* - Project | ||
* <identity projections for existing fields> | ||
* $0 = C1(...) | ||
|
@@ -78,6 +82,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont | |
{ | ||
Assignments.Builder newAssignments = Assignments.builder(); | ||
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder(); | ||
ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder(); | ||
boolean aggregateWithoutFilterPresent = false; | ||
|
||
for (Map.Entry<Symbol, Aggregation> entry : aggregation.getAggregations().entrySet()) { | ||
Symbol output = entry.getKey(); | ||
|
@@ -92,23 +98,37 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont | |
verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern"); | ||
newAssignments.put(symbol, filter); | ||
mask = Optional.of(symbol); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add mask symbol to the list:
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. undo, such empty line can cause the build failure |
||
maskSymbols.add(symbol.toSymbolReference()); | ||
} | ||
else { | ||
aggregateWithoutFilterPresent = true; | ||
} | ||
|
||
aggregations.put(output, new Aggregation( | ||
new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.getOrderBy(), call.isDistinct(), call.getArguments()), | ||
entry.getValue().getSignature(), | ||
mask)); | ||
} | ||
|
||
Expression predicate = TRUE_LITERAL; | ||
if (!aggregation.hasNonEmptyGroupingSet() && !aggregateWithoutFilterPresent) { | ||
predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL); | ||
} | ||
|
||
// identity projection for all existing inputs | ||
newAssignments.putIdentities(aggregation.getSource().getOutputSymbols()); | ||
|
||
return Result.ofPlanNode( | ||
new AggregationNode( | ||
context.getIdAllocator().getNextId(), | ||
new ProjectNode( | ||
new FilterNode( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add
This way we won't do redundant computations of expressions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also needs to consider if any aggregate without a FILTER clause is present, made the change and fixed |
||
context.getIdAllocator().getNextId(), | ||
aggregation.getSource(), | ||
newAssignments.build()), | ||
new ProjectNode( | ||
context.getIdAllocator().getNextId(), | ||
aggregation.getSource(), | ||
newAssignments.build()), | ||
predicate), | ||
aggregations.build(), | ||
aggregation.getGroupingSets(), | ||
ImmutableList.of(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,23 @@ | |
*/ | ||
package com.facebook.presto.sql.query; | ||
|
||
import com.facebook.presto.sql.planner.LogicalPlanner; | ||
import com.facebook.presto.sql.planner.assertions.BasePlanTest; | ||
import com.facebook.presto.sql.planner.plan.FilterNode; | ||
import com.google.common.collect.ImmutableMap; | ||
import org.testng.annotations.AfterClass; | ||
import org.testng.annotations.BeforeClass; | ||
import org.testng.annotations.Test; | ||
|
||
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; | ||
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; | ||
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; | ||
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; | ||
import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; | ||
import static org.testng.Assert.assertFalse; | ||
|
||
public class TestFilteredAggregations | ||
extends BasePlanTest | ||
{ | ||
private QueryAssertions assertions; | ||
|
||
|
@@ -34,6 +46,22 @@ public void teardown() | |
assertions = null; | ||
} | ||
|
||
@Test | ||
public void testAddPredicateForFilterClauses() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/testNoFilterAddedForConstantValueFilters/testFilterdAggregationPredicatePushdown |
||
{ | ||
assertions.assertQuery( | ||
"SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please test also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would ask , what does t(x) mean ? I have never seen this kind of way of expression. Could you tell me more ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"VALUES (BIGINT '10')"); | ||
|
||
assertions.assertQuery( | ||
"SELECT sum(x) FILTER(WHERE x > 0), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)", | ||
"VALUES (BIGINT '18', BIGINT '2')"); | ||
|
||
assertions.assertQuery( | ||
"SELECT sum(x) FILTER(WHERE x > 1), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)", | ||
"VALUES (BIGINT '8', BIGINT '10')"); | ||
} | ||
|
||
@Test | ||
public void testGroupAll() | ||
{ | ||
|
@@ -86,4 +114,48 @@ public void testGroupingSets() | |
"(2, BIGINT '4', BIGINT '1'), " + | ||
"(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')"); | ||
} | ||
|
||
@Test | ||
public void rewriteAddFilterWithMultipleFilters() | ||
{ | ||
assertPlan( | ||
"SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FILTER(WHERE custkey > 0) FROM orders", | ||
anyTree( | ||
filter( | ||
"(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')", | ||
tableScan( | ||
"orders", ImmutableMap.of("totalprice", "totalprice", | ||
"custkey", "custkey"))))); | ||
} | ||
|
||
@Test | ||
public void testDoNotPushdownPredicateIfNonFilteredAggregateIsPresent() | ||
{ | ||
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FROM orders"); | ||
} | ||
|
||
@Test | ||
public void testPushDownConstantFilterPredicate() | ||
{ | ||
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE FALSE) FROM orders"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a comment that filter node was optimized |
||
|
||
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE TRUE) FROM orders"); | ||
} | ||
|
||
@Test | ||
public void testNoFilterAddedForConstantValueFilters() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/testNoFilterAddedForConstantValueFilters/testNoFilterAddedFoNonDefaultAggregation |
||
{ | ||
assertPlanContainsNoFilter("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x"); | ||
|
||
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0) FROM orders GROUP BY totalprice"); | ||
} | ||
|
||
private void assertPlanContainsNoFilter(String sql) | ||
{ | ||
assertFalse( | ||
searchFrom(plan(sql, LogicalPlanner.Stage.OPTIMIZED).getRoot()) | ||
.where(isInstanceOfAny(FilterNode.class)) | ||
.matches(), | ||
"Unexpected node for query: " + sql); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add to the commit message information about rule application space, that this only work only for default aggregation.