Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disjunctive predicate for filtered aggregates #11425

Merged
merged 1 commit into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.sql.planner.plan.AggregationNode;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add to the commit message information about rule application space, that this only work only for default aggregation.

import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
Expand All @@ -30,7 +31,9 @@
import java.util.Optional;

import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.ExpressionUtils.combineDisjunctsWithDefault;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.google.common.base.Verify.verify;

/**
Expand All @@ -46,6 +49,7 @@
* - Aggregation
* F1(...) mask ($0)
* F2(...) mask ($1)
* - Filter(mask ($0) OR mask ($1))
* - Project
* <identity projections for existing fields>
* $0 = C1(...)
Expand Down Expand Up @@ -78,6 +82,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont
{
Assignments.Builder newAssignments = Assignments.builder();
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
boolean aggregateWithoutFilterPresent = false;

for (Map.Entry<Symbol, Aggregation> entry : aggregation.getAggregations().entrySet()) {
Symbol output = entry.getKey();
Expand All @@ -92,23 +98,37 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont
verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern");
newAssignments.put(symbol, filter);
mask = Optional.of(symbol);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add mask symbol to the list:

maskSymbols.add(symbol.toSymbolReference());


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undo, such empty line can cause the build failure

maskSymbols.add(symbol.toSymbolReference());
}
else {
aggregateWithoutFilterPresent = true;
}

aggregations.put(output, new Aggregation(
new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.getOrderBy(), call.isDistinct(), call.getArguments()),
entry.getValue().getSignature(),
mask));
}

Expression predicate = TRUE_LITERAL;
if (!aggregation.hasNonEmptyGroupingSet() && !aggregateWithoutFilterPresent) {
predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL);
}

// identity projection for all existing inputs
newAssignments.putIdentities(aggregation.getSource().getOutputSymbols());

return Result.ofPlanNode(
new AggregationNode(
context.getIdAllocator().getNextId(),
new ProjectNode(
new FilterNode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add FilterNode above ProjectNode using maskSymbols:

Expression predicate = TRUE_LITERAL;
if (!node.hasNonEmptyGroupingSet()) {
    predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL));
}
new FilterNode(
  ...,
  predicate,
...

This way we won't do redundant computations of expressions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to consider if any aggregate without a FILTER clause is present, made the change and fixed

context.getIdAllocator().getNextId(),
aggregation.getSource(),
newAssignments.build()),
new ProjectNode(
context.getIdAllocator().getNextId(),
aggregation.getSource(),
newAssignments.build()),
predicate),
aggregations.build(),
aggregation.getGroupingSets(),
ImmutableList.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
*/
package com.facebook.presto.sql.query;

import com.facebook.presto.sql.planner.LogicalPlanner;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.util.MorePredicates.isInstanceOfAny;
import static org.testng.Assert.assertFalse;

public class TestFilteredAggregations
extends BasePlanTest
{
private QueryAssertions assertions;

Expand All @@ -34,6 +46,22 @@ public void teardown()
assertions = null;
}

@Test
public void testAddPredicateForFilterClauses()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/testNoFilterAddedForConstantValueFilters/testFilterdAggregationPredicatePushdown

{
assertions.assertQuery(
"SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test also SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would ask , what does t(x) mean ? I have never seen this kind of way of expression. Could you tell me more ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t(x) is a pseudo representation of a table with a single column x and rows 1, 1, 0, 2, 3, 3

"VALUES (BIGINT '10')");

assertions.assertQuery(
"SELECT sum(x) FILTER(WHERE x > 0), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)",
"VALUES (BIGINT '18', BIGINT '2')");

assertions.assertQuery(
"SELECT sum(x) FILTER(WHERE x > 1), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
"VALUES (BIGINT '8', BIGINT '10')");
}

@Test
public void testGroupAll()
{
Expand Down Expand Up @@ -86,4 +114,48 @@ public void testGroupingSets()
"(2, BIGINT '4', BIGINT '1'), " +
"(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')");
}

@Test
public void rewriteAddFilterWithMultipleFilters()
{
assertPlan(
"SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FILTER(WHERE custkey > 0) FROM orders",
anyTree(
filter(
"(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')",
tableScan(
"orders", ImmutableMap.of("totalprice", "totalprice",
"custkey", "custkey")))));
}

@Test
public void testDoNotPushdownPredicateIfNonFilteredAggregateIsPresent()
{
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FROM orders");
}

@Test
public void testPushDownConstantFilterPredicate()
{
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE FALSE) FROM orders");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment that filter node was optimized


assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE TRUE) FROM orders");
}

@Test
public void testNoFilterAddedForConstantValueFilters()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/testNoFilterAddedForConstantValueFilters/testNoFilterAddedFoNonDefaultAggregation

{
assertPlanContainsNoFilter("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x");

assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0) FROM orders GROUP BY totalprice");
}

private void assertPlanContainsNoFilter(String sql)
{
assertFalse(
searchFrom(plan(sql, LogicalPlanner.Stage.OPTIMIZED).getRoot())
.where(isInstanceOfAny(FilterNode.class))
.matches(),
"Unexpected node for query: " + sql);
}
}