diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
index 0dc55c07e592..ba443b323087 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java
@@ -41,6 +41,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange;
import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion;
+import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin;
import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete;
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
@@ -208,6 +209,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea
// add UnaliasSymbolReferences when it's ported
new RemoveRedundantIdentityProjections(),
new SwapAdjacentWindowsBySpecifications(),
+ new RemoveDistinctFromSemiJoin(),
new MergeAdjacentWindows())),
inlineProjections,
new PruneUnreferencedOutputs(), // Make sure to run this at the end to help clean the plan for logging/execution and not remove info that other optimizers might need at an earlier point
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java
new file mode 100644
index 000000000000..86d6ad316dbb
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.AggregationNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+import java.util.Optional;
+
+import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
+
+/**
+ * Remove distinct node from following path:
+ *
+ * SemiJoinNode
+ * - Source
+ * - FilteringSource
+ * - ProjectNode
+ * - AggregationNode
+ *
+ */
+public class RemoveDistinctFromSemiJoin
+ implements Rule
+{
+ @Override
+ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+ {
+ if (node instanceof SemiJoinNode) {
+ PlanNode filteringSource = lookup.resolve(((SemiJoinNode) node).getFilteringSource());
+ Optional result = rewriteUpstream(filteringSource, lookup);
+ if (result.isPresent()) {
+ return Optional.of(node.replaceChildren(ImmutableList.of(lookup.resolve(((SemiJoinNode) node).getSource()), result.get())));
+ }
+ }
+ return Optional.empty();
+ }
+
+ private Optional rewriteUpstream(PlanNode node, Lookup lookup)
+ {
+ if (node instanceof AggregationNode) {
+ AggregationNode aggregationNode = (AggregationNode) node;
+ if (isDistinctNode(aggregationNode)) {
+ return Optional.of(lookup.resolve(aggregationNode.getSource()));
+ }
+ }
+ else if (node instanceof ProjectNode) {
+ Optional result = rewriteUpstream(lookup.resolve(((ProjectNode) node).getSource()), lookup);
+ if (result.isPresent()) {
+ return Optional.of(node.replaceChildren(ImmutableList.of(result.get())));
+ }
+ }
+ return Optional.empty();
+ }
+
+ private boolean isDistinctNode(AggregationNode aggregationNode)
+ {
+ List outputSymbols = aggregationNode.getOutputSymbols();
+ List groupingKeys = aggregationNode.getGroupingKeys();
+ return aggregationNode.getStep() == SINGLE
+ && outputSymbols.size() == 1
+ && groupingKeys.size() == 1
+ && outputSymbols.get(0).equals(groupingKeys.get(0));
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java
index 42b10921922c..d89807b5a956 100644
--- a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java
+++ b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java
@@ -390,6 +390,12 @@ private synchronized CatalogMetadata getTransactionCatalogMetadata(ConnectorId c
CatalogMetadata catalogMetadata = this.catalogMetadata.get(connectorId);
if (catalogMetadata == null) {
Catalog catalog = catalogsByConnectorId.get(connectorId);
+ if (catalog == null) {
+ // For rule tester, getConnectorId has never been called because the plan was generated outside this transaction.
+ getConnectorId(connectorId.getCatalogName());
+ catalog = catalogsByConnectorId.get(connectorId);
+ }
+
verify(catalog != null, "Unknown connectorId: %s", connectorId);
ConnectorTransactionMetadata metadata = createConnectorTransactionMetadata(catalog.getConnectorId(), catalog);
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java
index 94e7fae7e0c7..ac318b7d8855 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java
@@ -38,6 +38,7 @@
import static com.facebook.presto.spi.type.VarcharType.createVarcharType;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.apply;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedTableScan;
@@ -321,4 +322,23 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
project(ImmutableMap.of("NON_NULL", expression("true")),
node(ValuesNode.class)))))))))));
}
+
+ @Test
+ public void testRemoveDistinctFromSemiJoin()
+ {
+ assertPlan(
+ "SELECT orderkey FROM orders " +
+ "WHERE custkey " +
+ "IN (SELECT distinct custkey FROM customer)",
+ anyTree(
+ semiJoin("Source", "Filter", "Output",
+ anyTree(tableScan("orders", ImmutableMap.of("Source", "custkey"))),
+ anyNot(AggregationNode.class,
+ anyNot(AggregationNode.class,
+ tableScan("customer", ImmutableMap.of("Filter", "custkey"))
+ ))
+ )
+ )
+ );
+ }
}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java
index d6af5bf0f880..f477adaa7de2 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java
@@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.planner.iterative.rule.test;
+import com.facebook.presto.Session;
import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
@@ -42,6 +43,7 @@
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
@@ -75,12 +77,15 @@ public class PlanBuilder
{
private final PlanNodeIdAllocator idAllocator;
private final Metadata metadata;
+ private final Session session;
+
private final Map symbols = new HashMap<>();
- public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata)
+ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session)
{
this.idAllocator = idAllocator;
this.metadata = metadata;
+ this.session = session;
}
public ValuesNode values(Symbol... columns)
@@ -199,14 +204,31 @@ public ApplyNode apply(Assignments subqueryAssignments, List correlation
return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation);
}
+ public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol, Symbol semiJoinOutput)
+ {
+ return new SemiJoinNode(idAllocator.getNextId(),
+ source,
+ filteringSource,
+ sourceJoinSymbol,
+ filteringSourceJoinSymbol,
+ semiJoinOutput,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+ }
+
public TableScanNode tableScan(List symbols, Map assignments)
+ {
+ TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle());
+ return tableScan(tableHandle, symbols, assignments);
+ }
+
+ public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments)
{
Expression originalConstraint = null;
return new TableScanNode(
idAllocator.getNextId(),
- new TableHandle(
- new ConnectorId("testConnector"),
- new TestingTableHandle()),
+ tableHandle,
symbols,
assignments,
Optional.empty(),
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java
index 6b5890d70e9e..8d10877eeb0a 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java
@@ -83,7 +83,7 @@ public RuleAssert on(Function planProvider)
{
checkArgument(plan == null, "plan has already been set");
- PlanBuilder builder = new PlanBuilder(idAllocator, metadata);
+ PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session);
plan = planProvider.apply(builder);
symbols = builder.getSymbols();
return this;
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java
new file mode 100644
index 000000000000..86a34baad1e5
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule.test;
+
+import com.facebook.presto.connector.ConnectorId;
+import com.facebook.presto.metadata.TableHandle;
+import com.facebook.presto.spi.type.BigintType;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin;
+import com.facebook.presto.sql.planner.plan.AggregationNode;
+import com.facebook.presto.sql.planner.plan.Assignments;
+import com.facebook.presto.tpch.TpchColumnHandle;
+import com.facebook.presto.tpch.TpchTableHandle;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.spi.type.BigintType.BIGINT;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
+import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
+import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
+import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR;
+
+public class TestRemoveDistinctFromSemiJoin
+{
+ private final RuleTester tester = new RuleTester();
+
+ @Test
+ public void test()
+ throws Exception
+ {
+ tester.assertThat(new RemoveDistinctFromSemiJoin())
+ .on(p -> {
+ Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT);
+ Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT);
+ Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT);
+ return p.semiJoin(
+ p.tableScan(
+ new TableHandle(
+ new ConnectorId("local"),
+ new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)),
+ ImmutableList.of(sourceKey),
+ ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))),
+ p.project(Assignments.of(filteringSourceKey, expression("x")),
+ p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE)
+ .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey)))
+ .source(
+ p.tableScan(
+ new TableHandle(
+ new ConnectorId("local"),
+ new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)),
+ ImmutableList.of(filteringSourceKey),
+ ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT)))
+ )
+ .build())
+ ),
+ sourceKey, filteringSourceKey, outputKey
+ );
+ })
+ .matches(
+ semiJoin("Source", "Filter", "Output",
+ tableScan("orders", ImmutableMap.of("Source", "custkey")),
+ project(tableScan("customer", ImmutableMap.of("Filter", "custkey")))
+ )
+ );
+ }
+
+ @Test
+ public void testDoesNotFire()
+ {
+ tester.assertThat(new RemoveDistinctFromSemiJoin())
+ .on(p -> {
+ Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT);
+ Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT);
+ Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT);
+ return p.semiJoin(
+ p.tableScan(
+ new TableHandle(
+ new ConnectorId("local"),
+ new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)),
+ ImmutableList.of(sourceKey),
+ ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))),
+ p.project(Assignments.of(filteringSourceKey, expression("x")),
+ p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE)
+ .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey)))
+ .addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT))
+ .source(p.tableScan(
+ new TableHandle(
+ new ConnectorId("local"),
+ new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)),
+ ImmutableList.of(filteringSourceKey),
+ ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT)))
+ )
+ .build())
+ ),
+ sourceKey, filteringSourceKey, outputKey
+ );
+ }).doesNotFire();
+ }
+}