From f04adeb8118bf9e44ab40d2e12beddd8f42a0c43 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Wed, 6 Mar 2019 17:36:47 -0800 Subject: [PATCH] Encapsulate expression type analysis in planner The new class is to facilitate obtaining the type of an expression and its subexpressions during planning (i.e., when interacting with IR expression) and to remove spurious dependencies on the SQL parser. It will eventually get removed when we split the AST from the IR and we encode the type directly into IR expressions. --- .../benchmark/AbstractOperatorBenchmark.java | 14 ++--- .../TestExtractSpatialInnerJoin.java | 2 +- .../TestExtractSpatialLeftJoin.java | 2 +- .../execution/SqlQueryExecution.java | 3 +- .../io/prestosql/server/ServerMainModule.java | 2 + .../sql/analyzer/ExpressionAnalyzer.java | 38 ------------ .../sql/analyzer/QueryExplainer.java | 3 +- .../sql/analyzer/StatementAnalyzer.java | 9 +-- .../planner/DesugarAtTimeZoneRewriter.java | 10 +-- .../sql/planner/DomainTranslator.java | 19 ++---- .../sql/planner/LocalExecutionPlanner.java | 50 +++------------ .../prestosql/sql/planner/LogicalPlanner.java | 15 +++-- .../prestosql/sql/planner/PlanOptimizers.java | 25 ++++---- .../prestosql/sql/planner/TypeAnalyzer.java | 61 +++++++++++++++++++ ...wPartialAggregationOverGroupIdRuleSet.java | 10 +-- .../iterative/rule/DesugarAtTimeZone.java | 12 ++-- .../iterative/rule/ExtractSpatialJoins.java | 54 +++++++--------- .../rule/PushPredicateIntoTableScan.java | 25 +++----- .../iterative/rule/SimplifyExpressions.java | 21 +++---- .../planner/optimizations/AddExchanges.java | 14 ++--- .../optimizations/AddLocalExchanges.java | 14 ++--- .../optimizations/ExpressionEquivalence.java | 31 ++++------ .../optimizations/PredicatePushDown.java | 46 ++++---------- .../optimizations/PropertyDerivations.java | 27 ++++---- .../StreamPropertyDerivations.java | 18 +++--- .../sanity/NoDuplicatePlanNodeIdsChecker.java | 4 +- .../sanity/NoIdentifierLeftChecker.java | 4 +- .../NoSubqueryExpressionLeftChecker.java | 4 +- .../sql/planner/sanity/PlanSanityChecker.java | 12 ++-- .../sql/planner/sanity/TypeValidator.java | 21 +++---- ...ValidateAggregationsWithDefaultValues.java | 16 ++--- .../sanity/ValidateDependenciesChecker.java | 4 +- .../sanity/ValidateStreamingAggregations.java | 14 ++--- .../sanity/VerifyNoFilteredAggregations.java | 4 +- .../sanity/VerifyOnlyOneOutputNode.java | 4 +- .../prestosql/testing/LocalQueryRunner.java | 7 ++- .../io/prestosql/execution/TaskTestUtils.java | 3 +- ...BenchmarkScanFilterAndProjectOperator.java | 18 +++--- .../operator/scalar/FunctionAssertions.java | 17 ++---- .../sql/TestExpressionInterpreter.java | 9 ++- .../sql/gen/PageProcessorBenchmark.java | 8 +-- .../sql/planner/TestTypeValidator.java | 4 +- .../rule/TestPushPredicateIntoTableScan.java | 3 +- .../rule/TestSimplifyExpressions.java | 3 +- .../iterative/rule/test/RuleTester.java | 13 ++-- .../optimizations/TestEliminateSorts.java | 3 +- .../TestExpressionEquivalence.java | 3 +- .../optimizations/TestReorderWindows.java | 3 +- ...ValidateAggregationsWithDefaultValues.java | 5 +- .../TestValidateStreamingAggregations.java | 8 +-- .../type/BenchmarkDecimalOperators.java | 17 ++++-- .../tests/AbstractTestQueryFramework.java | 3 +- 52 files changed, 330 insertions(+), 409 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index 3b3171c37318d..568a30387ceeb 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -23,7 +23,6 @@ import io.prestosql.execution.Lifespan; import io.prestosql.execution.TaskId; import io.prestosql.execution.TaskStateMachine; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.memory.MemoryPool; import io.prestosql.memory.QueryContext; import io.prestosql.metadata.Metadata; @@ -52,6 +51,7 @@ import io.prestosql.split.SplitSource; import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.HashGenerationOptimizer; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -83,7 +83,6 @@ import static io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; @@ -229,14 +228,9 @@ protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNo Optional hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet())); verify(hashExpression.isPresent()); - Map, Type> expressionTypes = getExpressionTypes( - session, - localQueryRunner.getMetadata(), - localQueryRunner.getSqlParser(), - TypeProvider.copyOf(symbolTypes.build()), - hashExpression.get(), - ImmutableList.of(), - WarningCollector.NOOP); + Map, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata()) + .getTypes(session, TypeProvider.copyOf(symbolTypes.build()), hashExpression.get()); + RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionRegistry(), localQueryRunner.getTypeManager(), session, false); PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0); diff --git a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java index f25c31450df4d..b16089288821a 100644 --- a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java +++ b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.java @@ -389,6 +389,6 @@ public void testPushDownAnd() private RuleAssert assertRuleApplication() { RuleTester tester = tester(); - return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser())); + return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer())); } } diff --git a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java index d5eac82517f09..44e8374fc25e2 100644 --- a/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java +++ b/presto-geospatial/src/test/java/io/prestosql/plugin/geospatial/TestExtractSpatialLeftJoin.java @@ -258,6 +258,6 @@ public void testPushDownAnd() private RuleAssert assertRuleApplication() { RuleTester tester = tester(); - return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser())); + return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer())); } } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java index a46e38ff80176..52ba8216f3acc 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java @@ -62,6 +62,7 @@ import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.StageExecutionPlan; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.tree.Explain; import io.prestosql.transaction.TransactionManager; @@ -414,7 +415,7 @@ private PlanRoot doAnalyzeQuery() // plan query PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, stateMachine.getWarningCollector()); + LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, new TypeAnalyzer(sqlParser, metadata), statsCalculator, costCalculator, stateMachine.getWarningCollector()); Plan plan = logicalPlanner.plan(analysis); queryPlan.set(plan); diff --git a/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java b/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java index c8b15755063d4..af58b6b1fe7db 100644 --- a/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java +++ b/presto-main/src/main/java/io/prestosql/server/ServerMainModule.java @@ -127,6 +127,7 @@ import io.prestosql.sql.planner.CompilerConfig; import io.prestosql.sql.planner.LocalExecutionPlanner; import io.prestosql.sql.planner.NodePartitioningManager; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.transaction.TransactionManagerConfig; @@ -354,6 +355,7 @@ protected void setup(Binder binder) binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); // type + binder.bind(TypeAnalyzer.class).in(Scopes.SINGLETON); binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index 1e04eb9700d32..8747d7fd3b53e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -1421,44 +1421,6 @@ public static Signature resolveFunction(FunctionCall node, List, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Expression expression, - List parameters, - WarningCollector warningCollector) - { - return getExpressionTypes(session, metadata, sqlParser, types, expression, parameters, warningCollector, false); - } - - public static Map, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Expression expression, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return getExpressionTypes(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, warningCollector, isDescribe); - } - - public static Map, Type> getExpressionTypes( - Session session, - Metadata metadata, - SqlParser sqlParser, - TypeProvider types, - Iterable expressions, - List parameters, - WarningCollector warningCollector, - boolean isDescribe) - { - return analyzeExpressions(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes(); - } - public static ExpressionAnalysis analyzeExpressions( Session session, Metadata metadata, diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java index 834b9a313ee89..d0fd2416acec0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/QueryExplainer.java @@ -29,6 +29,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.planPrinter.IoPlanPrinter; import io.prestosql.sql.planner.planPrinter.PlanPrinter; @@ -175,7 +176,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List s throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references"); } - Map, Type> expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = ExpressionAnalyzer.analyzeExpressions( session, metadata, sqlParser, TypeProvider.empty(), - relation.getSamplePercentage(), + ImmutableList.of(relation.getSamplePercentage()), analysis.getParameters(), WarningCollector.NOOP, - analysis.isDescribe()); + analysis.isDescribe()) + .getExpressionTypes(); + ExpressionInterpreter samplePercentageEval = expressionOptimizer(relation.getSamplePercentage(), metadata, session, expressionTypes); Object samplePercentageObject = samplePercentageEval.optimize(symbol -> { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java b/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java index e9020862a05c2..0c6a2375443fc 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DesugarAtTimeZoneRewriter.java @@ -16,10 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.tree.AtTimeZone; import io.prestosql.sql.tree.Cast; import io.prestosql.sql.tree.Expression; @@ -36,8 +34,6 @@ import static io.prestosql.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; import static io.prestosql.spi.type.TimestampType.TIMESTAMP; import static io.prestosql.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class DesugarAtTimeZoneRewriter @@ -49,15 +45,15 @@ public static Expression rewrite(Expression expression, Map, private DesugarAtTimeZoneRewriter() {} - public static Expression rewrite(Expression expression, Session session, Metadata metadata, SqlParser sqlParser, SymbolAllocator symbolAllocator) + public static Expression rewrite(Expression expression, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { return expression; } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); return rewrite(expression, expressionTypes); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java index cfa67ca0d1356..61eb2a8697cc0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.spi.block.Block; @@ -34,7 +33,6 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.ExpressionUtils; import io.prestosql.sql.InterpretedFunctionInvoker; -import io.prestosql.sql.analyzer.ExpressionAnalyzer; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.tree.AstVisitor; import io.prestosql.sql.tree.BetweenPredicate; @@ -78,7 +76,6 @@ import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; -import static java.util.Collections.emptyList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; @@ -277,7 +274,7 @@ public static ExtractionResult fromPredicate( Expression predicate, TypeProvider types) { - return new Visitor(metadata, session, types).process(predicate, false); + return new Visitor(metadata, session, types, new TypeAnalyzer(new SqlParser(), metadata)).process(predicate, false); } private static class Visitor @@ -288,14 +285,16 @@ private static class Visitor private final Session session; private final TypeProvider types; private final InterpretedFunctionInvoker functionInvoker; + private final TypeAnalyzer typeAnalyzer; - private Visitor(Metadata metadata, Session session, TypeProvider types) + private Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); this.functionInvoker = new InterpretedFunctionInvoker(metadata.getFunctionRegistry()); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } private Type checkedTypeLookup(Symbol symbol) @@ -424,7 +423,7 @@ else if (symbolExpression instanceof Cast) { return super.visitComparisonExpression(node, complement); } - Type castSourceType = typeOf(castExpression.getExpression(), session, metadata, types); // type of expression which is then cast to type of value + Type castSourceType = typeAnalyzer.getType(session, types, castExpression.getExpression()); // type of expression which is then cast to type of value // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side Optional coercedExpression = coerceComparisonWithRounding( @@ -489,7 +488,7 @@ private boolean isImplicitCoercion(Cast cast) private Map, Type> analyzeExpression(Expression expression) { - return ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); + return typeAnalyzer.getTypes(session, types, expression); } private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) @@ -757,12 +756,6 @@ protected ExtractionResult visitNullLiteral(NullLiteral node, Boolean complement } } - private static Type typeOf(Expression expression, Session session, Metadata metadata, TypeProvider types) - { - Map, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); - return expressionTypes.get(NodeRef.of(expression)); - } - private static class NormalizedSimpleComparison { private final Expression symbolExpression; diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 800c88eb75637..47aefe8fec2e2 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -133,7 +133,6 @@ import io.prestosql.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import io.prestosql.sql.gen.OrderingCompiler; import io.prestosql.sql.gen.PageFunctionCompiler; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.optimizations.IndexJoinOptimizer; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -230,7 +229,6 @@ import static io.prestosql.SystemSessionProperties.isSpillEnabled; import static io.prestosql.SystemSessionProperties.isSpillOrderBy; import static io.prestosql.SystemSessionProperties.isSpillWindowOperator; -import static io.prestosql.execution.warnings.WarningCollector.NOOP; import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static io.prestosql.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory; @@ -249,7 +247,6 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.TypeUtils.writeNativeValue; import static io.prestosql.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; @@ -276,14 +273,13 @@ import static io.prestosql.util.SpatialJoinUtils.ST_WITHIN; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; public class LocalExecutionPlanner { private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final Optional explainAnalyzeContext; private final PageSourceProvider pageSourceProvider; private final IndexManager indexManager; @@ -310,7 +306,7 @@ public class LocalExecutionPlanner @Inject public LocalExecutionPlanner( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, Optional explainAnalyzeContext, PageSourceProvider pageSourceProvider, IndexManager indexManager, @@ -337,7 +333,7 @@ public LocalExecutionPlanner( this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.exchangeClientSupplier = exchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); this.pageFunctionCompiler = requireNonNull(pageFunctionCompiler, "pageFunctionCompiler is null"); @@ -1215,15 +1211,10 @@ else if (sourceNode instanceof SampleNode) { projections.add(assignments.get(symbol)); } - Map, Type> expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = typeAnalyzer.getTypes( context.getSession(), - metadata, - sqlParser, context.getTypes(), - concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions()), - emptyList(), - NOOP, - false); + concat(filterExpression.map(ImmutableList::of).orElse(ImmutableList.of()), assignments.getExpressions())); Optional translatedFilter = filterExpression.map(filter -> toRowExpression(filter, expressionTypes, sourceLayout)); List translatedProjections = projections.stream() @@ -1300,15 +1291,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext PageBuilder pageBuilder = new PageBuilder(node.getRows().size(), outputTypes); for (List row : node.getRows()) { pageBuilder.declarePosition(); - Map, Type> expressionTypes = getExpressionTypes( - context.getSession(), - metadata, - sqlParser, - TypeProvider.empty(), - ImmutableList.copyOf(row), - emptyList(), - NOOP, - false); + Map, Type> expressionTypes = typeAnalyzer.getTypes(context.getSession(), TypeProvider.empty(), ImmutableList.copyOf(row)); for (int i = 0; i < row.size(); i++) { // evaluate the literal value Object result = ExpressionInterpreter.expressionInterpreter(row.get(i), metadata, context.getSession(), expressionTypes).evaluate(); @@ -2033,17 +2016,7 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( { Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - types, - filterExpression, - emptyList(), /* parameters have already been replaced */ - NOOP, - false); - - RowExpression translatedFilter = toRowExpression(filterExpression, expressionTypes, joinSourcesLayout); + RowExpression translatedFilter = toRowExpression(filterExpression, typeAnalyzer.getTypes(session, types, filterExpression), joinSourcesLayout); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); } @@ -2554,14 +2527,7 @@ private AccumulatorFactory buildAccumulatorFactory( // expressions from lambda arguments .putAll(lambdaArgumentExpressionTypes) // expressions from lambda body - .putAll(getExpressionTypes( - session, - metadata, - sqlParser, - TypeProvider.copyOf(lambdaArgumentSymbolTypes), - lambdaExpression.getBody(), - emptyList(), - NOOP)) + .putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody())) .build(); LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) toRowExpression(lambdaExpression, expressionTypes, ImmutableMap.of()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java index eb453f53dd03b..a6df3075bc2d4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java @@ -42,7 +42,6 @@ import io.prestosql.sql.analyzer.RelationId; import io.prestosql.sql.analyzer.RelationType; import io.prestosql.sql.analyzer.Scope; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.plan.AggregationNode; @@ -115,7 +114,7 @@ public enum Stage private final PlanSanityChecker planSanityChecker; private final SymbolAllocator symbolAllocator = new SymbolAllocator(); private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final StatisticsAggregationPlanner statisticsAggregationPlanner; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; @@ -125,12 +124,12 @@ public LogicalPlanner(Session session, List planOptimizers, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector) { - this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, warningCollector); + this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, metadata, typeAnalyzer, statsCalculator, costCalculator, warningCollector); } public LogicalPlanner(Session session, @@ -138,7 +137,7 @@ public LogicalPlanner(Session session, PlanSanityChecker planSanityChecker, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector) @@ -148,7 +147,7 @@ public LogicalPlanner(Session session, this.planSanityChecker = requireNonNull(planSanityChecker, "planSanityChecker is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, metadata); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); @@ -164,7 +163,7 @@ public Plan plan(Analysis analysis, Stage stage) { PlanNode root = planStatement(analysis, analysis.getStatement()); - planSanityChecker.validateIntermediatePlan(root, session, metadata, sqlParser, symbolAllocator.getTypes(), warningCollector); + planSanityChecker.validateIntermediatePlan(root, session, metadata, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); if (stage.ordinal() >= Stage.OPTIMIZED.ordinal()) { for (PlanOptimizer optimizer : planOptimizers) { @@ -175,7 +174,7 @@ public Plan plan(Analysis analysis, Stage stage) if (stage.ordinal() >= Stage.OPTIMIZED_AND_VALIDATED.ordinal()) { // make sure we produce a valid plan after optimizations run. This is mainly to catch programming errors - planSanityChecker.validateFinalPlan(root, session, metadata, sqlParser, symbolAllocator.getTypes(), warningCollector); + planSanityChecker.validateFinalPlan(root, session, metadata, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); } TypeProvider types = symbolAllocator.getTypes(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 336665d027600..27b54ec2de971 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -27,7 +27,6 @@ import io.prestosql.split.PageSourceManager; import io.prestosql.split.SplitManager; import io.prestosql.sql.analyzer.FeaturesConfig; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.iterative.IterativeOptimizer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; @@ -146,7 +145,7 @@ public class PlanOptimizers @Inject public PlanOptimizers( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, FeaturesConfig featuresConfig, NodeSchedulerConfig nodeSchedulerConfig, InternalNodeManager nodeManager, @@ -161,7 +160,7 @@ public PlanOptimizers( TaskCountEstimator taskCountEstimator) { this(metadata, - sqlParser, + typeAnalyzer, featuresConfig, taskManagerConfig, false, @@ -191,7 +190,7 @@ public void destroy() public PlanOptimizers( Metadata metadata, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, FeaturesConfig featuresConfig, TaskManagerConfig taskManagerConfig, boolean forceSingleNode, @@ -250,9 +249,9 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - new SimplifyExpressions(metadata, sqlParser).rules()); + new SimplifyExpressions(metadata, typeAnalyzer).rules()); - PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser)); + PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, typeAnalyzer)); builder.add( // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers @@ -262,7 +261,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.>builder() .addAll(new DesugarLambdaExpression().rules()) - .addAll(new DesugarAtTimeZone(metadata, sqlParser).rules()) + .addAll(new DesugarAtTimeZone(metadata, typeAnalyzer).rules()) .addAll(new DesugarCurrentUser().rules()) .addAll(new DesugarCurrentPath().rules()) .addAll(new DesugarTryExpression().rules()) @@ -359,7 +358,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), new PruneUnreferencedOutputs(), new IterativeOptimizer( ruleStats, @@ -409,7 +408,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new PushPredicateIntoTableScan(metadata, sqlParser))), + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), projectionPushDown, new PruneUnreferencedOutputs(), new IterativeOptimizer( @@ -442,7 +441,7 @@ public PlanOptimizers( costCalculator, ImmutableSet.>builder() .add(new RemoveRedundantIdentityProjections()) - .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, sqlParser).rules()) + .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, typeAnalyzer).rules()) .add(new InlineProjections()) .build())); @@ -463,7 +462,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges - builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser))); + builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, typeAnalyzer))); } //noinspection UnusedAssignment estimatedExchangesCostCalculator = null; // Prevent accidental use after AddExchanges @@ -493,7 +492,7 @@ public PlanOptimizers( .build())); // Optimizers above this don't understand local exchanges, so be careful moving this. - builder.add(new AddLocalExchanges(metadata, sqlParser)); + builder.add(new AddLocalExchanges(metadata, typeAnalyzer)); // Optimizers above this do not need to care about aggregations with the type other than SINGLE // This optimizer must be run after all exchange-related optimizers @@ -509,7 +508,7 @@ public PlanOptimizers( ruleStats, statsCalculator, costCalculator, - new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, sqlParser, taskCountEstimator, taskManagerConfig).rules())); + new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(metadata, typeAnalyzer, taskCountEstimator, taskManagerConfig).rules())); builder.add(new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java new file mode 100644 index 0000000000000..c29300f0bc359 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/TypeAnalyzer.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.analyzer.ExpressionAnalyzer; +import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.NodeRef; + +import javax.inject.Inject; + +import java.util.Map; + +/** + * This class is to facilitate obtaining the type of an expression and its subexpressions + * during planning (i.e., when interacting with IR expression). It will eventually get + * removed when we split the AST from the IR and we encode the type directly into IR expressions. + */ +public class TypeAnalyzer +{ + private final SqlParser parser; + private final Metadata metadata; + + @Inject + public TypeAnalyzer(SqlParser parser, Metadata metadata) + { + this.parser = parser; + this.metadata = metadata; + } + + public Map, Type> getTypes(Session session, TypeProvider inputTypes, Iterable expressions) + { + return ExpressionAnalyzer.analyzeExpressions(session, metadata, parser, inputTypes, expressions, ImmutableList.of(), WarningCollector.NOOP, false).getExpressionTypes(); + } + + public Map, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) + { + return getTypes(session, inputTypes, ImmutableList.of(expression)); + } + + public Type getType(Session session, TypeProvider inputTypes, Expression expression) + { + return getTypes(session, inputTypes, expression).get(NodeRef.of(expression)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index 6e6639dab739d..010b8b9adcf9c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -26,10 +26,10 @@ import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.PartitioningScheme; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.optimizations.StreamPreferredProperties; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -128,18 +128,18 @@ public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet private static final double ANTI_SKEWNESS_MARGIN = 3; private final Metadata metadata; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final TaskCountEstimator taskCountEstimator; private final DataSize maxPartialAggregationMemoryUsage; public AddExchangesBelowPartialAggregationOverGroupIdRuleSet( Metadata metadata, - SqlParser parser, + TypeAnalyzer typeAnalyzer, TaskCountEstimator taskCountEstimator, TaskManagerConfig taskManagerConfig) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); this.maxPartialAggregationMemoryUsage = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxPartialAggregationMemoryUsage(); } @@ -342,7 +342,7 @@ private StreamProperties derivePropertiesRecursively(PlanNode node, Context cont List inputProperties = resolvedPlanNode.getSources().stream() .map(source -> derivePropertiesRecursively(source, context)) .collect(toImmutableList()); - return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), parser); + return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), context.getSymbolAllocator().getTypes(), typeAnalyzer); } } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java index c2ebb7b3679ff..1ee472b74a5af 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DesugarAtTimeZone.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DesugarAtTimeZoneRewriter; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import java.util.Set; @@ -26,9 +26,9 @@ public class DesugarAtTimeZone extends ExpressionRewriteRuleSet { - public DesugarAtTimeZone(Metadata metadata, SqlParser sqlParser) + public DesugarAtTimeZone(Metadata metadata, TypeAnalyzer typeAnalyzer) { - super(createRewrite(metadata, sqlParser)); + super(createRewrite(metadata, typeAnalyzer)); } @Override @@ -42,11 +42,11 @@ public Set> rules() valuesExpressionRewrite()); } - private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + private static ExpressionRewriter createRewrite(Metadata metadata, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); - return (expression, context) -> DesugarAtTimeZoneRewriter.rewrite(expression, context.getSession(), metadata, sqlParser, context.getSymbolAllocator()); + return (expression, context) -> DesugarAtTimeZoneRewriter.rewrite(expression, context.getSession(), metadata, typeAnalyzer, context.getSymbolAllocator()); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index b0c118ace6dec..170ca6294796b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -21,7 +21,6 @@ import com.google.common.collect.Iterables; import io.prestosql.Session; import io.prestosql.execution.Lifespan; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.geospatial.KdbTree; import io.prestosql.geospatial.KdbTreeUtils; import io.prestosql.matching.Capture; @@ -42,8 +41,8 @@ import io.prestosql.split.SplitManager; import io.prestosql.split.SplitSource; import io.prestosql.split.SplitSource.SplitBatch; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.iterative.Rule.Context; import io.prestosql.sql.planner.iterative.Rule.Result; @@ -59,7 +58,6 @@ import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.StringLiteral; import io.prestosql.sql.tree.SymbolReference; @@ -85,7 +83,6 @@ import static io.prestosql.spi.type.IntegerType.INTEGER; import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; @@ -98,7 +95,6 @@ import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.prestosql.util.SpatialJoinUtils.extractSupportedSpatialFunctions; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -158,21 +154,21 @@ public class ExtractSpatialJoins private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } public Set> rules() { return ImmutableSet.of( - new ExtractSpatialInnerJoin(metadata, splitManager, pageSourceManager, sqlParser), - new ExtractSpatialLeftJoin(metadata, splitManager, pageSourceManager, sqlParser)); + new ExtractSpatialInnerJoin(metadata, splitManager, pageSourceManager, typeAnalyzer), + new ExtractSpatialLeftJoin(metadata, splitManager, pageSourceManager, typeAnalyzer)); } @VisibleForTesting @@ -186,14 +182,14 @@ public static final class ExtractSpatialInnerJoin private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialInnerJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialInnerJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -215,7 +211,7 @@ public Result apply(FilterNode node, Captures captures, Context context) Expression filter = node.getPredicate(); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -223,7 +219,7 @@ public Result apply(FilterNode node, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, node.getId(), node.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -242,14 +238,14 @@ public static final class ExtractSpatialLeftJoin private final Metadata metadata; private final SplitManager splitManager; private final PageSourceManager pageSourceManager; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, SqlParser sqlParser) + public ExtractSpatialLeftJoin(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -270,7 +266,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) Expression filter = joinNode.getFilter().get(); List spatialFunctions = extractSupportedSpatialFunctions(filter); for (FunctionCall spatialFunction : spatialFunctions) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialFunction, Optional.empty(), metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -278,7 +274,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) List spatialComparisons = extractSupportedSpatialComparisons(filter); for (ComparisonExpression spatialComparison : spatialComparisons) { - Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, sqlParser); + Result result = tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), joinNode.getOutputSymbols(), spatialComparison, metadata, splitManager, pageSourceManager, typeAnalyzer); if (!result.isEmpty()) { return result; } @@ -298,7 +294,7 @@ private static Result tryCreateSpatialJoin( Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, - SqlParser sqlParser) + TypeAnalyzer typeAnalyzer) { PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -350,7 +346,7 @@ private static Result tryCreateSpatialJoin( joinNode.getDistributionType(), joinNode.isSpillable()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, sqlParser); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), metadata, splitManager, pageSourceManager, typeAnalyzer); } private static Result tryCreateSpatialJoin( @@ -364,7 +360,7 @@ private static Result tryCreateSpatialJoin( Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, - SqlParser sqlParser) + TypeAnalyzer typeAnalyzer) { // TODO Add support for distributed left spatial joins Optional spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty(); @@ -377,8 +373,8 @@ private static Result tryCreateSpatialJoin( Expression secondArgument = arguments.get(1); Type sphericalGeographyType = metadata.getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE); - if (getExpressionType(firstArgument, context, metadata, sqlParser).equals(sphericalGeographyType) - || getExpressionType(secondArgument, context, metadata, sqlParser).equals(sphericalGeographyType)) { + if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) + || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) { return Result.empty(); } @@ -446,14 +442,6 @@ else if (alignment < 0) { kdbTree.map(KdbTreeUtils::toJson))); } - private static Type getExpressionType(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) - { - Type type = getExpressionTypes(context.getSession(), metadata, sqlParser, context.getSymbolAllocator().getTypes(), expression, emptyList(), WarningCollector.NOOP) - .get(NodeRef.of(expression)); - verify(type != null); - return type; - } - private static KdbTree loadKdbTree(String tableName, Session session, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) { QualifiedObjectName name = toQualifiedObjectName(tableName, session.getCatalog().get(), session.getSchema().get()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index 244200de417d6..07625f27ee587 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.matching.Capture; import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; @@ -27,8 +26,6 @@ import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.LiteralEncoder; @@ -36,6 +33,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.plan.FilterNode; @@ -43,7 +41,6 @@ import io.prestosql.sql.planner.plan.TableScanNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.NullLiteral; import java.util.Map; @@ -59,12 +56,10 @@ import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; import static io.prestosql.sql.ExpressionUtils.filterNonDeterministicConjuncts; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.plan.Patterns.filter; import static io.prestosql.sql.planner.plan.Patterns.source; import static io.prestosql.sql.planner.plan.Patterns.tableScan; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -80,13 +75,13 @@ public class PushPredicateIntoTableScan tableScan().capturedAs(TABLE_SCAN))); private final Metadata metadata; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final DomainTranslator domainTranslator; - public PushPredicateIntoTableScan(Metadata metadata, SqlParser parser) + public PushPredicateIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); } @@ -115,7 +110,7 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) context.getSymbolAllocator().getTypes(), context.getIdAllocator(), metadata, - parser, + typeAnalyzer, domainTranslator); if (arePlansSame(filterNode, tableScan, rewritten)) { @@ -154,7 +149,7 @@ public static PlanNode pushFilterIntoTableScan( TypeProvider types, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser parser, + TypeAnalyzer typeAnalyzer, DomainTranslator domainTranslator) { // don't include non-deterministic predicates @@ -176,7 +171,7 @@ public static PlanNode pushFilterIntoTableScan( if (pruneWithPredicateExpression) { LayoutConstraintEvaluator evaluator = new LayoutConstraintEvaluator( metadata, - parser, + typeAnalyzer, session, types, node.getAssignments(), @@ -239,13 +234,11 @@ private static class LayoutConstraintEvaluator private final ExpressionInterpreter evaluator; private final Set arguments; - public LayoutConstraintEvaluator(Metadata metadata, SqlParser parser, Session session, TypeProvider types, Map assignments, Expression expression) + public LayoutConstraintEvaluator(Metadata metadata, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types, Map assignments, Expression expression) { this.assignments = assignments; - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); - - evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); + evaluator = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, typeAnalyzer.getTypes(session, types, expression)); arguments = SymbolsExtractor.extractUnique(expression).stream() .map(assignments::get) .collect(toImmutableSet()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java index 0ccf965931ee2..95c39cd649582 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/SimplifyExpressions.java @@ -16,14 +16,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.LiteralEncoder; import io.prestosql.sql.planner.NoOpSymbolResolver; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.NodeRef; @@ -32,33 +31,31 @@ import java.util.Map; import java.util.Set; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates; import static io.prestosql.sql.planner.iterative.rule.PushDownNegationsExpressionRewriter.pushDownNegations; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class SimplifyExpressions extends ExpressionRewriteRuleSet { @VisibleForTesting - static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, SqlParser sqlParser) + static Expression rewrite(Expression expression, Session session, SymbolAllocator symbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); if (expression instanceof SymbolReference) { return expression; } expression = pushDownNegations(expression); expression = extractCommonPredicates(expression); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return literalEncoder.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } - public SimplifyExpressions(Metadata metadata, SqlParser sqlParser) + public SimplifyExpressions(Metadata metadata, TypeAnalyzer typeAnalyzer) { - super(createRewrite(metadata, sqlParser)); + super(createRewrite(metadata, typeAnalyzer)); } @Override @@ -71,12 +68,12 @@ public Set> rules() valuesExpressionRewrite()); // ApplyNode and AggregationNode are not supported, because ExpressionInterpreter doesn't support them } - private static ExpressionRewriter createRewrite(Metadata metadata, SqlParser sqlParser) + private static ExpressionRewriter createRewrite(Metadata metadata, TypeAnalyzer typeAnalyzer) { requireNonNull(metadata, "metadata is null"); - requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); LiteralEncoder literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); - return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), metadata, literalEncoder, sqlParser); + return (expression, context) -> rewrite(expression, context.getSession(), context.getSymbolAllocator(), metadata, literalEncoder, typeAnalyzer); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java index e4a97ef1298bb..57f98b9d80640 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddExchanges.java @@ -26,7 +26,6 @@ import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; import io.prestosql.spi.connector.SortingProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.LiteralEncoder; import io.prestosql.sql.planner.Partitioning; @@ -34,6 +33,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; import io.prestosql.sql.planner.plan.AggregationNode; @@ -112,15 +112,15 @@ public class AddExchanges implements PlanOptimizer { - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; private final Metadata metadata; private final DomainTranslator domainTranslator; - public AddExchanges(Metadata metadata, SqlParser parser) + public AddExchanges(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = metadata; this.domainTranslator = new DomainTranslator(new LiteralEncoder(metadata.getBlockEncodingSerde())); - this.parser = parser; + this.typeAnalyzer = typeAnalyzer; } @Override @@ -532,7 +532,7 @@ else if (redistributeWrites) { private PlanWithProperties planTableScan(TableScanNode node, Expression predicate) { - PlanNode plan = PushPredicateIntoTableScan.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, parser, domainTranslator); + PlanNode plan = PushPredicateIntoTableScan.pushFilterIntoTableScan(node, predicate, true, session, types, idAllocator, metadata, typeAnalyzer, domainTranslator); return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); } @@ -1190,7 +1190,7 @@ private ActualProperties deriveProperties(PlanNode result, ActualProperties inpu private ActualProperties deriveProperties(PlanNode result, List inputProperties) { // TODO: move this logic to PlanSanityChecker once PropertyDerivations.deriveProperties fully supports local exchanges - ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, parser); + ActualProperties outputProperties = PropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, typeAnalyzer); verify(result instanceof SemiJoinNode || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated) || outputProperties.isNullsAndAnyReplicated(), "SemiJoinNode is the only node that can strip null replication"); return outputProperties; @@ -1198,7 +1198,7 @@ private ActualProperties deriveProperties(PlanNode result, List inputProperties) { - return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, parser)); + return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, types, typeAnalyzer)); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java index d051b49f5941e..efae3082d79e5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/ExpressionEquivalence.java @@ -20,12 +20,11 @@ import com.google.common.collect.Ordering; import io.airlift.slice.Slice; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.CallExpression; import io.prestosql.sql.relational.ConstantExpression; @@ -35,7 +34,6 @@ import io.prestosql.sql.relational.RowExpressionVisitor; import io.prestosql.sql.relational.VariableReferenceExpression; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import java.util.Comparator; import java.util.HashMap; @@ -57,10 +55,8 @@ import static io.prestosql.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.prestosql.spi.function.OperatorType.NOT_EQUAL; import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static java.lang.Integer.min; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class ExpressionEquivalence @@ -68,12 +64,12 @@ public class ExpressionEquivalence private static final Ordering ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator()); private static final CanonicalizationVisitor CANONICALIZATION_VISITOR = new CanonicalizationVisitor(); private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) + public ExpressionEquivalence(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types) @@ -95,18 +91,15 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types) { - // determine the type of every expression - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - types, + return translate( expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); - - // convert to row expression - return translate(expression, SCALAR, expressionTypes, symbolInput, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); + SCALAR, + typeAnalyzer.getTypes(session, types, expression), + symbolInput, + metadata.getFunctionRegistry(), + metadata.getTypeManager(), + session, + false); } private static class CanonicalizationVisitor diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java index 27f7e0a558cab..246616248cecb 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java @@ -20,7 +20,6 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DeterminismEvaluator; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.EffectivePredicateExtractor; @@ -32,6 +31,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AssignUniqueId; @@ -84,7 +84,6 @@ import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.extractConjuncts; import static io.prestosql.sql.ExpressionUtils.filterDeterministicConjuncts; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.DeterminismEvaluator.isDeterministic; import static io.prestosql.sql.planner.EqualityInference.createEqualityInference; import static io.prestosql.sql.planner.ExpressionSymbolInliner.inlineSymbols; @@ -93,7 +92,6 @@ import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; import static io.prestosql.sql.planner.plan.JoinNode.Type.RIGHT; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; public class PredicatePushDown @@ -102,14 +100,14 @@ public class PredicatePushDown private final Metadata metadata; private final LiteralEncoder literalEncoder; private final EffectivePredicateExtractor effectivePredicateExtractor; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; - public PredicatePushDown(Metadata metadata, SqlParser sqlParser) + public PredicatePushDown(Metadata metadata, TypeAnalyzer typeAnalyzer) { this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde()); this.effectivePredicateExtractor = new EffectivePredicateExtractor(new DomainTranslator(literalEncoder)); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); } @Override @@ -121,7 +119,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith( - new Rewriter(symbolAllocator, idAllocator, metadata, literalEncoder, effectivePredicateExtractor, sqlParser, session, types), + new Rewriter(symbolAllocator, idAllocator, metadata, literalEncoder, effectivePredicateExtractor, typeAnalyzer, session, types), plan, TRUE_LITERAL); } @@ -134,7 +132,7 @@ private static class Rewriter private final Metadata metadata; private final LiteralEncoder literalEncoder; private final EffectivePredicateExtractor effectivePredicateExtractor; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final Session session; private final TypeProvider types; private final ExpressionEquivalence expressionEquivalence; @@ -145,7 +143,7 @@ private Rewriter( Metadata metadata, LiteralEncoder literalEncoder, EffectivePredicateExtractor effectivePredicateExtractor, - SqlParser sqlParser, + TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) { @@ -154,10 +152,10 @@ private Rewriter( this.metadata = requireNonNull(metadata, "metadata is null"); this.literalEncoder = requireNonNull(literalEncoder, "literalEncoder is null"); this.effectivePredicateExtractor = requireNonNull(effectivePredicateExtractor, "effectivePredicateExtractor is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); - this.expressionEquivalence = new ExpressionEquivalence(metadata, sqlParser); + this.expressionEquivalence = new ExpressionEquivalence(metadata, typeAnalyzer); } @Override @@ -638,7 +636,7 @@ private Symbol symbolForExpression(Expression expression) return Symbol.from(expression); } - return symbolAllocator.newSymbol(expression, extractType(expression)); + return symbolAllocator.newSymbol(expression, typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression)); } private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection outerSymbols) @@ -891,12 +889,6 @@ private static Expression extractJoinPredicate(JoinNode joinNode) return combineConjuncts(builder.build()); } - private Type extractType(Expression expression) - { - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), /* parameters have already been replaced */WarningCollector.NOOP); - return expressionTypes.get(NodeRef.of(expression)); - } - private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) { checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); @@ -948,14 +940,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex // Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses private Expression simplifyExpression(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - symbolAllocator.getTypes(), - expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); return literalEncoder.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } @@ -970,14 +955,7 @@ private boolean areExpressionsEquivalent(Expression leftExpression, Expression r */ private Object nullInputEvaluator(final Collection nullSymbols, Expression expression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - sqlParser, - symbolAllocator.getTypes(), - expression, - emptyList(), /* parameters have already been replaced */ - WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression); return ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes) .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java index 8657f7f3e31d3..b0de8196db49e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java @@ -21,7 +21,6 @@ import com.google.common.collect.Sets; import io.prestosql.Session; import io.prestosql.SystemSessionProperties; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableLayout; import io.prestosql.metadata.TableLayout.TablePartitioning; @@ -32,13 +31,13 @@ import io.prestosql.spi.connector.SortingProperty; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.type.Type; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.DomainTranslator; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.NoOpSymbolResolver; import io.prestosql.sql.planner.OrderingScheme; import io.prestosql.sql.planner.Partitioning; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.ActualProperties.Global; import io.prestosql.sql.planner.plan.AggregationNode; @@ -95,7 +94,6 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.SystemSessionProperties.planWithTableNodePartitioning; import static io.prestosql.spi.predicate.TupleDomain.extractFixedValues; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION; import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.coordinatorSingleStreamPartition; @@ -104,7 +102,6 @@ import static io.prestosql.sql.planner.optimizations.ActualProperties.Global.streamPartitionedOn; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.REMOTE; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; @@ -112,17 +109,17 @@ public class PropertyDerivations { private PropertyDerivations() {} - public static ActualProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session, types, parser)) + .map(source -> derivePropertiesRecursively(source, metadata, session, types, typeAnalyzer)) .collect(toImmutableList()); - return deriveProperties(node, inputProperties, metadata, session, types, parser); + return deriveProperties(node, inputProperties, metadata, session, types, typeAnalyzer); } - public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - ActualProperties output = node.accept(new Visitor(metadata, session, types, parser), inputProperties); + ActualProperties output = node.accept(new Visitor(metadata, session, types, typeAnalyzer), inputProperties); output.getNodePartitioning().ifPresent(partitioning -> verify(node.getOutputSymbols().containsAll(partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); @@ -137,9 +134,9 @@ public static ActualProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static ActualProperties streamBackdoorDeriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - return node.accept(new Visitor(metadata, session, types, parser), inputProperties); + return node.accept(new Visitor(metadata, session, types, typeAnalyzer), inputProperties); } private static class Visitor @@ -148,14 +145,14 @@ private static class Visitor private final Metadata metadata; private final Session session; private final TypeProvider types; - private final SqlParser parser; + private final TypeAnalyzer typeAnalyzer; - public Visitor(Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { this.metadata = metadata; this.session = session; this.types = types; - this.parser = parser; + this.typeAnalyzer = typeAnalyzer; } @Override @@ -636,7 +633,7 @@ public ActualProperties visitProject(ProjectNode node, List in for (Map.Entry assignment : node.getAssignments().entrySet()) { Expression expression = assignment.getValue(); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression); Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); // TODO: diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java index d7f6202632384..ed1c0ff221d63 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java @@ -23,9 +23,9 @@ import io.prestosql.metadata.TableLayout; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.LocalProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Partitioning.ArgumentBinding; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.ApplyNode; @@ -96,27 +96,27 @@ public final class StreamPropertyDerivations { private StreamPropertyDerivations() {} - public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session, types, parser)) + .map(source -> derivePropertiesRecursively(source, metadata, session, types, typeAnalyzer)) .collect(toImmutableList()); - return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, types, parser); + return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, types, typeAnalyzer); } - public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { - return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, types, parser); + return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, types, typeAnalyzer); } - public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, SqlParser parser) + public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer) { requireNonNull(node, "node is null"); requireNonNull(inputProperties, "inputProperties is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); - requireNonNull(parser, "parser is null"); + requireNonNull(typeAnalyzer, "typeAnalyzer is null"); // properties.otherActualProperties will never be null here because the only way // an external caller should obtain StreamProperties is from this method, and the @@ -129,7 +129,7 @@ public static StreamProperties deriveProperties(PlanNode node, List planNodeIds = new HashMap<>(); searchFrom(planNode) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java index ddb9e03c65dc8..a7d0f133cac06 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoIdentifierLeftChecker.java @@ -17,8 +17,8 @@ import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.sql.analyzer.ExpressionTreeUtils; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.tree.Identifier; @@ -29,7 +29,7 @@ public final class NoIdentifierLeftChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { List identifiers = ExpressionTreeUtils.extractExpressions(ExpressionExtractor.extractExpressions(plan), Identifier.class); if (!identifiers.isEmpty()) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java index 916235efc1b96..88c4e94c63b8c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java @@ -16,8 +16,8 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.tree.DefaultTraversalVisitor; @@ -30,7 +30,7 @@ public final class NoSubqueryExpressionLeftChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { for (Expression expression : ExpressionExtractor.extractExpressions(plan)) { new DefaultTraversalVisitor() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java index cc7a468aad48c..6c6cb14d1de4d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/PlanSanityChecker.java @@ -18,7 +18,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNode; @@ -56,19 +56,19 @@ public PlanSanityChecker(boolean forceSingleNode) .build(); } - public void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types, warningCollector)); + checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, typeAnalyzer, types, warningCollector)); } - public void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types, warningCollector)); + checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, typeAnalyzer, types, warningCollector)); } public interface Checker { - void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector); + void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector); } private enum Stage diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java index daf7934aa459f..30c239dc29fda 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/TypeValidator.java @@ -21,9 +21,9 @@ import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeManager; import io.prestosql.spi.type.TypeSignature; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.SimplePlanVisitor; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -33,16 +33,13 @@ import io.prestosql.sql.planner.plan.WindowNode; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; -import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.SymbolReference; import java.util.List; import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.type.UnknownType.UNKNOWN; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -54,9 +51,9 @@ public final class TypeValidator public TypeValidator() {} @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - plan.accept(new Visitor(session, metadata, sqlParser, types, warningCollector), null); + plan.accept(new Visitor(session, metadata, typeAnalyzer, types, warningCollector), null); } private static class Visitor @@ -64,15 +61,15 @@ private static class Visitor { private final Session session; private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; private final WarningCollector warningCollector; - public Visitor(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } @@ -119,8 +116,7 @@ public Void visitProject(ProjectNode node, Void context) verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); continue; } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); + Type actualType = typeAnalyzer.getType(session, types, entry.getValue()); verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } @@ -165,8 +161,7 @@ private void checkSignature(Symbol symbol, Signature signature) private void checkCall(Symbol symbol, FunctionCall call) { Type expectedType = types.get(symbol); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, call, emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(call)); + Type actualType = typeAnalyzer.getType(session, types, call); verifyTypeSignature(symbol, expectedType.getTypeSignature(), actualType.getTypeSignature()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java index a6f760773aaff..22152e412d349 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateAggregationsWithDefaultValues.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.ActualProperties; import io.prestosql.sql.planner.optimizations.PropertyDerivations; @@ -60,9 +60,9 @@ public ValidateAggregationsWithDefaultValues(boolean forceSingleNode) } @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata, sqlParser, types), null); + planNode.accept(new Visitor(session, metadata, typeAnalyzer, types), null); } private class Visitor @@ -70,14 +70,14 @@ private class Visitor { final Session session; final Metadata metadata; - final SqlParser parser; + final TypeAnalyzer typeAnalyzer; final TypeProvider types; - Visitor(Session session, Metadata metadata, SqlParser parser, TypeProvider types) + Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); - this.parser = requireNonNull(parser, "parser is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.types = requireNonNull(types, "types is null"); } @@ -115,14 +115,14 @@ public Optional visitAggregation(AggregationNode node, Void conte // No remote repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed on a single node. - ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, parser); + ActualProperties globalProperties = PropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, typeAnalyzer); checkArgument(forceSingleNode || globalProperties.isSingleNode(), "Final aggregation with default value not separated from partial aggregation by remote hash exchange"); if (!seenExchanges.localRepartitionExchange) { // No local repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed by single thread. - StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, parser); + StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, types, typeAnalyzer); checkArgument(localProperties.isSingleStream(), "Final aggregation with default value not separated from partial aggregation by local hash exchange"); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java index 56a8175897b21..ea0949bbc717c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java @@ -18,9 +18,9 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; @@ -85,7 +85,7 @@ public final class ValidateDependenciesChecker implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { validate(plan); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java index 3d2c37400bac0..bf0b963104130 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateStreamingAggregations.java @@ -20,8 +20,8 @@ import io.prestosql.metadata.Metadata; import io.prestosql.spi.connector.GroupingProperty; import io.prestosql.spi.connector.LocalProperty; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.optimizations.LocalProperties; import io.prestosql.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; @@ -44,9 +44,9 @@ public class ValidateStreamingAggregations implements Checker { @Override - public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata, sqlParser, types, warningCollector), null); + planNode.accept(new Visitor(session, metadata, typeAnalyzer, types, warningCollector), null); } private static final class Visitor @@ -54,15 +54,15 @@ private static final class Visitor { private final Session session; private final Metadata metadata; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; private final TypeProvider types; private final WarningCollector warningCollector; - private Visitor(Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + private Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { this.session = session; this.metadata = metadata; - this.sqlParser = sqlParser; + this.typeAnalyzer = typeAnalyzer; this.types = types; this.warningCollector = warningCollector; } @@ -81,7 +81,7 @@ public Void visitAggregation(AggregationNode node, Void context) return null; } - StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, types, sqlParser); + StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, types, typeAnalyzer); List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedSymbols())); Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java index 80999b7024dd6..479e11487d8c4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyNoFilteredAggregations.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.PlanNode; @@ -27,7 +27,7 @@ public final class VerifyNoFilteredAggregations implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { searchFrom(plan) .where(AggregationNode.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java index db860491ae864..9c81a94d61a0e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/VerifyOnlyOneOutputNode.java @@ -16,7 +16,7 @@ import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanNode; @@ -28,7 +28,7 @@ public final class VerifyOnlyOneOutputNode implements PlanSanityChecker.Checker { @Override - public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) + public void validate(PlanNode plan, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) { int outputPlanNodesCount = searchFrom(plan) .where(OutputNode.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java index 2111a41c6b8d9..451add17a7dd7 100644 --- a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java @@ -136,6 +136,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.PlanOptimizers; import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; @@ -695,7 +696,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), Optional.empty(), pageSourceManager, indexManager, @@ -809,7 +810,7 @@ public List getPlanOptimizers(boolean forceSingleNode) { return new PlanOptimizers( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), featuresConfig, taskManagerConfig, forceSingleNode, @@ -847,7 +848,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR); private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); - private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA); private static final int TOTAL_POSITIONS = 1_000_000; private static final DataSize FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = new DataSize(500, KILOBYTE); @@ -232,8 +229,15 @@ private RowExpression rowExpression(String value) { Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate( + expression, + SCALAR, + TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + METADATA.getFunctionRegistry(), + METADATA.getTypeManager(), + TEST_SESSION, + true); } private static Page createPage(List types, int positions, boolean dictionary) diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java index d4b4b2ee136e6..f536c574d131f 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/FunctionAssertions.java @@ -65,6 +65,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.relational.RowExpression; @@ -127,14 +128,12 @@ import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; import static io.prestosql.sql.analyzer.ExpressionAnalyzer.analyzeExpressions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression; import static io.prestosql.sql.relational.Expressions.constant; import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate; import static io.prestosql.testing.TestingTaskContext.createTaskContext; import static io.prestosql.type.UnknownType.UNKNOWN; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -211,6 +210,7 @@ public final class FunctionAssertions private final Session session; private final LocalQueryRunner runner; private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; private final ExpressionCompiler compiler; public FunctionAssertions() @@ -229,6 +229,7 @@ public FunctionAssertions(Session session, FeaturesConfig featuresConfig) runner = new LocalQueryRunner(session, featuresConfig); metadata = runner.getMetadata(); compiler = runner.getExpressionCompiler(); + typeAnalyzer = new TypeAnalyzer(SQL_PARSER, metadata); } public TypeRegistry getTypeRegistry() @@ -627,15 +628,7 @@ private List executeProjectionWithAll(String projection, Type expectedTy private RowExpression toRowExpression(Session session, Expression projectionExpression) { - Map, Type> expressionTypes = getExpressionTypes( - session, - metadata, - SQL_PARSER, - TypeProvider.copyOf(INPUT_TYPES), - projectionExpression, - ImmutableList.of(), - WarningCollector.NOOP); - return toRowExpression(projectionExpression, expressionTypes, INPUT_MAPPING); + return toRowExpression(projectionExpression, typeAnalyzer.getTypes(session, TypeProvider.copyOf(INPUT_TYPES), projectionExpression), INPUT_MAPPING); } private Object selectSingleValue(OperatorFactory operatorFactory, Type type, Session session) @@ -870,7 +863,7 @@ protected Void visitSymbolReference(SymbolReference node, Void context) private Object interpret(Expression expression, Type expectedType, Session session) { - Map, Type> expressionTypes = getExpressionTypes(session, metadata, SQL_PARSER, SYMBOL_TYPES, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = typeAnalyzer.getTypes(session, SYMBOL_TYPES, expression); ExpressionInterpreter evaluator = ExpressionInterpreter.expressionInterpreter(expression, metadata, session, expressionTypes); Object result = evaluator.evaluate(symbol -> { diff --git a/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java b/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java index 3a2810049625c..203d2900537af 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestExpressionInterpreter.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.scalar.FunctionAssertions; @@ -31,6 +30,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.ExpressionInterpreter; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; @@ -69,13 +69,11 @@ import static io.prestosql.sql.ExpressionFormatter.formatExpression; import static io.prestosql.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static io.prestosql.sql.ParsingUtil.createParsingOptions; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.sql.planner.ExpressionInterpreter.expressionInterpreter; import static io.prestosql.sql.planner.ExpressionInterpreter.expressionOptimizer; import static io.prestosql.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static io.prestosql.util.DateTimeZoneIndex.getDateTimeZone; import static java.lang.String.format; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; @@ -116,6 +114,7 @@ public class TestExpressionInterpreter private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(SQL_PARSER, METADATA); @Test public void testAnd() @@ -1454,7 +1453,7 @@ private static Object optimize(@Language("SQL") String expression) Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, parsedExpression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, SYMBOL_TYPES, parsedExpression); ExpressionInterpreter interpreter = expressionOptimizer(parsedExpression, METADATA, TEST_SESSION, expressionTypes); return interpreter.optimize(symbol -> { switch (symbol.getName().toLowerCase(ENGLISH)) { @@ -1511,7 +1510,7 @@ private static void assertRoundTrip(String expression) private static Object evaluate(Expression expression) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, SYMBOL_TYPES, expression); ExpressionInterpreter interpreter = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes); return interpreter.evaluate(); diff --git a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java index 459f335a3f35e..7c98ae1e8693d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/io/prestosql/sql/gen/PageProcessorBenchmark.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.SequencePageBuilder; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.DriverYieldSignal; @@ -30,6 +29,7 @@ import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; @@ -65,8 +65,6 @@ import static io.prestosql.operator.scalar.FunctionAssertions.createExpression; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.stream.Collectors.toList; @@ -80,8 +78,8 @@ public class PageProcessorBenchmark { private static final Map TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR); - private static final SqlParser SQL_PARSER = new SqlParser(); private static final Metadata METADATA = createTestMetadataManager(); + private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA); private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); private static final int POSITIONS = 1024; @@ -180,7 +178,7 @@ private RowExpression rowExpression(String value) { Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes)); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); + Map, Type> expressionTypes = TYPE_ANALYZER.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression); return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java index e3f7f760c4316..2d26b5d0b9532 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestTypeValidator.java @@ -21,6 +21,7 @@ import io.prestosql.connector.ConnectorId; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.FunctionKind; +import io.prestosql.metadata.MetadataManager; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; import io.prestosql.spi.connector.ColumnHandle; @@ -393,7 +394,8 @@ public void testInvalidUnion() private void assertTypesValid(PlanNode node) { - TYPE_VALIDATOR.validate(node, TEST_SESSION, createTestMetadataManager(), SQL_PARSER, symbolAllocator.getTypes(), WarningCollector.NOOP); + MetadataManager metadata = createTestMetadataManager(); + TYPE_VALIDATOR.validate(node, TEST_SESSION, metadata, new TypeAnalyzer(SQL_PARSER, metadata), symbolAllocator.getTypes(), WarningCollector.NOOP); } private static PlanNodeId newId() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java index b10f0add15bc3..0e12314ff29dd 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java @@ -26,6 +26,7 @@ import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.Type; import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -53,7 +54,7 @@ public class TestPushPredicateIntoTableScan @BeforeClass public void setUpBeforeClass() { - pushPredicateIntoTableScan = new PushPredicateIntoTableScan(tester().getMetadata(), new SqlParser()); + pushPredicateIntoTableScan = new PushPredicateIntoTableScan(tester().getMetadata(), new TypeAnalyzer(new SqlParser(), tester().getMetadata())); connectorId = tester().getCurrentConnectorId(); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java index ae0cbaa33e3e2..28eb9868829db 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -20,6 +20,7 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; @@ -118,7 +119,7 @@ private static void assertSimplifies(String expression, String expected) { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected)); - Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), METADATA, LITERAL_ENCODER, SQL_PARSER); + Expression rewritten = rewrite(actualExpression, TEST_SESSION, new SymbolAllocator(booleanSymbolTypeMapFor(actualExpression)), METADATA, LITERAL_ENCODER, new TypeAnalyzer(SQL_PARSER, METADATA)); assertEquals( normalize(rewritten), normalize(expectedExpression)); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java index 7b6c8f7212fce..b7cd186dc8c0c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java @@ -22,7 +22,7 @@ import io.prestosql.spi.Plugin; import io.prestosql.split.PageSourceManager; import io.prestosql.split.SplitManager; -import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.testing.LocalQueryRunner; import io.prestosql.transaction.TransactionManager; @@ -48,7 +48,7 @@ public class RuleTester private final SplitManager splitManager; private final PageSourceManager pageSourceManager; private final AccessControl accessControl; - private final SqlParser sqlParser; + private final TypeAnalyzer typeAnalyzer; public RuleTester() { @@ -91,7 +91,7 @@ public RuleTester(List plugins, Map sessionProperties, O this.splitManager = queryRunner.getSplitManager(); this.pageSourceManager = queryRunner.getPageSourceManager(); this.accessControl = queryRunner.getAccessControl(); - this.sqlParser = queryRunner.getSqlParser(); + this.typeAnalyzer = new TypeAnalyzer(queryRunner.getSqlParser(), metadata); } public RuleAssert assertThat(Rule rule) @@ -120,12 +120,9 @@ public PageSourceManager getPageSourceManager() return pageSourceManager; } - // TODO: this is only being used by rules that need to get the type of an expression - // In the short term, it should be encapsulated into something that knows how to provide types - // Rules should *not* need to use the parser otherwise. - public SqlParser getSqlParser() + public TypeAnalyzer getTypeAnalyzer() { - return sqlParser; + return typeAnalyzer; } public ConnectorId getCurrentConnectorId() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java index e659209b6b422..6cdc7591fbe68 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateSorts.java @@ -19,6 +19,7 @@ import io.prestosql.spi.block.SortOrder; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.assertions.ExpectedValueProvider; import io.prestosql.sql.planner.assertions.PlanMatchPattern; @@ -89,7 +90,7 @@ public void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new AddExchanges(getQueryRunner().getMetadata(), new SqlParser()), + new AddExchanges(getQueryRunner().getMetadata(), new TypeAnalyzer(new SqlParser(), getQueryRunner().getMetadata())), new PruneUnreferencedOutputs(), new IterativeOptimizer( new RuleStatsRecorder(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java index 1c9d617c91a17..a61aad2080c61 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestExpressionEquivalence.java @@ -21,6 +21,7 @@ import io.prestosql.sql.parser.ParsingOptions; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.tree.Expression; import org.intellij.lang.annotations.Language; @@ -41,7 +42,7 @@ public class TestExpressionEquivalence { private static final SqlParser SQL_PARSER = new SqlParser(); private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); - private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, SQL_PARSER); + private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, new TypeAnalyzer(SQL_PARSER, METADATA)); @Test public void testEquivalent() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java index dfaa88c1eb4d6..c4a46ae629d2e 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestReorderWindows.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.spi.block.SortOrder; import io.prestosql.sql.planner.RuleStatsRecorder; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.assertions.ExpectedValueProvider; import io.prestosql.sql.planner.assertions.PlanMatchPattern; @@ -322,7 +323,7 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter { List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new PredicatePushDown(getQueryRunner().getMetadata(), getQueryRunner().getSqlParser()), + new PredicatePushDown(getQueryRunner().getMetadata(), new TypeAnalyzer(getQueryRunner().getSqlParser(), getQueryRunner().getMetadata())), new IterativeOptimizer( new RuleStatsRecorder(), getQueryRunner().getStatsCalculator(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 0184499637f9f..2bfffc331b8e2 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -25,6 +25,7 @@ import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; @@ -48,8 +49,6 @@ public class TestValidateAggregationsWithDefaultValues extends BasePlanTest { - private static final SqlParser SQL_PARSER = new SqlParser(); - private Metadata metadata; private PlanBuilder builder; private Symbol symbol; @@ -192,7 +191,7 @@ private void validatePlan(PlanNode root, boolean forceSingleNode) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, SQL_PARSER, TypeProvider.empty(), WarningCollector.NOOP); + new ValidateAggregationsWithDefaultValues(forceSingleNode).validate(root, session, metadata, new TypeAnalyzer(new SqlParser(), metadata), TypeProvider.empty(), WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java index ae2576516a7a0..23e5e1684c3ae 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestValidateStreamingAggregations.java @@ -22,8 +22,8 @@ import io.prestosql.plugin.tpch.TpchColumnHandle; import io.prestosql.plugin.tpch.TpchTableHandle; import io.prestosql.plugin.tpch.TpchTransactionHandle; -import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.planner.assertions.BasePlanTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; @@ -41,7 +41,7 @@ public class TestValidateStreamingAggregations extends BasePlanTest { private Metadata metadata; - private SqlParser sqlParser; + private TypeAnalyzer typeAnalyzer; private PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private TableHandle nationTableHandle; @@ -49,7 +49,7 @@ public class TestValidateStreamingAggregations public void setup() { metadata = getQueryRunner().getMetadata(); - sqlParser = getQueryRunner().getSqlParser(); + typeAnalyzer = new TypeAnalyzer(getQueryRunner().getSqlParser(), metadata); ConnectorId connectorId = getCurrentConnectorId(); nationTableHandle = new TableHandle( @@ -109,7 +109,7 @@ private void validatePlan(Function planProvider) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateStreamingAggregations().validate(planNode, session, metadata, sqlParser, types, WarningCollector.NOOP); + new ValidateStreamingAggregations().validate(planNode, session, metadata, typeAnalyzer, types, WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java index cac7e0d97a230..046898b61722e 100644 --- a/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/io/prestosql/type/BenchmarkDecimalOperators.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.prestosql.RowPagesBuilder; import io.prestosql.Session; -import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.DriverYieldSignal; import io.prestosql.operator.project.PageProcessor; @@ -30,11 +29,11 @@ import io.prestosql.sql.gen.PageFunctionCompiler; import io.prestosql.sql.parser.SqlParser; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.TypeProvider; import io.prestosql.sql.relational.RowExpression; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.NodeRef; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Measurement; @@ -70,12 +69,10 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DecimalType.createDecimalType; import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static io.prestosql.testing.TestingConnectorSession.SESSION; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.math.BigInteger.ONE; import static java.math.BigInteger.ZERO; -import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; import static org.openjdk.jmh.annotations.Scope.Thread; @@ -546,6 +543,7 @@ private Object execute(BaseState state) private static class BaseState { private final MetadataManager metadata = createTestMetadataManager(); + private final TypeAnalyzer typeAnalyzer = new TypeAnalyzer(new SqlParser(), metadata); private final Session session = testSessionBuilder().build(); private final Random random = new Random(); @@ -613,8 +611,15 @@ private RowExpression rowExpression(String value) { Expression expression = createExpression(value, metadata, TypeProvider.copyOf(symbolTypes)); - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, metadata, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyList(), WarningCollector.NOOP); - return SqlToRowExpressionTranslator.translate(expression, SCALAR, expressionTypes, sourceLayout, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); + return SqlToRowExpressionTranslator.translate( + expression, + SCALAR, + typeAnalyzer.getTypes(TEST_SESSION, TypeProvider.copyOf(symbolTypes), expression), + sourceLayout, + metadata.getFunctionRegistry(), + metadata.getTypeManager(), + TEST_SESSION, + true); } private Object generateRandomValue(Type type) diff --git a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java index a11c3c7abedc4..45d55e5bb99a2 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueryFramework.java @@ -34,6 +34,7 @@ import io.prestosql.sql.planner.Plan; import io.prestosql.sql.planner.PlanFragmenter; import io.prestosql.sql.planner.PlanOptimizers; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.tree.ExplainType; import io.prestosql.testing.MaterializedResult; @@ -345,7 +346,7 @@ private QueryExplainer getQueryExplainer() CostCalculator costCalculator = new CostCalculatorUsingExchanges(taskCountEstimator); List optimizers = new PlanOptimizers( metadata, - sqlParser, + new TypeAnalyzer(sqlParser, metadata), featuresConfig, new TaskManagerConfig(), forceSingleNode,