diff --git a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java index 42b10921922c..d89807b5a956 100644 --- a/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/transaction/TransactionManager.java @@ -390,6 +390,12 @@ private synchronized CatalogMetadata getTransactionCatalogMetadata(ConnectorId c CatalogMetadata catalogMetadata = this.catalogMetadata.get(connectorId); if (catalogMetadata == null) { Catalog catalog = catalogsByConnectorId.get(connectorId); + if (catalog == null) { + // For rule tester, getConnectorId has never been called because the plan was generated outside this transaction. + getConnectorId(connectorId.getCatalogName()); + catalog = catalogsByConnectorId.get(connectorId); + } + verify(catalog != null, "Unknown connectorId: %s", connectorId); ConnectorTransactionMetadata metadata = createConnectorTransactionMetadata(catalog.getConnectorId(), catalog); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index e7fc3fb5d344..f477adaa7de2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.QualifiedObjectName; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; @@ -66,13 +65,11 @@ import java.util.Map; import java.util.Optional; import java.util.function.Consumer; -import java.util.stream.Collectors; import java.util.stream.Stream; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; @@ -222,35 +219,22 @@ public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol s public TableScanNode tableScan(List symbols, Map assignments) { - Expression originalConstraint = null; - return new TableScanNode(idAllocator.getNextId(), - new TableHandle( - new ConnectorId("testConnector"), - new TestingTableHandle()), - symbols, - assignments, - Optional.empty(), - TupleDomain.all(), - originalConstraint); + TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); + return tableScan(tableHandle, symbols, assignments); } - public TableScanNode tableScan(String tableName, Map symbolMap) + public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) { - QualifiedObjectName name = new QualifiedObjectName(session.getCatalog().get(), session.getSchema().get(), tableName); - Optional tableHandle = metadata.getTableHandle(session, name); - verify(tableHandle.isPresent(), "Unknown table %s", name); - Map columns = metadata.getColumnHandles(session, tableHandle.get()); - Map assignments = symbolMap.entrySet().stream() - .filter(entry -> columns.containsKey(entry.getValue())) - .collect(Collectors.toMap(entry -> new Symbol(entry.getKey()), entry -> columns.get(entry.getValue()))); - List symbols = ImmutableList.copyOf(assignments.keySet()); - return new TableScanNode(idAllocator.getNextId(), - tableHandle.get(), + Expression originalConstraint = null; + return new TableScanNode( + idAllocator.getNextId(), + tableHandle, symbols, assignments, Optional.empty(), TupleDomain.all(), - null); + originalConstraint + ); } public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, Symbol deleteRowId) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 640855c30c5e..8d10877eeb0a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -30,7 +30,6 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; -import com.facebook.presto.transaction.TransactionId; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -84,8 +83,6 @@ public RuleAssert on(Function planProvider) { checkArgument(plan == null, "plan has already been set"); - TransactionId transactionId = transactionManager.beginTransaction(TransactionManager.DEFAULT_ISOLATION, false, false); - this.session = session.beginTransactionId(transactionId, transactionManager, accessControl); PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session); plan = planProvider.apply(builder); symbols = builder.getSymbols(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java index 93467c47b67f..86a34baad1e5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveDistinctFromSemiJoin.java @@ -13,11 +13,15 @@ */ package com.facebook.presto.sql.planner.iterative.rule.test; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -27,6 +31,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; public class TestRemoveDistinctFromSemiJoin { @@ -42,11 +47,23 @@ public void test() Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); return p.semiJoin( - p.tableScan("orders", ImmutableMap.of("custkey", "custkey")), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(sourceKey), + ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))), p.project(Assignments.of(filteringSourceKey, expression("x")), p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) - .source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey"))) + .source( + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)), + ImmutableList.of(filteringSourceKey), + ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT))) + ) .build()) ), sourceKey, filteringSourceKey, outputKey @@ -69,12 +86,23 @@ public void testDoesNotFire() Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT); Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT); return p.semiJoin( - p.tableScan("orders", ImmutableMap.of("custkey", "custkey")), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(sourceKey), + ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))), p.project(Assignments.of(filteringSourceKey, expression("x")), p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE) .groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey))) .addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT)) - .source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey"))) + .source(p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)), + ImmutableList.of(filteringSourceKey), + ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT))) + ) .build()) ), sourceKey, filteringSourceKey, outputKey