Skip to content

Commit

Permalink
Add Disjunctive Predicate For Filtered Default Aggs
Browse files Browse the repository at this point in the history
Filtered aggregates do not make an explicit predicate of the filter
clause which leads to optimizer not being able to push down predicate
to source whenever possible. Note that the pushdown is possible only
for cases when the aggregation is default and no grouping sets are
present
  • Loading branch information
havi-odin authored and sopel39 committed Sep 20, 2018
1 parent eddfa2d commit 70426ef
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
Expand All @@ -30,7 +31,9 @@
import java.util.Optional;

import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.ExpressionUtils.combineDisjunctsWithDefault;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.google.common.base.Verify.verify;

/**
Expand All @@ -46,6 +49,7 @@
* - Aggregation
* F1(...) mask ($0)
* F2(...) mask ($1)
* - Filter(mask ($0) OR mask ($1))
* - Project
* <identity projections for existing fields>
* $0 = C1(...)
Expand Down Expand Up @@ -78,6 +82,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont
{
Assignments.Builder newAssignments = Assignments.builder();
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
boolean aggregateWithoutFilterPresent = false;

for (Map.Entry<Symbol, Aggregation> entry : aggregation.getAggregations().entrySet()) {
Symbol output = entry.getKey();
Expand All @@ -92,23 +98,37 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont
verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern");
newAssignments.put(symbol, filter);
mask = Optional.of(symbol);

maskSymbols.add(symbol.toSymbolReference());
}
else {
aggregateWithoutFilterPresent = true;
}

aggregations.put(output, new Aggregation(
new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.getOrderBy(), call.isDistinct(), call.getArguments()),
entry.getValue().getSignature(),
mask));
}

Expression predicate = TRUE_LITERAL;
if (!aggregation.hasNonEmptyGroupingSet() && !aggregateWithoutFilterPresent) {
predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL);
}

// identity projection for all existing inputs
newAssignments.putIdentities(aggregation.getSource().getOutputSymbols());

return Result.ofPlanNode(
new AggregationNode(
context.getIdAllocator().getNextId(),
new ProjectNode(
new FilterNode(
context.getIdAllocator().getNextId(),
aggregation.getSource(),
newAssignments.build()),
new ProjectNode(
context.getIdAllocator().getNextId(),
aggregation.getSource(),
newAssignments.build()),
predicate),
aggregations.build(),
aggregation.getGroupingSets(),
ImmutableList.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
*/
package com.facebook.presto.sql.query;

import com.facebook.presto.sql.planner.LogicalPlanner;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.util.MorePredicates.isInstanceOfAny;
import static org.testng.Assert.assertFalse;

public class TestFilteredAggregations
extends BasePlanTest
{
private QueryAssertions assertions;

Expand All @@ -34,6 +46,22 @@ public void teardown()
assertions = null;
}

@Test
public void testAddPredicateForFilterClauses()
{
assertions.assertQuery(
"SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
"VALUES (BIGINT '10')");

assertions.assertQuery(
"SELECT sum(x) FILTER(WHERE x > 0), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)",
"VALUES (BIGINT '18', BIGINT '2')");

assertions.assertQuery(
"SELECT sum(x) FILTER(WHERE x > 1), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)",
"VALUES (BIGINT '8', BIGINT '10')");
}

@Test
public void testGroupAll()
{
Expand Down Expand Up @@ -86,4 +114,48 @@ public void testGroupingSets()
"(2, BIGINT '4', BIGINT '1'), " +
"(CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')");
}

@Test
public void rewriteAddFilterWithMultipleFilters()
{
assertPlan(
"SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FILTER(WHERE custkey > 0) FROM orders",
anyTree(
filter(
"(\"totalprice\" > 0E0 OR \"custkey\" > BIGINT '0')",
tableScan(
"orders", ImmutableMap.of("totalprice", "totalprice",
"custkey", "custkey")))));
}

@Test
public void testDoNotPushdownPredicateIfNonFilteredAggregateIsPresent()
{
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FROM orders");
}

@Test
public void testPushDownConstantFilterPredicate()
{
assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE FALSE) FROM orders");

assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE TRUE) FROM orders");
}

@Test
public void testNoFilterAddedForConstantValueFilters()
{
assertPlanContainsNoFilter("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x");

assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0) FROM orders GROUP BY totalprice");
}

private void assertPlanContainsNoFilter(String sql)
{
assertFalse(
searchFrom(plan(sql, LogicalPlanner.Stage.OPTIMIZED).getRoot())
.where(isInstanceOfAny(FilterNode.class))
.matches(),
"Unexpected node for query: " + sql);
}
}

0 comments on commit 70426ef

Please sign in to comment.