Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize distinct from semi join #8092

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange;
import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin;
import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete;
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
Expand Down Expand Up @@ -208,6 +209,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea
// add UnaliasSymbolReferences when it's ported
new RemoveRedundantIdentityProjections(),
new SwapAdjacentWindowsBySpecifications(),
new RemoveDistinctFromSemiJoin(),
new MergeAdjacentWindows())),
inlineProjections,
new PruneUnreferencedOutputs(), // Make sure to run this at the end to help clean the plan for logging/execution and not remove info that other optimizers might need at an earlier point
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Optional;

import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;

/**
* Remove distinct node from following path:
* <pre>
* SemiJoinNode
* - Source
* - FilteringSource
* - ProjectNode
* - AggregationNode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see distinct node here, do you mean it is under the aggregation?

* </pre>
*/
public class RemoveDistinctFromSemiJoin
implements Rule
{
@Override
public Optional<PlanNode> apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
{
if (node instanceof SemiJoinNode) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implement getPattern method. Also instead of nested IF blocks use a short-circuit IFs

if (!cond1) {
  return Optional.empty()
}
if (!cond2) {
  return Optional.empty()
}
if (!cond3) {
  return Optional.empty()
}

return actualRewrite;

PlanNode filteringSource = lookup.resolve(((SemiJoinNode) node).getFilteringSource());
Optional<PlanNode> result = rewriteUpstream(filteringSource, lookup);
if (result.isPresent()) {
return Optional.of(node.replaceChildren(ImmutableList.of(lookup.resolve(((SemiJoinNode) node).getSource()), result.get())));
}
}
return Optional.empty();
}

private Optional<PlanNode> rewriteUpstream(PlanNode node, Lookup lookup)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use PlanNodeSearcher in order to do such plan rewrites

{
if (node instanceof AggregationNode) {
AggregationNode aggregationNode = (AggregationNode) node;
if (isDistinctNode(aggregationNode)) {
return Optional.of(lookup.resolve(aggregationNode.getSource()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you do not need to resolve this node if you are not trying to cast it to actual node type

}
}
else if (node instanceof ProjectNode) {
Optional<PlanNode> result = rewriteUpstream(lookup.resolve(((ProjectNode) node).getSource()), lookup);
if (result.isPresent()) {
return Optional.of(node.replaceChildren(ImmutableList.of(result.get())));
}
}
return Optional.empty();
}

private boolean isDistinctNode(AggregationNode aggregationNode)
{
List<Symbol> outputSymbols = aggregationNode.getOutputSymbols();
List<Symbol> groupingKeys = aggregationNode.getGroupingKeys();
return aggregationNode.getStep() == SINGLE
&& outputSymbols.size() == 1
&& groupingKeys.size() == 1
&& outputSymbols.get(0).equals(groupingKeys.get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,12 @@ private synchronized CatalogMetadata getTransactionCatalogMetadata(ConnectorId c
CatalogMetadata catalogMetadata = this.catalogMetadata.get(connectorId);
if (catalogMetadata == null) {
Catalog catalog = catalogsByConnectorId.get(connectorId);
if (catalog == null) {
// For rule tester, getConnectorId has never been called because the plan was generated outside this transaction.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if need to change the product because the testing infrastructure code. Maybe we should testing code instead to not have this issue?

getConnectorId(connectorId.getCatalogName());
catalog = catalogsByConnectorId.get(connectorId);
}

verify(catalog != null, "Unknown connectorId: %s", connectorId);

ConnectorTransactionMetadata metadata = createConnectorTransactionMetadata(catalog.getConnectorId(), catalog);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import static com.facebook.presto.spi.type.VarcharType.createVarcharType;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.apply;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedTableScan;
Expand Down Expand Up @@ -321,4 +322,23 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
project(ImmutableMap.of("NON_NULL", expression("true")),
node(ValuesNode.class)))))))))));
}

@Test
public void testRemoveDistinctFromSemiJoin()
{
assertPlan(
"SELECT orderkey FROM orders " +
"WHERE custkey " +
"IN (SELECT distinct custkey FROM customer)",
anyTree(
semiJoin("Source", "Filter", "Output",
anyTree(tableScan("orders", ImmutableMap.of("Source", "custkey"))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put tableScan in separate line

anyNot(AggregationNode.class,
anyNot(AggregationNode.class,
tableScan("customer", ImmutableMap.of("Filter", "custkey"))
))
)
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.planner.iterative.rule.test;

import com.facebook.presto.Session;
import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
Expand Down Expand Up @@ -42,6 +43,7 @@
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
Expand Down Expand Up @@ -75,12 +77,15 @@ public class PlanBuilder
{
private final PlanNodeIdAllocator idAllocator;
private final Metadata metadata;
private final Session session;

private final Map<Symbol, Type> symbols = new HashMap<>();

public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata)
public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you need to update all the callers? If so please put such refactoring changes (adding a new session field to plan builder) as separate commit.

{
this.idAllocator = idAllocator;
this.metadata = metadata;
this.session = session;
}

public ValuesNode values(Symbol... columns)
Expand Down Expand Up @@ -199,14 +204,31 @@ public ApplyNode apply(Assignments subqueryAssignments, List<Symbol> correlation
return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation);
}

public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol, Symbol semiJoinOutput)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put symbols before plan nodes

{
return new SemiJoinNode(idAllocator.getNextId(),
source,
filteringSource,
sourceJoinSymbol,
filteringSourceJoinSymbol,
semiJoinOutput,
Optional.empty(),
Optional.empty(),
Optional.empty());
}

public TableScanNode tableScan(List<Symbol> symbols, Map<Symbol, ColumnHandle> assignments)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this method become deprecated?

{
TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle());
return tableScan(tableHandle, symbols, assignments);
}

public TableScanNode tableScan(TableHandle tableHandle, List<Symbol> symbols, Map<Symbol, ColumnHandle> assignments)
{
Expression originalConstraint = null;
return new TableScanNode(
idAllocator.getNextId(),
new TableHandle(
new ConnectorId("testConnector"),
new TestingTableHandle()),
tableHandle,
symbols,
assignments,
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public RuleAssert on(Function<PlanBuilder, PlanNode> planProvider)
{
checkArgument(plan == null, "plan has already been set");

PlanBuilder builder = new PlanBuilder(idAllocator, metadata);
PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session);
plan = planProvider.apply(builder);
symbols = builder.getSymbols();
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule.test;

import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.tpch.TpchColumnHandle;
import com.facebook.presto.tpch.TpchTableHandle;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR;

public class TestRemoveDistinctFromSemiJoin
{
private final RuleTester tester = new RuleTester();

@Test
public void test()
throws Exception
{
tester.assertThat(new RemoveDistinctFromSemiJoin())
.on(p -> {
Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT);
Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT);
Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT);
return p.semiJoin(
p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)),
ImmutableList.of(sourceKey),
ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))),
p.project(Assignments.of(filteringSourceKey, expression("x")),
p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE)
.groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey)))
.source(
p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)),
ImmutableList.of(filteringSourceKey),
ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT)))
)
.build())
),
sourceKey, filteringSourceKey, outputKey
);
})
.matches(
semiJoin("Source", "Filter", "Output",
tableScan("orders", ImmutableMap.of("Source", "custkey")),
project(tableScan("customer", ImmutableMap.of("Filter", "custkey")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please try to keep each plan node matcher in separate line

)
);
}

@Test
public void testDoesNotFire()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put some information why it does not fire into the test name

{
tester.assertThat(new RemoveDistinctFromSemiJoin())
.on(p -> {
Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT);
Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT);
Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT);
return p.semiJoin(
p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)),
ImmutableList.of(sourceKey),
ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))),
p.project(Assignments.of(filteringSourceKey, expression("x")),
p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE)
.groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey)))
.addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT))
.source(p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)),
ImmutableList.of(filteringSourceKey),
ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT)))
)
.build())
),
sourceKey, filteringSourceKey, outputKey
);
}).doesNotFire();
}
}