Skip to content

Commit

Permalink
Encapsulate expression type analysis in planner
Browse files Browse the repository at this point in the history
The new class is to facilitate obtaining the type of an expression and its subexpressions
during planning (i.e., when interacting with IR expression) and to remove spurious
dependencies on the SQL parser. It will eventually get removed when we split the AST
from the IR and we encode the type directly into IR expressions.
  • Loading branch information
martint committed Mar 11, 2019
1 parent 35dd636 commit f04adeb
Show file tree
Hide file tree
Showing 52 changed files with 330 additions and 409 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.prestosql.execution.Lifespan;
import io.prestosql.execution.TaskId;
import io.prestosql.execution.TaskStateMachine;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.memory.MemoryPool;
import io.prestosql.memory.QueryContext;
import io.prestosql.metadata.Metadata;
Expand Down Expand Up @@ -52,6 +51,7 @@
import io.prestosql.split.SplitSource;
import io.prestosql.sql.gen.PageFunctionCompiler;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.HashGenerationOptimizer;
import io.prestosql.sql.planner.plan.PlanNodeId;
Expand Down Expand Up @@ -83,7 +83,6 @@
import static io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING;
import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static io.prestosql.sql.relational.SqlToRowExpressionTranslator.translate;
import static io.prestosql.testing.TestingSession.testSessionBuilder;
import static java.lang.String.format;
Expand Down Expand Up @@ -229,14 +228,9 @@ protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNo
Optional<Expression> hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet()));
verify(hashExpression.isPresent());

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
session,
localQueryRunner.getMetadata(),
localQueryRunner.getSqlParser(),
TypeProvider.copyOf(symbolTypes.build()),
hashExpression.get(),
ImmutableList.of(),
WarningCollector.NOOP);
Map<NodeRef<Expression>, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata())
.getTypes(session, TypeProvider.copyOf(symbolTypes.build()), hashExpression.get());

RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionRegistry(), localQueryRunner.getTypeManager(), session, false);

PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,6 @@ public void testPushDownAnd()
private RuleAssert assertRuleApplication()
{
RuleTester tester = tester();
return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser()));
return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,6 @@ public void testPushDownAnd()
private RuleAssert assertRuleApplication()
{
RuleTester tester = tester();
return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getSqlParser()));
return tester().assertThat(new ExtractSpatialLeftJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import io.prestosql.sql.planner.PlanOptimizers;
import io.prestosql.sql.planner.StageExecutionPlan;
import io.prestosql.sql.planner.SubPlan;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.tree.Explain;
import io.prestosql.transaction.TransactionManager;
Expand Down Expand Up @@ -414,7 +415,7 @@ private PlanRoot doAnalyzeQuery()

// plan query
PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, stateMachine.getWarningCollector());
LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, new TypeAnalyzer(sqlParser, metadata), statsCalculator, costCalculator, stateMachine.getWarningCollector());
Plan plan = logicalPlanner.plan(analysis);
queryPlan.set(plan);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
import io.prestosql.sql.planner.CompilerConfig;
import io.prestosql.sql.planner.LocalExecutionPlanner;
import io.prestosql.sql.planner.NodePartitioningManager;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.transaction.TransactionManagerConfig;
Expand Down Expand Up @@ -354,6 +355,7 @@ protected void setup(Binder binder)
binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON);

// type
binder.bind(TypeAnalyzer.class).in(Scopes.SINGLETON);
binder.bind(TypeRegistry.class).in(Scopes.SINGLETON);
binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON);
jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1421,44 +1421,6 @@ public static Signature resolveFunction(FunctionCall node, List<TypeSignaturePro
}
}

public static Map<NodeRef<Expression>, Type> getExpressionTypes(
Session session,
Metadata metadata,
SqlParser sqlParser,
TypeProvider types,
Expression expression,
List<Expression> parameters,
WarningCollector warningCollector)
{
return getExpressionTypes(session, metadata, sqlParser, types, expression, parameters, warningCollector, false);
}

public static Map<NodeRef<Expression>, Type> getExpressionTypes(
Session session,
Metadata metadata,
SqlParser sqlParser,
TypeProvider types,
Expression expression,
List<Expression> parameters,
WarningCollector warningCollector,
boolean isDescribe)
{
return getExpressionTypes(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, warningCollector, isDescribe);
}

public static Map<NodeRef<Expression>, Type> getExpressionTypes(
Session session,
Metadata metadata,
SqlParser sqlParser,
TypeProvider types,
Iterable<Expression> expressions,
List<Expression> parameters,
WarningCollector warningCollector,
boolean isDescribe)
{
return analyzeExpressions(session, metadata, sqlParser, types, expressions, parameters, warningCollector, isDescribe).getExpressionTypes();
}

public static ExpressionAnalysis analyzeExpressions(
Session session,
Metadata metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.PlanOptimizers;
import io.prestosql.sql.planner.SubPlan;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.planPrinter.IoPlanPrinter;
import io.prestosql.sql.planner.planPrinter.PlanPrinter;
Expand Down Expand Up @@ -175,7 +176,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List<Expression
PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();

// plan statement
LogicalPlanner logicalPlanner = new LogicalPlanner(session, planOptimizers, idAllocator, metadata, sqlParser, statsCalculator, costCalculator, warningCollector);
LogicalPlanner logicalPlanner = new LogicalPlanner(session, planOptimizers, idAllocator, metadata, new TypeAnalyzer(sqlParser, metadata), statsCalculator, costCalculator, warningCollector);
return logicalPlanner.plan(analysis);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@
import static io.prestosql.sql.analyzer.AggregationAnalyzer.verifyOrderByAggregations;
import static io.prestosql.sql.analyzer.AggregationAnalyzer.verifySourceAggregations;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static io.prestosql.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions;
import static io.prestosql.sql.analyzer.ExpressionTreeUtils.extractExpressions;
import static io.prestosql.sql.analyzer.ExpressionTreeUtils.extractWindowFunctions;
Expand Down Expand Up @@ -953,15 +952,17 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional<Scope> s
throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references");
}

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
Map<NodeRef<Expression>, Type> expressionTypes = ExpressionAnalyzer.analyzeExpressions(
session,
metadata,
sqlParser,
TypeProvider.empty(),
relation.getSamplePercentage(),
ImmutableList.of(relation.getSamplePercentage()),
analysis.getParameters(),
WarningCollector.NOOP,
analysis.isDescribe());
analysis.isDescribe())
.getExpressionTypes();

ExpressionInterpreter samplePercentageEval = expressionOptimizer(relation.getSamplePercentage(), metadata, session, expressionTypes);

Object samplePercentageObject = samplePercentageEval.optimize(symbol -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.tree.AtTimeZone;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.Expression;
Expand All @@ -36,8 +34,6 @@
import static io.prestosql.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static io.prestosql.spi.type.TimestampType.TIMESTAMP;
import static io.prestosql.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

public class DesugarAtTimeZoneRewriter
Expand All @@ -49,15 +45,15 @@ public static Expression rewrite(Expression expression, Map<NodeRef<Expression>,

private DesugarAtTimeZoneRewriter() {}

public static Expression rewrite(Expression expression, Session session, Metadata metadata, SqlParser sqlParser, SymbolAllocator symbolAllocator)
public static Expression rewrite(Expression expression, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator)
{
requireNonNull(metadata, "metadata is null");
requireNonNull(sqlParser, "sqlParser is null");
requireNonNull(typeAnalyzer, "typeAnalyzer is null");

if (expression instanceof SymbolReference) {
return expression;
}
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP);
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression);

return rewrite(expression, expressionTypes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.PeekingIterator;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.block.Block;
Expand All @@ -34,7 +33,6 @@
import io.prestosql.spi.type.Type;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.InterpretedFunctionInvoker;
import io.prestosql.sql.analyzer.ExpressionAnalyzer;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.tree.AstVisitor;
import io.prestosql.sql.tree.BetweenPredicate;
Expand Down Expand Up @@ -78,7 +76,6 @@
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
import static java.util.Collections.emptyList;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.collectingAndThen;
Expand Down Expand Up @@ -277,7 +274,7 @@ public static ExtractionResult fromPredicate(
Expression predicate,
TypeProvider types)
{
return new Visitor(metadata, session, types).process(predicate, false);
return new Visitor(metadata, session, types, new TypeAnalyzer(new SqlParser(), metadata)).process(predicate, false);
}

private static class Visitor
Expand All @@ -288,14 +285,16 @@ private static class Visitor
private final Session session;
private final TypeProvider types;
private final InterpretedFunctionInvoker functionInvoker;
private final TypeAnalyzer typeAnalyzer;

private Visitor(Metadata metadata, Session session, TypeProvider types)
private Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde());
this.session = requireNonNull(session, "session is null");
this.types = requireNonNull(types, "types is null");
this.functionInvoker = new InterpretedFunctionInvoker(metadata.getFunctionRegistry());
this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
}

private Type checkedTypeLookup(Symbol symbol)
Expand Down Expand Up @@ -424,7 +423,7 @@ else if (symbolExpression instanceof Cast) {
return super.visitComparisonExpression(node, complement);
}

Type castSourceType = typeOf(castExpression.getExpression(), session, metadata, types); // type of expression which is then cast to type of value
Type castSourceType = typeAnalyzer.getType(session, types, castExpression.getExpression()); // type of expression which is then cast to type of value

// we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side
Optional<Expression> coercedExpression = coerceComparisonWithRounding(
Expand Down Expand Up @@ -489,7 +488,7 @@ private boolean isImplicitCoercion(Cast cast)

private Map<NodeRef<Expression>, Type> analyzeExpression(Expression expression)
{
return ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP);
return typeAnalyzer.getTypes(session, types, expression);
}

private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement)
Expand Down Expand Up @@ -757,12 +756,6 @@ protected ExtractionResult visitNullLiteral(NullLiteral node, Boolean complement
}
}

private static Type typeOf(Expression expression, Session session, Metadata metadata, TypeProvider types)
{
Map<NodeRef<Expression>, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP);
return expressionTypes.get(NodeRef.of(expression));
}

private static class NormalizedSimpleComparison
{
private final Expression symbolExpression;
Expand Down
Loading

0 comments on commit f04adeb

Please sign in to comment.