Skip to content

Commit

Permalink
Remove unnecessary expression translations
Browse files Browse the repository at this point in the history
In order to translate expression to row expressions, the
code was first replacing all symbol references with field
references for the corresponding ordinal inputs.

This is unnecessary, as the translation can be done on demand
as the expression is translated to a row expression.
  • Loading branch information
martint committed Mar 9, 2019
1 parent 64f93a9 commit 313d052
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.DiscreteDomain.integers;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Iterables.getOnlyElement;
Expand Down Expand Up @@ -252,7 +251,6 @@
import static io.prestosql.spi.type.TypeUtils.writeNativeValue;
import static io.prestosql.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
import static io.prestosql.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider;
import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression;
import static io.prestosql.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
Expand All @@ -279,7 +277,6 @@
import static io.prestosql.util.SpatialJoinUtils.ST_WITHIN;
import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons;
import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.IntStream.range;
Expand Down Expand Up @@ -1170,15 +1167,13 @@ private PhysicalOperation visitScanFilterAndProject(
// if source is a table scan we fold it directly into the filter and project
// otherwise we plan it as a normal operator
Map<Symbol, Integer> sourceLayout;
Map<Integer, Type> sourceTypes;
List<ColumnHandle> columns = null;
PhysicalOperation source = null;
if (sourceNode instanceof TableScanNode) {
TableScanNode tableScanNode = (TableScanNode) sourceNode;

// extract the column handles and channel to type mapping
sourceLayout = new LinkedHashMap<>();
sourceTypes = new LinkedHashMap<>();
columns = new ArrayList<>();
int channel = 0;
for (Symbol symbol : tableScanNode.getOutputSymbols()) {
Expand All @@ -1187,9 +1182,6 @@ private PhysicalOperation visitScanFilterAndProject(
Integer input = channel;
sourceLayout.put(symbol, input);

Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol));
sourceTypes.put(input, type);

channel++;
}
}
Expand All @@ -1209,7 +1201,6 @@ else if (sourceNode instanceof SampleNode) {
// plan source
source = sourceNode.accept(this, context);
sourceLayout = source.getLayout();
sourceTypes = getInputTypes(source.getLayout(), source.getTypes());
}

// build output mapping
Expand All @@ -1220,27 +1211,24 @@ else if (sourceNode instanceof SampleNode) {
}
Map<Symbol, Integer> outputMappings = outputMappingsBuilder.build();

// compiler uses inputs instead of symbols, so rewrite the expressions first
SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout);
Optional<Expression> rewrittenFilter = filterExpression.map(symbolToInputRewriter::rewrite);

List<Expression> rewrittenProjections = new ArrayList<>();
List<Expression> projections = new ArrayList<>();
for (Symbol symbol : outputSymbols) {
rewrittenProjections.add(symbolToInputRewriter.rewrite(assignments.get(symbol)));
projections.add(assignments.get(symbol));
}

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypesFromInput(
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
context.getSession(),
metadata,
sqlParser,
sourceTypes,
concat(rewrittenFilter.map(ImmutableList::of).orElse(ImmutableList.of()), rewrittenProjections),
context.getTypes(),
concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions()),
emptyList(),
NOOP);
NOOP,
false);

Optional<RowExpression> translatedFilter = rewrittenFilter.map(filter -> toRowExpression(filter, expressionTypes));
List<RowExpression> translatedProjections = rewrittenProjections.stream()
.map(expression -> toRowExpression(expression, expressionTypes))
Optional<RowExpression> translatedFilter = filterExpression.map(filter -> toRowExpression(filter, expressionTypes, sourceLayout));
List<RowExpression> translatedProjections = projections.stream()
.map(expression -> toRowExpression(expression, expressionTypes, sourceLayout))
.collect(toImmutableList());

try {
Expand All @@ -1256,7 +1244,7 @@ else if (sourceNode instanceof SampleNode) {
cursorProcessor,
pageProcessor,
columns,
getTypes(rewrittenProjections, expressionTypes),
getTypes(projections, expressionTypes),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

Expand All @@ -1269,7 +1257,7 @@ else if (sourceNode instanceof SampleNode) {
context.getNextOperatorId(),
planNodeId,
pageProcessor,
getTypes(rewrittenProjections, expressionTypes),
getTypes(projections, expressionTypes),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));

Expand All @@ -1281,19 +1269,9 @@ else if (sourceNode instanceof SampleNode) {
}
}

private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types)
{
return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true);
}

private Map<Integer, Type> getInputTypes(Map<Symbol, Integer> layout, List<Type> types)
private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout)
{
ImmutableMap.Builder<Integer, Type> inputTypes = ImmutableMap.builder();
for (Integer input : ImmutableSet.copyOf(layout.values())) {
Type type = types.get(input);
inputTypes.put(input, type);
}
return inputTypes.build();
return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true, layout);
}

@Override
Expand Down Expand Up @@ -2058,20 +2036,17 @@ private JoinFilterFunctionFactory compileJoinFilterFunction(
{
Map<Symbol, Integer> joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout);

Map<Integer, Type> sourceTypes = joinSourcesLayout.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey())));

Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression);
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypesFromInput(
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
session,
metadata,
sqlParser,
sourceTypes,
rewrittenFilter,
types,
filterExpression,
emptyList(), /* parameters have already been replaced */
NOOP);
NOOP,
false);

RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes);
RowExpression translatedFilter = toRowExpression(filterExpression, expressionTypes, joinSourcesLayout);
return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size());
}

Expand Down Expand Up @@ -2603,7 +2578,7 @@ private AccumulatorFactory buildAccumulatorFactory(
NOOP))
.build();

LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes);
LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of());
Class<? extends LambdaProvider> lambdaProviderClass = compileLambdaProvider(lambda, metadata.getFunctionRegistry(), lambdaInterfaces.get(i));
try {
lambdaProviders.add((LambdaProvider) constructorMethodHandle(lambdaProviderClass, ConnectorSession.class).invoke(session.toConnectorSession()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.prestosql.spi.type.Type;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolToInputRewriter;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.relational.CallExpression;
import io.prestosql.sql.relational.ConstantExpression;
Expand Down Expand Up @@ -58,7 +57,7 @@
import static io.prestosql.spi.function.OperatorType.LESS_THAN_OR_EQUAL;
import static io.prestosql.spi.function.OperatorType.NOT_EQUAL;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate;
import static java.lang.Integer.min;
import static java.util.Collections.emptyList;
Expand All @@ -80,39 +79,34 @@ public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser)
public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types)
{
Map<Symbol, Integer> symbolInput = new HashMap<>();
Map<Integer, Type> inputTypes = new HashMap<>();
int inputId = 0;
for (Entry<Symbol, Type> entry : types.allTypes().entrySet()) {
symbolInput.put(entry.getKey(), inputId);
inputTypes.put(inputId, entry.getValue());
inputId++;
}
RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, inputTypes);
RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, inputTypes);
RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, types);
RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, types);

RowExpression canonicalizedLeft = leftRowExpression.accept(CANONICALIZATION_VISITOR, null);
RowExpression canonicalizedRight = rightRowExpression.accept(CANONICALIZATION_VISITOR, null);

return canonicalizedLeft.equals(canonicalizedRight);
}

private RowExpression toRowExpression(Session session, Expression expression, Map<Symbol, Integer> symbolInput, Map<Integer, Type> inputTypes)
private RowExpression toRowExpression(Session session, Expression expression, Map<Symbol, Integer> symbolInput, TypeProvider types)
{
// replace qualified names with input references since row expressions do not support these
Expression expressionWithInputReferences = new SymbolToInputRewriter(symbolInput).rewrite(expression);

// determine the type of every expression
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypesFromInput(
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
session,
metadata,
sqlParser,
inputTypes,
expressionWithInputReferences,
types,
expression,
emptyList(), /* parameters have already been replaced */
WarningCollector.NOOP);

// convert to row expression
return translate(expressionWithInputReferences, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false);
return translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, symbolInput);
}

private static class CanonicalizationVisitor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.prestosql.spi.type.TypeManager;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.relational.optimizer.ExpressionOptimizer;
import io.prestosql.sql.tree.ArithmeticBinaryExpression;
import io.prestosql.sql.tree.ArithmeticUnaryExpression;
Expand Down Expand Up @@ -144,12 +145,14 @@ public static RowExpression translate(
FunctionRegistry functionRegistry,
TypeManager typeManager,
Session session,
boolean optimize)
boolean optimize,
Map<Symbol, Integer> layout)
{
Visitor visitor = new Visitor(
functionKind,
types,
typeManager,
layout,
session.getTimeZoneKey(),
isLegacyRowFieldOrdinalAccessEnabled(session),
SystemSessionProperties.isLegacyTimestamp(session));
Expand All @@ -171,6 +174,7 @@ private static class Visitor
private final FunctionKind functionKind;
private final Map<NodeRef<Expression>, Type> types;
private final TypeManager typeManager;
private final Map<Symbol, Integer> layout;
private final TimeZoneKey timeZoneKey;
private final boolean legacyRowFieldOrdinalAccess;
@Deprecated
Expand All @@ -180,13 +184,15 @@ private Visitor(
FunctionKind functionKind,
Map<NodeRef<Expression>, Type> types,
TypeManager typeManager,
Map<Symbol, Integer> layout,
TimeZoneKey timeZoneKey,
boolean legacyRowFieldOrdinalAccess,
boolean isLegacyTimestamp)
{
this.functionKind = functionKind;
this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null"));
this.typeManager = typeManager;
this.layout = layout;
this.timeZoneKey = timeZoneKey;
this.legacyRowFieldOrdinalAccess = legacyRowFieldOrdinalAccess;
this.isLegacyTimestamp = isLegacyTimestamp;
Expand Down Expand Up @@ -363,6 +369,11 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context)
@Override
protected RowExpression visitSymbolReference(SymbolReference node, Void context)
{
Integer field = layout.get(Symbol.from(node));
if (field != null) {
return field(field, getType(node));
}

return new VariableReferenceExpression(node.getName(), getType(node));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import io.prestosql.sql.gen.PageFunctionCompiler;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolToInputRewriter;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.relational.RowExpression;
Expand Down Expand Up @@ -81,7 +80,7 @@
import static io.prestosql.operator.scalar.FunctionAssertions.createExpression;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static io.prestosql.testing.TestingSplit.createLocalSplit;
import static java.util.Collections.emptyList;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -203,10 +202,10 @@ private List<Page> createInputPages(List<Type> types)
private RowExpression getFilter(Type type)
{
if (type == VARCHAR) {
return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR);
return rowExpression("cast(varchar0 as bigint) % 2 = 0");
}
if (type == BIGINT) {
return rowExpression("bigint0 % 2 = 0", BIGINT);
return rowExpression("bigint0 % 2 = 0");
}
throw new IllegalArgumentException("filter not supported for type : " + type);
}
Expand All @@ -216,32 +215,25 @@ private List<RowExpression> getProjections(Type type)
ImmutableList.Builder<RowExpression> builder = ImmutableList.builder();
if (type == BIGINT) {
for (int i = 0; i < columnCount; i++) {
builder.add(rowExpression("bigint" + i + " + 5", type));
builder.add(rowExpression("bigint" + i + " + 5"));
}
}
else if (type == VARCHAR) {
for (int i = 0; i < columnCount; i++) {
// alternatively use identity expression rowExpression("varchar" + i, type) or
// rowExpression("substr(varchar" + i + ", 1, 1)", type)
builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type));
builder.add(rowExpression("concat(varchar" + i + ", 'foo')"));
}
}
return builder.build();
}

private RowExpression rowExpression(String expression, Type type)
private RowExpression rowExpression(String value)
{
SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout);
Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes)));
Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes));

ImmutableMap.Builder<Integer, Type> builder = ImmutableMap.builder();
for (int i = 0; i < columnCount; i++) {
builder.put(i, type);
}
Map<Integer, Type> types = builder.build();

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP);
return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true);
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP);
return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout);
}

private static Page createPage(List<? extends Type> types, int positions, boolean dictionary)
Expand Down
Loading

0 comments on commit 313d052

Please sign in to comment.