diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 0dc55c07e592..ba443b323087 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -41,6 +41,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange; import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion; +import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; @@ -208,6 +209,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea // add UnaliasSymbolReferences when it's ported new RemoveRedundantIdentityProjections(), new SwapAdjacentWindowsBySpecifications(), + new RemoveDistinctFromSemiJoin(), new MergeAdjacentWindows())), inlineProjections, new PruneUnreferencedOutputs(), // Make sure to run this at the end to help clean the plan for logging/execution and not remove info that other optimizers might need at an earlier point diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java new file mode 100644 index 000000000000..86d6ad316dbb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; + +/** + * Remove distinct node from following path: + *
+ *     SemiJoinNode
+ *        - Source
+ *        - FilteringSource
+ *          - ProjectNode
+ *              - AggregationNode
+ * 
+ */ +public class RemoveDistinctFromSemiJoin + implements Rule +{ + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (node instanceof SemiJoinNode) { + PlanNode filteringSource = lookup.resolve(((SemiJoinNode) node).getFilteringSource()); + Optional result = rewriteUpstream(filteringSource, lookup); + if (result.isPresent()) { + return Optional.of(node.replaceChildren(ImmutableList.of(lookup.resolve(((SemiJoinNode) node).getSource()), result.get()))); + } + } + return Optional.empty(); + } + + private Optional rewriteUpstream(PlanNode node, Lookup lookup) + { + if (node instanceof AggregationNode) { + AggregationNode aggregationNode = (AggregationNode) node; + if (isDistinctNode(aggregationNode)) { + return Optional.of(lookup.resolve(aggregationNode.getSource())); + } + } + else if (node instanceof ProjectNode) { + Optional result = rewriteUpstream(lookup.resolve(((ProjectNode) node).getSource()), lookup); + if (result.isPresent()) { + return Optional.of(node.replaceChildren(ImmutableList.of(result.get()))); + } + } + return Optional.empty(); + } + + private boolean isDistinctNode(AggregationNode aggregationNode) + { + List outputSymbols = aggregationNode.getOutputSymbols(); + List groupingKeys = aggregationNode.getGroupingKeys(); + return aggregationNode.getStep() == SINGLE + && outputSymbols.size() == 1 + && groupingKeys.size() == 1 + && outputSymbols.get(0).equals(groupingKeys.get(0)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java index 42b10921922c..d89807b5a956 100644 --- a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java @@ -390,6 +390,12 @@ private synchronized CatalogMetadata getTransactionCatalogMetadata(ConnectorId c CatalogMetadata catalogMetadata = this.catalogMetadata.get(connectorId); if (catalogMetadata == null) { Catalog catalog = catalogsByConnectorId.get(connectorId); + if (catalog == null) { + // For rule tester, getConnectorId has never been called because the plan was generated outside this transaction. + getConnectorId(connectorId.getCatalogName()); + catalog = catalogsByConnectorId.get(connectorId); + } + verify(catalog != null, "Unknown connectorId: %s", connectorId); ConnectorTransactionMetadata metadata = createConnectorTransactionMetadata(catalog.getConnectorId(), catalog); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 94e7fae7e0c7..ac318b7d8855 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -38,6 +38,7 @@ import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.apply; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedTableScan; @@ -321,4 +322,23 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() project(ImmutableMap.of("NON_NULL", expression("true")), node(ValuesNode.class))))))))))); } + + @Test + public void testRemoveDistinctFromSemiJoin() + { + assertPlan( + "SELECT orderkey FROM orders " + + "WHERE custkey " + + "IN (SELECT distinct custkey FROM customer)", + anyTree( + semiJoin("Source", "Filter", "Output", + anyTree(tableScan("orders", ImmutableMap.of("Source", "custkey"))), + anyNot(AggregationNode.class, + anyNot(AggregationNode.class, + tableScan("customer", ImmutableMap.of("Filter", "custkey")) + )) + ) + ) + ); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index d6af5bf0f880..f477adaa7de2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule.test; +import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; @@ -42,6 +43,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; @@ -75,12 +77,15 @@ public class PlanBuilder { private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; + private final Session session; + private final Map symbols = new HashMap<>(); - public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) + public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { this.idAllocator = idAllocator; this.metadata = metadata; + this.session = session; } public ValuesNode values(Symbol... columns) @@ -199,14 +204,31 @@ public ApplyNode apply(Assignments subqueryAssignments, List correlation return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation); } + public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol, Symbol semiJoinOutput) + { + return new SemiJoinNode(idAllocator.getNextId(), + source, + filteringSource, + sourceJoinSymbol, + filteringSourceJoinSymbol, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + public TableScanNode tableScan(List symbols, Map assignments) + { + TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); + return tableScan(tableHandle, symbols, assignments); + } + + public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) { Expression originalConstraint = null; return new TableScanNode( idAllocator.getNextId(), - new TableHandle( - new ConnectorId("testConnector"), - new TestingTableHandle()), + tableHandle, symbols, assignments, Optional.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 6b5890d70e9e..8d10877eeb0a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -83,7 +83,7 @@ public RuleAssert on(Function planProvider) { checkArgument(plan == null, "plan has already been set"); - PlanBuilder builder = new PlanBuilder(idAllocator, metadata); + PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session); plan = planProvider.apply(builder); symbols = builder.getSymbols(); return this; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java new file mode 100644 index 000000000000..86a34baad1e5 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; + +public class TestRemoveDistinctFromSemiJoin +{ + private final RuleTester tester = new RuleTester(); + + @Test + public void test() + throws Exception + { + tester.assertThat(new RemoveDistinctFromSemiJoin()) + .on(p -> { + Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT); + Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); + Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); + return p.semiJoin( + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(sourceKey), + ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))), + p.project(Assignments.of(filteringSourceKey, expression("x")), + p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) + .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) + .source( + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)), + ImmutableList.of(filteringSourceKey), + ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT))) + ) + .build()) + ), + sourceKey, filteringSourceKey, outputKey + ); + }) + .matches( + semiJoin("Source", "Filter", "Output", + tableScan("orders", ImmutableMap.of("Source", "custkey")), + project(tableScan("customer", ImmutableMap.of("Filter", "custkey"))) + ) + ); + } + + @Test + public void testDoesNotFire() + { + tester.assertThat(new RemoveDistinctFromSemiJoin()) + .on(p -> { + Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT); + Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); + Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); + return p.semiJoin( + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(sourceKey), + ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))), + p.project(Assignments.of(filteringSourceKey, expression("x")), + p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) + .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) + .addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT)) + .source(p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)), + ImmutableList.of(filteringSourceKey), + ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT))) + ) + .build()) + ), + sourceKey, filteringSourceKey, outputKey + ); + }).doesNotFire(); + } +}