Skip to content

Commit

Permalink
Make PushPredicateIntoTableScan top level rule
Browse files Browse the repository at this point in the history
It now contains a single rule, so no point in
having it return a rule set.
  • Loading branch information
martint committed Mar 9, 2019
1 parent deff7b1 commit 64f93a9
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
new PushPredicateIntoTableScan(metadata, sqlParser).rules()),
ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))),
new PruneUnreferencedOutputs(),
new IterativeOptimizer(
ruleStats,
Expand Down Expand Up @@ -407,7 +407,7 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
new PushPredicateIntoTableScan(metadata, sqlParser).rules()),
ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))),
projectionPushDown,
new PruneUnreferencedOutputs(),
new IterativeOptimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.matching.Capture;
Expand Down Expand Up @@ -61,7 +60,6 @@
import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts;
import static io.prestosql.sql.ExpressionUtils.filterNonDeterministicConjuncts;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static io.prestosql.sql.planner.iterative.rule.PreconditionRules.checkRulesAreFiredBeforeAddExchangesRule;
import static io.prestosql.sql.planner.plan.Patterns.filter;
import static io.prestosql.sql.planner.plan.Patterns.source;
import static io.prestosql.sql.planner.plan.Patterns.tableScan;
Expand All @@ -74,7 +72,13 @@
* chosen by AddExchanges
*/
public class PushPredicateIntoTableScan
implements Rule<FilterNode>
{
private static final Capture<TableScanNode> TABLE_SCAN = newCapture();

private static final Pattern<FilterNode> PATTERN = filter().with(source().matching(
tableScan().capturedAs(TABLE_SCAN)));

private final Metadata metadata;
private final SqlParser parser;
private final DomainTranslator domainTranslator;
Expand All @@ -86,90 +90,60 @@ public PushPredicateIntoTableScan(Metadata metadata, SqlParser parser)
this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde()));
}

public Set<Rule<?>> rules()
@Override
public Pattern<FilterNode> getPattern()
{
return ImmutableSet.of(pickTableLayoutForPredicate());
return PATTERN;
}

public PickTableLayoutForPredicate pickTableLayoutForPredicate()
@Override
public boolean isEnabled(Session session)
{
return new PickTableLayoutForPredicate(metadata, parser, domainTranslator);
return isNewOptimizerEnabled(session);
}

private static final class PickTableLayoutForPredicate
implements Rule<FilterNode>
@Override
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
private final Metadata metadata;
private final SqlParser parser;
private final DomainTranslator domainTranslator;
TableScanNode tableScan = captures.get(TABLE_SCAN);

PlanNode rewritten = pushFilterIntoTableScan(
tableScan,
filterNode.getPredicate(),
false,
context.getSession(),
context.getSymbolAllocator().getTypes(),
context.getIdAllocator(),
metadata,
parser,
domainTranslator);

private PickTableLayoutForPredicate(Metadata metadata, SqlParser parser, DomainTranslator domainTranslator)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.parser = requireNonNull(parser, "parser is null");
this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
if (arePlansSame(filterNode, tableScan, rewritten)) {
return Result.empty();
}

private static final Capture<TableScanNode> TABLE_SCAN = newCapture();

private static final Pattern<FilterNode> PATTERN = filter().with(source().matching(
tableScan().capturedAs(TABLE_SCAN)));
return Result.ofPlanNode(rewritten);
}

@Override
public Pattern<FilterNode> getPattern()
{
return PATTERN;
private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten)
{
if (!(rewritten instanceof FilterNode)) {
return false;
}

@Override
public boolean isEnabled(Session session)
{
return isNewOptimizerEnabled(session);
FilterNode rewrittenFilter = (FilterNode) rewritten;
if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) {
return false;
}

@Override
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
TableScanNode tableScan = captures.get(TABLE_SCAN);

PlanNode rewritten = pushFilterIntoTableScan(
tableScan,
filterNode.getPredicate(),
false,
context.getSession(),
context.getSymbolAllocator().getTypes(),
context.getIdAllocator(),
metadata,
parser,
domainTranslator);

if (arePlansSame(filterNode, tableScan, rewritten)) {
return Result.empty();
}

return Result.ofPlanNode(rewritten);
if (!(rewrittenFilter.getSource() instanceof TableScanNode)) {
return false;
}

private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten)
{
if (!(rewritten instanceof FilterNode)) {
return false;
}
TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource();

FilterNode rewrittenFilter = (FilterNode) rewritten;
if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) {
return false;
}

if (!(rewrittenFilter.getSource() instanceof TableScanNode)) {
return false;
}

TableScanNode rewrittenTableScan = (TableScanNode) rewrittenFilter.getSource();

return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint())
&& Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint());
}
return Objects.equals(tableScan.getCurrentConstraint(), rewrittenTableScan.getCurrentConstraint())
&& Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint());
}

public static PlanNode pushFilterIntoTableScan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -76,17 +75,15 @@ public void setUpBeforeClass()
@Test
public void doesNotFireIfNoTableScan()
{
for (Rule<?> rule : pushPredicateIntoTableScan.rules()) {
tester().assertThat(rule)
.on(p -> p.values(p.symbol("a", BIGINT)))
.doesNotFire();
}
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.values(p.symbol("a", BIGINT)))
.doesNotFire();
}

@Test
public void eliminateTableScanWhenNoLayoutExist()
{
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("orderstatus = 'G'"),
p.tableScan(
ordersTableHandle,
Expand All @@ -99,7 +96,7 @@ public void eliminateTableScanWhenNoLayoutExist()
public void replaceWithExistsWhenNoLayoutExist()
{
ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT);
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("nationkey = BIGINT '44'"),
p.tableScan(
nationTableHandle,
Expand All @@ -113,7 +110,7 @@ public void replaceWithExistsWhenNoLayoutExist()
@Test
public void doesNotFireIfRuleNotChangePlan()
{
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("nationkey % 17 = BIGINT '44' AND nationkey % 15 = BIGINT '43'"),
p.tableScan(
nationTableHandle,
Expand All @@ -130,7 +127,7 @@ public void ruleAddedTableLayoutToFilterTableScan()
Map<String, Domain> filterConstraint = ImmutableMap.<String, Domain>builder()
.put("orderstatus", singleValue(createVarcharType(1), utf8Slice("F")))
.build();
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("orderstatus = CAST ('F' AS VARCHAR(1))"),
p.tableScan(
ordersTableHandle,
Expand All @@ -143,7 +140,7 @@ public void ruleAddedTableLayoutToFilterTableScan()
@Test
public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint()
{
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("orderstatus = 'F'"),
p.tableScan(
ordersTableHandle,
Expand All @@ -160,7 +157,7 @@ public void ruleAddedNewTableLayoutIfTableScanHasEmptyConstraint()
public void ruleWithPushdownableToTableLayoutPredicate()
{
Type orderStatusType = createVarcharType(1);
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("orderstatus = 'O'"),
p.tableScan(
ordersTableHandle,
Expand All @@ -176,7 +173,7 @@ public void ruleWithPushdownableToTableLayoutPredicate()
public void nonDeterministicPredicate()
{
Type orderStatusType = createVarcharType(1);
tester().assertThat(pushPredicateIntoTableScan.pickTableLayoutForPredicate())
tester().assertThat(pushPredicateIntoTableScan)
.on(p -> p.filter(expression("orderstatus = 'O' AND rand() = 0"),
p.tableScan(
ordersTableHandle,
Expand Down

0 comments on commit 64f93a9

Please sign in to comment.