diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index db5b3d1ef9a1b..fca17d4ab0fa5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -217,7 +217,6 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; @@ -252,7 +251,6 @@ import static io.prestosql.spi.type.TypeUtils.writeNativeValue; import static io.prestosql.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static io.prestosql.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; @@ -279,7 +277,6 @@ import static io.prestosql.util.SpatialJoinUtils.ST_WITHIN; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; -import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -1170,7 +1167,6 @@ private PhysicalOperation visitScanFilterAndProject( // if source is a table scan we fold it directly into the filter and project // otherwise we plan it as a normal operator Map sourceLayout; - Map sourceTypes; List columns = null; PhysicalOperation source = null; if (sourceNode instanceof TableScanNode) { @@ -1178,7 +1174,6 @@ private PhysicalOperation visitScanFilterAndProject( // extract the column handles and channel to type mapping sourceLayout = new LinkedHashMap<>(); - sourceTypes = new LinkedHashMap<>(); columns = new ArrayList<>(); int channel = 0; for (Symbol symbol : tableScanNode.getOutputSymbols()) { @@ -1187,9 +1182,6 @@ private PhysicalOperation visitScanFilterAndProject( Integer input = channel; sourceLayout.put(symbol, input); - Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); - sourceTypes.put(input, type); - channel++; } } @@ -1209,7 +1201,6 @@ else if (sourceNode instanceof SampleNode) { // plan source source = sourceNode.accept(this, context); sourceLayout = source.getLayout(); - sourceTypes = getInputTypes(source.getLayout(), source.getTypes()); } // build output mapping @@ -1220,27 +1211,24 @@ else if (sourceNode instanceof SampleNode) { } Map outputMappings = outputMappingsBuilder.build(); - // compiler uses inputs instead of symbols, so rewrite the expressions first - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Optional rewrittenFilter = filterExpression.map(symbolToInputRewriter::rewrite); - - List rewrittenProjections = new ArrayList<>(); + List projections = new ArrayList<>(); for (Symbol symbol : outputSymbols) { - rewrittenProjections.add(symbolToInputRewriter.rewrite(assignments.get(symbol))); + projections.add(assignments.get(symbol)); } - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( context.getSession(), metadata, sqlParser, - sourceTypes, - concat(rewrittenFilter.map(ImmutableList::of).orElse(ImmutableList.of()), rewrittenProjections), + context.getTypes(), + concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions()), emptyList(), - NOOP); + NOOP, + false); - Optional translatedFilter = rewrittenFilter.map(filter -> toRowExpression(filter, expressionTypes)); - List translatedProjections = rewrittenProjections.stream() - .map(expression -> toRowExpression(expression, expressionTypes)) + Optional translatedFilter = filterExpression.map(filter -> toRowExpression(filter, expressionTypes, sourceLayout)); + List translatedProjections = projections.stream() + .map(expression -> toRowExpression(expression, expressionTypes, sourceLayout)) .collect(toImmutableList()); try { @@ -1256,7 +1244,7 @@ else if (sourceNode instanceof SampleNode) { cursorProcessor, pageProcessor, columns, - getTypes(rewrittenProjections, expressionTypes), + getTypes(projections, expressionTypes), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -1269,7 +1257,7 @@ else if (sourceNode instanceof SampleNode) { context.getNextOperatorId(), planNodeId, pageProcessor, - getTypes(rewrittenProjections, expressionTypes), + getTypes(projections, expressionTypes), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -1281,19 +1269,9 @@ else if (sourceNode instanceof SampleNode) { } } - private RowExpression toRowExpression(Expression expression, Map, Type> types) - { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); - } - - private Map getInputTypes(Map layout, List types) + private RowExpression toRowExpression(Expression expression, Map, Type> types, Map layout) { - ImmutableMap.Builder inputTypes = ImmutableMap.builder(); - for (Integer input : ImmutableSet.copyOf(layout.values())) { - Type type = types.get(input); - inputTypes.put(input, type); - } - return inputTypes.build(); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true, layout); } @Override @@ -2058,20 +2036,17 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( { Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Map sourceTypes = joinSourcesLayout.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey()))); - - Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, - sourceTypes, - rewrittenFilter, + types, + filterExpression, emptyList(), /* parameters have already been replaced */ - NOOP); + NOOP, + false); - RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes); + RowExpression translatedFilter = toRowExpression(filterExpression, expressionTypes, joinSourcesLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } @@ -2603,7 +2578,7 @@ private AccumulatorFactory buildAccumulatorFactory( NOOP)) .build(); - LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes); + LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of()); Class lambdaProviderClass = compileLambdaProvider(lambda, metadata.getFunctionRegistry(), lambdaInterfaces.get(i)); try { lambdaProviders.add((LambdaProvider) constructorMethodHandle(lambdaProviderClass, ConnectorSession.class).invoke(session.toConnectorSession())); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java index 81b519f6a3a5c..5a5066b5721a9 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java @@ -26,7 +26,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.CallExpression; import io.prestosql.sql.relational.ConstantExpression; @@ -58,7 +57,7 @@ import static io.prestosql.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.prestosql.spi.function.OperatorType.NOT_EQUAL; import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static java.lang.Integer.min; import static java.util.Collections.emptyList; @@ -80,15 +79,13 @@ public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types) { Map symbolInput = new HashMap<>(); - Map inputTypes = new HashMap<>(); int inputId = 0; for (Entry entry : types.allTypes().entrySet()) { symbolInput.put(entry.getKey(), inputId); - inputTypes.put(inputId, entry.getValue()); inputId++; } - RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, inputTypes); - RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, inputTypes); + RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, types); + RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, types); RowExpression canonicalizedLeft = leftRowExpression.accept(CANONICALIZATION_VISITOR, null); RowExpression canonicalizedRight = rightRowExpression.accept(CANONICALIZATION_VISITOR, null); @@ -96,23 +93,20 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi return canonicalizedLeft.equals(canonicalizedRight); } - private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, Map inputTypes) + private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types) { - // replace qualified names with input references since row expressions do not support these - Expression expressionWithInputReferences = new SymbolToInputRewriter(symbolInput).rewrite(expression); - // determine the type of every expression - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, - inputTypes, - expressionWithInputReferences, + types, + expression, emptyList(), /* parameters have already been replaced */ WarningCollector.NOOP); // convert to row expression - return translate(expressionWithInputReferences, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + return translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, symbolInput); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java index 720182a33668c..fbf247765108f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/relational/SqlToRowExpressionTranslator.java @@ -31,6 +31,7 @@ import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; import io.prestosql.spi.type.VarcharType; +import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.relational.optimizer.ExpressionOptimizer; import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.ArithmeticUnaryExpression; @@ -144,12 +145,14 @@ public static RowExpression translate( FunctionRegistry functionRegistry, TypeManager typeManager, Session session, - boolean optimize) + boolean optimize, + Map layout) { Visitor visitor = new Visitor( functionKind, types, typeManager, + layout, session.getTimeZoneKey(), isLegacyRowFieldOrdinalAccessEnabled(session), SystemSessionProperties.isLegacyTimestamp(session)); @@ -171,6 +174,7 @@ private static class Visitor private final FunctionKind functionKind; private final Map, Type> types; private final TypeManager typeManager; + private final Map layout; private final TimeZoneKey timeZoneKey; private final boolean legacyRowFieldOrdinalAccess; @Deprecated @@ -180,6 +184,7 @@ private Visitor( FunctionKind functionKind, Map, Type> types, TypeManager typeManager, + Map layout, TimeZoneKey timeZoneKey, boolean legacyRowFieldOrdinalAccess, boolean isLegacyTimestamp) @@ -187,6 +192,7 @@ private Visitor( this.functionKind = functionKind; this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null")); this.typeManager = typeManager; + this.layout = layout; this.timeZoneKey = timeZoneKey; this.legacyRowFieldOrdinalAccess = legacyRowFieldOrdinalAccess; this.isLegacyTimestamp = isLegacyTimestamp; @@ -363,6 +369,11 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context) @Override protected RowExpression visitSymbolReference(SymbolReference node, Void context) { + Integer field = layout.get(Symbol.from(node)); + if (field != null) { + return field(field, getType(node)); + } + return new VariableReferenceExpression(node.getName(), getType(node)); } diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java index 439402947f456..418b2730462b7 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkScanFilterAndProjectOperator.java @@ -33,7 +33,6 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -81,7 +80,7 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.testing.TestingSplit.createLocalSplit; import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; @@ -203,10 +202,10 @@ private List createInputPages(List types) private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR); + return rowExpression("cast(varchar0 as bigint) % 2 = 0"); } if (type == BIGINT) { - return rowExpression("bigint0 % 2 = 0", BIGINT); + return rowExpression("bigint0 % 2 = 0"); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -216,32 +215,25 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression("bigint" + i + " + 5", type)); + builder.add(rowExpression("bigint" + i + " + 5")); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type)); + builder.add(rowExpression("concat(varchar" + i + ", 'foo')")); } } return builder.build(); } - private RowExpression rowExpression(String expression, Type type) + private RowExpression rowExpression(String value) { - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < columnCount; i++) { - builder.put(i, type); - } - Map types = builder.build(); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout); } private static Page createPage(List types, int positions, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index 00d181faa4b62..2c81971cec695 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -66,7 +66,6 @@ import io.prestosql.sql.gen.ExpressionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -129,7 +128,7 @@ import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressionsWithSymbols; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; @@ -168,17 +167,17 @@ public final class FunctionAssertions private static final Page ZERO_CHANNEL_PAGE = new Page(1); - private static final Map INPUT_TYPES = ImmutableMap.builder() - .put(0, BIGINT) - .put(1, VARCHAR) - .put(2, DOUBLE) - .put(3, BOOLEAN) - .put(4, BIGINT) - .put(5, VARCHAR) - .put(6, VARCHAR) - .put(7, TIMESTAMP_WITH_TIME_ZONE) - .put(8, VARBINARY) - .put(9, INTEGER) + private static final Map INPUT_TYPES = ImmutableMap.builder() + .put(new Symbol("bound_long"), BIGINT) + .put(new Symbol("bound_string"), VARCHAR) + .put(new Symbol("bound_double"), DOUBLE) + .put(new Symbol("bound_boolean"), BOOLEAN) + .put(new Symbol("bound_timestamp"), BIGINT) + .put(new Symbol("bound_pattern"), VARCHAR) + .put(new Symbol("bound_null_string"), VARCHAR) + .put(new Symbol("bound_timestamp_with_timezone"), TIMESTAMP_WITH_TIME_ZONE) + .put(new Symbol("bound_binary_literal"), VARBINARY) + .put(new Symbol("bound_integer"), INTEGER) .build(); private static final Map INPUT_MAPPING = ImmutableMap.builder() @@ -630,16 +629,15 @@ private List executeProjectionWithAll(String projection, Type expectedTy private RowExpression toRowExpression(Session session, Expression projectionExpression) { - Expression translatedProjection = new SymbolToInputRewriter(INPUT_MAPPING).rewrite(projectionExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, SQL_PARSER, - INPUT_TYPES, - ImmutableList.of(translatedProjection), + TypeProvider.copyOf(INPUT_TYPES), + projectionExpression, ImmutableList.of(), WarningCollector.NOOP); - return toRowExpression(translatedProjection, expressionTypes); + return toRowExpression(projectionExpression, expressionTypes, INPUT_MAPPING); } private Object selectSingleValue(OperatorFactory operatorFactory, Type type, Session session) @@ -955,9 +953,9 @@ private static SourceOperatorFactory compileScanFilterProject(Optional, Type> expressionTypes) + private RowExpression toRowExpression(Expression projection, Map, Type> expressionTypes, Map layout) { - return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false, layout); } private static Page getAtMostOnePage(Operator operator, Page sourcePage) diff --git a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java index b8071a11f17a0..349e453dd3375 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestSqlToRowExpressionTranslator.java @@ -92,7 +92,7 @@ private RowExpression translateAndOptimize(Expression expression) private RowExpression translateAndOptimize(Expression expression, Map, Type> types) { - return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true, ImmutableMap.of()); } private Expression simplifyExpression(Expression expression) diff --git a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java index 2a556a9d5767f..597b82f961699 100644 --- a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java @@ -30,7 +30,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -66,7 +65,7 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.stream.Collectors.toList; @@ -151,10 +150,10 @@ public List> columnOriented() private RowExpression getFilter(Type type) { if (type == VARCHAR) { - return rowExpression("cast(varchar0 as bigint) % 2 = 0", VARCHAR); + return rowExpression("cast(varchar0 as bigint) % 2 = 0"); } if (type == BIGINT) { - return rowExpression("bigint0 % 2 = 0", BIGINT); + return rowExpression("bigint0 % 2 = 0"); } throw new IllegalArgumentException("filter not supported for type : " + type); } @@ -164,32 +163,25 @@ private List getProjections(Type type) ImmutableList.Builder builder = ImmutableList.builder(); if (type == BIGINT) { for (int i = 0; i < columnCount; i++) { - builder.add(rowExpression("bigint" + i + " + 5", type)); + builder.add(rowExpression("bigint" + i + " + 5")); } } else if (type == VARCHAR) { for (int i = 0; i < columnCount; i++) { // alternatively use identity expression rowExpression("varchar" + i, type) or // rowExpression("substr(varchar" + i + ", 1, 1)", type) - builder.add(rowExpression("concat(varchar" + i + ", 'foo')", type)); + builder.add(rowExpression("concat(varchar" + i + ", 'foo')")); } } return builder.build(); } - private RowExpression rowExpression(String expression, Type type) + private RowExpression rowExpression(String value) { - SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); - Expression inputReferenceExpression = symbolToInputRewriter.rewrite(createExpression(expression, METADATA, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (int i = 0; i < columnCount; i++) { - builder.put(i, type); - } - Map types = builder.build(); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true, sourceLayout); } private static Page createPage(List types, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java index e6402826a62bc..7c25281bb5582 100644 --- a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java @@ -30,7 +30,6 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.SymbolToInputRewriter; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -71,14 +70,13 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DecimalType.createDecimalType; import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.testing.TestingConnectorSession.SESSION; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.math.BigInteger.ONE; import static java.math.BigInteger.ZERO; import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toMap; import static org.openjdk.jmh.annotations.Scope.Thread; @State(Scope.Thread) @@ -611,15 +609,12 @@ protected void setDoubleMaxValue(double doubleMaxValue) this.doubleMaxValue = doubleMaxValue; } - private RowExpression rowExpression(String expression) + private RowExpression rowExpression(String value) { - Expression inputReferenceExpression = new SymbolToInputRewriter(sourceLayout).rewrite(createExpression(expression, metadata, TypeProvider.copyOf(symbolTypes))); + Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); - Map types = sourceLayout.entrySet().stream() - .collect(toMap(Map.Entry::getValue, entry -> symbolTypes.get(entry.getKey()))); - - Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, metadata, SQL_PARSER, types, inputReferenceExpression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, metadata, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true, sourceLayout); } private Object generateRandomValue(Type type)