Skip to content

Commit

Permalink
Update to use ruleAssert in transaction like #7979
Browse files Browse the repository at this point in the history
  • Loading branch information
hellium01 committed May 19, 2017
1 parent 1d52497 commit 28fb6ba
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,12 @@ private synchronized CatalogMetadata getTransactionCatalogMetadata(ConnectorId c
CatalogMetadata catalogMetadata = this.catalogMetadata.get(connectorId);
if (catalogMetadata == null) {
Catalog catalog = catalogsByConnectorId.get(connectorId);
if (catalog == null) {
// For rule tester, getConnectorId has never been called because the plan was generated outside this transaction.
getConnectorId(connectorId.getCatalogName());
catalog = catalogsByConnectorId.get(connectorId);
}

verify(catalog != null, "Unknown connectorId: %s", connectorId);

ConnectorTransactionMetadata metadata = createConnectorTransactionMetadata(catalog.getConnectorId(), catalog);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.facebook.presto.Session;
import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.QualifiedObjectName;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.spi.ColumnHandle;
Expand Down Expand Up @@ -66,13 +65,11 @@
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;

Expand Down Expand Up @@ -222,35 +219,22 @@ public SemiJoinNode semiJoin(PlanNode source, PlanNode filteringSource, Symbol s

public TableScanNode tableScan(List<Symbol> symbols, Map<Symbol, ColumnHandle> assignments)
{
Expression originalConstraint = null;
return new TableScanNode(idAllocator.getNextId(),
new TableHandle(
new ConnectorId("testConnector"),
new TestingTableHandle()),
symbols,
assignments,
Optional.empty(),
TupleDomain.all(),
originalConstraint);
TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle());
return tableScan(tableHandle, symbols, assignments);
}

public TableScanNode tableScan(String tableName, Map<String, String> symbolMap)
public TableScanNode tableScan(TableHandle tableHandle, List<Symbol> symbols, Map<Symbol, ColumnHandle> assignments)
{
QualifiedObjectName name = new QualifiedObjectName(session.getCatalog().get(), session.getSchema().get(), tableName);
Optional<TableHandle> tableHandle = metadata.getTableHandle(session, name);
verify(tableHandle.isPresent(), "Unknown table %s", name);
Map<String, ColumnHandle> columns = metadata.getColumnHandles(session, tableHandle.get());
Map<Symbol, ColumnHandle> assignments = symbolMap.entrySet().stream()
.filter(entry -> columns.containsKey(entry.getValue()))
.collect(Collectors.toMap(entry -> new Symbol(entry.getKey()), entry -> columns.get(entry.getValue())));
List<Symbol> symbols = ImmutableList.copyOf(assignments.keySet());
return new TableScanNode(idAllocator.getNextId(),
tableHandle.get(),
Expression originalConstraint = null;
return new TableScanNode(
idAllocator.getNextId(),
tableHandle,
symbols,
assignments,
Optional.empty(),
TupleDomain.all(),
null);
originalConstraint
);
}

public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, Symbol deleteRowId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.planPrinter.PlanPrinter;
import com.facebook.presto.transaction.TransactionId;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.collect.ImmutableSet;

Expand Down Expand Up @@ -84,8 +83,6 @@ public RuleAssert on(Function<PlanBuilder, PlanNode> planProvider)
{
checkArgument(plan == null, "plan has already been set");

TransactionId transactionId = transactionManager.beginTransaction(TransactionManager.DEFAULT_ISOLATION, false, false);
this.session = session.beginTransactionId(transactionId, transactionManager, accessControl);
PlanBuilder builder = new PlanBuilder(idAllocator, metadata, session);
plan = planProvider.apply(builder);
symbols = builder.getSymbols();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
*/
package com.facebook.presto.sql.planner.iterative.rule.test;

import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.rule.RemoveDistinctFromSemiJoin;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.tpch.TpchColumnHandle;
import com.facebook.presto.tpch.TpchTableHandle;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
Expand All @@ -27,6 +31,7 @@
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR;

public class TestRemoveDistinctFromSemiJoin
{
Expand All @@ -42,11 +47,23 @@ public void test()
Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT);
Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT);
return p.semiJoin(
p.tableScan("orders", ImmutableMap.of("custkey", "custkey")),
p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)),
ImmutableList.of(sourceKey),
ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))),
p.project(Assignments.of(filteringSourceKey, expression("x")),
p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE)
.groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey)))
.source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey")))
.source(
p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)),
ImmutableList.of(filteringSourceKey),
ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT)))
)
.build())
),
sourceKey, filteringSourceKey, outputKey
Expand All @@ -69,12 +86,23 @@ public void testDoesNotFire()
Symbol filteringSourceKey = p.symbol("custkey_1", BigintType.BIGINT);
Symbol outputKey = p.symbol("orderkey", BigintType.BIGINT);
return p.semiJoin(
p.tableScan("orders", ImmutableMap.of("custkey", "custkey")),
p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)),
ImmutableList.of(sourceKey),
ImmutableMap.of(sourceKey, new TpchColumnHandle("custkey", BIGINT))),
p.project(Assignments.of(filteringSourceKey, expression("x")),
p.aggregation(ab -> ab.step(AggregationNode.Step.SINGLE)
.groupingSets(ImmutableList.of(ImmutableList.of(filteringSourceKey)))
.addAggregation(p.symbol("max", BigintType.BIGINT), expression("max(custkey_1)"), ImmutableList.of(BIGINT))
.source(p.tableScan("customer", ImmutableMap.of("custkey_1", "custkey")))
.source(p.tableScan(
new TableHandle(
new ConnectorId("local"),
new TpchTableHandle("local", "customer", TINY_SCALE_FACTOR)),
ImmutableList.of(filteringSourceKey),
ImmutableMap.of(filteringSourceKey, new TpchColumnHandle("custkey", BIGINT)))
)
.build())
),
sourceKey, filteringSourceKey, outputKey
Expand Down

0 comments on commit 28fb6ba

Please sign in to comment.