From 1d524975a763277cdad9d2659c0c65db7434712d Mon Sep 17 00:00:00 2001 From: Yi He Date: Tue, 11 Apr 2017 18:33:20 -0700 Subject: [PATCH 1/2] Optimize distinct from semijoin --- .../presto/sql/planner/PlanOptimizers.java | 2 + .../rule/RemoveDistinctFromSemiJoin.java | 85 +++++++++++++++++++ .../sql/planner/TestLogicalPlanner.java | 20 +++++ .../iterative/rule/test/PlanBuilder.java | 48 +++++++++-- .../iterative/rule/test/RuleAssert.java | 5 +- .../test/TestRemoveDistinctFromSemiJoin.java | 84 ++++++++++++++++++ 6 files changed, 238 insertions(+), 6 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 0dc55c07e592..ba443b323087 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -41,6 +41,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange; import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion; +import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; @@ -208,6 +209,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea // add UnaliasSymbolReferences when it's ported new RemoveRedundantIdentityProjections(), new SwapAdjacentWindowsBySpecifications(), + new RemoveDistinctFromSemiJoin(), new MergeAdjacentWindows())), inlineProjections, new PruneUnreferencedOutputs(), // Make sure to run this at the end to help clean the plan for logging/execution and not remove info that other optimizers might need at an earlier point diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java new file mode 100644 index 000000000000..86d6ad316dbb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveDistinctFromSemiJoin.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; + +/** + * Remove distinct node from following path: + *
+ *     SemiJoinNode
+ *        - Source
+ *        - FilteringSource
+ *          - ProjectNode
+ *              - AggregationNode
+ * 
+ */ +public class RemoveDistinctFromSemiJoin + implements Rule +{ + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (node instanceof SemiJoinNode) { + PlanNode filteringSource = lookup.resolve(((SemiJoinNode) node).getFilteringSource()); + Optional result = rewriteUpstream(filteringSource, lookup); + if (result.isPresent()) { + return Optional.of(node.replaceChildren(ImmutableList.of(lookup.resolve(((SemiJoinNode) node).getSource()), result.get()))); + } + } + return Optional.empty(); + } + + private Optional rewriteUpstream(PlanNode node, Lookup lookup) + { + if (node instanceof AggregationNode) { + AggregationNode aggregationNode = (AggregationNode) node; + if (isDistinctNode(aggregationNode)) { + return Optional.of(lookup.resolve(aggregationNode.getSource())); + } + } + else if (node instanceof ProjectNode) { + Optional result = rewriteUpstream(lookup.resolve(((ProjectNode) node).getSource()), lookup); + if (result.isPresent()) { + return Optional.of(node.replaceChildren(ImmutableList.of(result.get()))); + } + } + return Optional.empty(); + } + + private boolean isDistinctNode(AggregationNode aggregationNode) + { + List outputSymbols = aggregationNode.getOutputSymbols(); + List groupingKeys = aggregationNode.getGroupingKeys(); + return aggregationNode.getStep() == SINGLE + && outputSymbols.size() == 1 + && groupingKeys.size() == 1 + && outputSymbols.get(0).equals(groupingKeys.get(0)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 94e7fae7e0c7..ac318b7d8855 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -38,6 +38,7 @@ import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.apply; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedTableScan; @@ -321,4 +322,23 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() project(ImmutableMap.of("NON_NULL", expression("true")), node(ValuesNode.class))))))))))); } + + @Test + public void testRemoveDistinctFromSemiJoin() + { + assertPlan( + "SELECT orderkey FROM orders " + + "WHERE custkey " + + "IN (SELECT distinct custkey FROM customer)", + anyTree( + semiJoin("Source", "Filter", "Output", + anyTree(tableScan("orders", ImmutableMap.of("Source", "custkey"))), + anyNot(AggregationNode.class, + anyNot(AggregationNode.class, + tableScan("customer", ImmutableMap.of("Filter", "custkey")) + )) + ) + ) + ); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index d6af5bf0f880..e7fc3fb5d344 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule.test; +import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.QualifiedObjectName; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; @@ -42,6 +44,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; @@ -63,11 +66,13 @@ import java.util.Map; import java.util.Optional; import java.util.function.Consumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; @@ -75,12 +80,15 @@ public class PlanBuilder { private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; + private final Session session; + private final Map symbols = new HashMap<>(); - public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) + public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { this.idAllocator = idAllocator; this.metadata = metadata; + this.session = session; } public ValuesNode values(Symbol... columns) @@ -199,11 +207,23 @@ public ApplyNode apply(Assignments subqueryAssignments, List correlation return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation); } + public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol, Symbol semiJoinOutput) + { + return new SemiJoinNode(idAllocator.getNextId(), + source, + filteringSource, + sourceJoinSymbol, + filteringSourceJoinSymbol, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + public TableScanNode tableScan(List symbols, Map assignments) { Expression originalConstraint = null; - return new TableScanNode( - idAllocator.getNextId(), + return new TableScanNode(idAllocator.getNextId(), new TableHandle( new ConnectorId("testConnector"), new TestingTableHandle()), @@ -211,8 +231,26 @@ public TableScanNode tableScan(List symbols, Map a assignments, Optional.empty(), TupleDomain.all(), - originalConstraint - ); + originalConstraint); + } + + public TableScanNode tableScan(String tableName, Map symbolMap) + { + QualifiedObjectName name = new QualifiedObjectName(session.getCatalog().get(), session.getSchema().get(), tableName); + Optional tableHandle = metadata.getTableHandle(session, name); + verify(tableHandle.isPresent(), "Unknown table %s", name); + Map columns = metadata.getColumnHandles(session, tableHandle.get()); + Map assignments = symbolMap.entrySet().stream() + .filter(entry -> columns.containsKey(entry.getValue())) + .collect(Collectors.toMap(entry -> new Symbol(entry.getKey()), entry -> columns.get(entry.getValue()))); + List symbols = ImmutableList.copyOf(assignments.keySet()); + return new TableScanNode(idAllocator.getNextId(), + tableHandle.get(), + symbols, + assignments, + Optional.empty(), + TupleDomain.all(), + null); } public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, Symbol deleteRowId) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 6b5890d70e9e..640855c30c5e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -30,6 +30,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; +import com.facebook.presto.transaction.TransactionId; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -83,7 +84,9 @@ public RuleAssert on(Function planProvider) { checkArgument(plan == null, "plan has already been set"); - PlanBuilder builder = new PlanBuilder(idAllocator, metadata); + TransactionId transactionId = transactionManager.beginTransaction(TransactionManager.DEFAULT_ISOLATION, false, false); + this.session = session.beginTransactionId(transactionId, transactionManager, accessControl); + PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session); plan = planProvider.apply(builder); symbols = builder.getSymbols(); return this; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java new file mode 100644 index 000000000000..93467c47b67f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; + +public class TestRemoveDistinctFromSemiJoin +{ + private final RuleTester tester = new RuleTester(); + + @Test + public void test() + throws Exception + { + tester.assertThat(new RemoveDistinctFromSemiJoin()) + .on(p -> { + Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT); + Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); + Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); + return p.semiJoin( + p.tableScan("orders", ImmutableMap.of("custkey", "custkey")), + p.project(Assignments.of(filteringSourceKey, expression("x")), + p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) + .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) + .source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey"))) + .build()) + ), + sourceKey, filteringSourceKey, outputKey + ); + }) + .matches( + semiJoin("Source", "Filter", "Output", + tableScan("orders", ImmutableMap.of("Source", "custkey")), + project(tableScan("customer", ImmutableMap.of("Filter", "custkey"))) + ) + ); + } + + @Test + public void testDoesNotFire() + { + tester.assertThat(new RemoveDistinctFromSemiJoin()) + .on(p -> { + Symbol sourceKey = p.symbol("custkey", BigintType.BIGINT); + Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); + Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); + return p.semiJoin( + p.tableScan("orders", ImmutableMap.of("custkey", "custkey")), + p.project(Assignments.of(filteringSourceKey, expression("x")), + p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) + .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) + .addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT)) + .source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey"))) + .build()) + ), + sourceKey, filteringSourceKey, outputKey + ); + }).doesNotFire(); + } +} From 28fb6ba76a8c3b316956cf7bba4272f8e47ec569 Mon Sep 17 00:00:00 2001 From: Yi He Date: Fri, 19 May 2017 15:50:40 -0700 Subject: [PATCH 2/2] Update to use ruleAssert in transaction like #7979 --- .../transaction/TransactionManager.java | 6 ++++ .../iterative/rule/test/PlanBuilder.java | 34 +++++------------- .../iterative/rule/test/RuleAssert.java | 3 -- .../test/TestRemoveDistinctFromSemiJoin.java | 36 ++++++++++++++++--- 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java index 42b10921922c..d89807b5a956 100644 --- a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java @@ -390,6 +390,12 @@ private synchronized CatalogMetadata getTransactionCatalogMetadata(ConnectorId c CatalogMetadata catalogMetadata = this.catalogMetadata.get(connectorId); if (catalogMetadata == null) { Catalog catalog = catalogsByConnectorId.get(connectorId); + if (catalog == null) { + // For rule tester, getConnectorId has never been called because the plan was generated outside this transaction. + getConnectorId(connectorId.getCatalogName()); + catalog = catalogsByConnectorId.get(connectorId); + } + verify(catalog != null, "Unknown connectorId: %s", connectorId); ConnectorTransactionMetadata metadata = createConnectorTransactionMetadata(catalog.getConnectorId(), catalog); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index e7fc3fb5d344..f477adaa7de2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.QualifiedObjectName; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; @@ -66,13 +65,11 @@ import java.util.Map; import java.util.Optional; import java.util.function.Consumer; -import java.util.stream.Collectors; import java.util.stream.Stream; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; @@ -222,35 +219,22 @@ public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol s public TableScanNode tableScan(List symbols, Map assignments) { - Expression originalConstraint = null; - return new TableScanNode(idAllocator.getNextId(), - new TableHandle( - new ConnectorId("testConnector"), - new TestingTableHandle()), - symbols, - assignments, - Optional.empty(), - TupleDomain.all(), - originalConstraint); + TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); + return tableScan(tableHandle, symbols, assignments); } - public TableScanNode tableScan(String tableName, Map symbolMap) + public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) { - QualifiedObjectName name = new QualifiedObjectName(session.getCatalog().get(), session.getSchema().get(), tableName); - Optional tableHandle = metadata.getTableHandle(session, name); - verify(tableHandle.isPresent(), "Unknown table %s", name); - Map columns = metadata.getColumnHandles(session, tableHandle.get()); - Map assignments = symbolMap.entrySet().stream() - .filter(entry -> columns.containsKey(entry.getValue())) - .collect(Collectors.toMap(entry -> new Symbol(entry.getKey()), entry -> columns.get(entry.getValue()))); - List symbols = ImmutableList.copyOf(assignments.keySet()); - return new TableScanNode(idAllocator.getNextId(), - tableHandle.get(), + Expression originalConstraint = null; + return new TableScanNode( + idAllocator.getNextId(), + tableHandle, symbols, assignments, Optional.empty(), TupleDomain.all(), - null); + originalConstraint + ); } public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, Symbol deleteRowId) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 640855c30c5e..8d10877eeb0a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -30,7 +30,6 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; -import com.facebook.presto.transaction.TransactionId; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -84,8 +83,6 @@ public RuleAssert on(Function planProvider) { checkArgument(plan == null, "plan has already been set"); - TransactionId transactionId = transactionManager.beginTransaction(TransactionManager.DEFAULT_ISOLATION, false, false); - this.session = session.beginTransactionId(transactionId, transactionManager, accessControl); PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session); plan = planProvider.apply(builder); symbols = builder.getSymbols(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java index 93467c47b67f..86a34baad1e5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java @@ -13,11 +13,15 @@ */ package com.facebook.presto.sql.planner.iterative.rule.test; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -27,6 +31,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; public class TestRemoveDistinctFromSemiJoin { @@ -42,11 +47,23 @@ public void test() Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); return p.semiJoin( - p.tableScan("orders", ImmutableMap.of("custkey", "custkey")), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(sourceKey), + ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))), p.project(Assignments.of(filteringSourceKey, expression("x")), p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) - .source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey"))) + .source( + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)), + ImmutableList.of(filteringSourceKey), + ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT))) + ) .build()) ), sourceKey, filteringSourceKey, outputKey @@ -69,12 +86,23 @@ public void testDoesNotFire() Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); return p.semiJoin( - p.tableScan("orders", ImmutableMap.of("custkey", "custkey")), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(sourceKey), + ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))), p.project(Assignments.of(filteringSourceKey, expression("x")), p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) .addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT)) - .source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey"))) + .source(p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)), + ImmutableList.of(filteringSourceKey), + ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT))) + ) .build()) ), sourceKey, filteringSourceKey, outputKey