Skip to content

Commit

Permalink
Require NodeLocation in syntax tree
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Oct 2, 2024
1 parent 2ce6829 commit f63ecb5
Show file tree
Hide file tree
Showing 220 changed files with 559 additions and 773 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ private void analyzeWindow(ResolvedWindow window, Context context, Node original

// validate frame start and end types
FrameBound.Type startType = frame.getStart().getType();
FrameBound.Type endType = frame.getEnd().orElse(new FrameBound(CURRENT_ROW)).getType();
FrameBound.Type endType = frame.getEnd().map(FrameBound::getType).orElse(CURRENT_ROW);
if (startType == UNBOUNDED_FOLLOWING) {
throw semanticException(INVALID_WINDOW_FRAME, frame, "Window frame start cannot be UNBOUNDED FOLLOWING");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1326,15 +1326,15 @@ private List<Property> processTableExecuteArguments(TableExecute node, TableProc
if (!names.add(name.getCanonicalValue())) {
throw semanticException(DUPLICATE_PROPERTY, argument, "Duplicate named argument: %s", name);
}
properties.add(new Property(argument.getLocation(), name, argument.getValue()));
properties.add(new Property(argument.getLocation().orElseThrow(), name, argument.getValue()));
}
}
else {
// all properties unnamed
int pos = 0;
for (CallArgument argument : arguments) {
Identifier name = new Identifier(procedureMetadata.getProperties().get(pos).getName());
properties.add(new Property(argument.getLocation(), name, argument.getValue()));
properties.add(new Property(argument.getLocation().orElseThrow(), name, argument.getValue()));
pos++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public Visitor(
protected Node visitExplainAnalyze(ExplainAnalyze node, Void context)
{
Statement statement = (Statement) process(node.getStatement(), context);
return new ExplainAnalyze(node.getLocation(), statement, node.isVerbose());
return new ExplainAnalyze(node.getLocation().orElseThrow(), statement, node.isVerbose());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ public Visitor(Session session)
protected Node visitExplain(Explain node, Void context)
{
Statement statement = (Statement) process(node.getStatement(), null);
return new Explain(node.getLocation(), statement, node.getOptions());
return new Explain(node.getLocation().orElseThrow(), statement, node.getOptions());
}

@Override
protected Node visitExplainAnalyze(ExplainAnalyze node, Void context)
{
Statement statement = (Statement) process(node.getStatement(), null);
return new ExplainAnalyze(node.getLocation(), statement, node.isVerbose());
return new ExplainAnalyze(node.getLocation().orElseThrow(), statement, node.isVerbose());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ public void testInsertWithColumnMasking()
.expression("clerk")
.build());
assertThat(assertions.query("INSERT INTO orders SELECT * FROM orders"))
.failure().hasMessage("Insert into table with column masks is not supported");
.failure().hasMessage("line 1:1: Insert into table with column masks is not supported");
}

@Test
Expand Down Expand Up @@ -906,7 +906,7 @@ public void testColumnMaskWithHiddenColumns()
.skippingTypesCheck()
.matches("VALUES 'POLAND'");
assertThat(assertions.query("INSERT INTO mock.tiny.nation_with_hidden_column SELECT * FROM mock.tiny.nation_with_hidden_column"))
.failure().hasMessage("Insert into table with column masks is not supported");
.failure().hasMessage("line 1:1: Insert into table with column masks is not supported");
assertThat(assertions.query("DELETE FROM mock.tiny.nation_with_hidden_column"))
.failure().hasMessage("line 1:1: Delete from table with column mask");
assertThat(assertions.query("UPDATE mock.tiny.nation_with_hidden_column SET name = 'X'"))
Expand Down
7 changes: 0 additions & 7 deletions core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import io.trino.sql.tree.QuerySpecification;
import io.trino.sql.tree.Relation;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.Select;
import io.trino.sql.tree.SelectItem;
import io.trino.sql.tree.SingleColumn;
Expand All @@ -42,7 +41,6 @@
import io.trino.sql.tree.Table;
import io.trino.sql.tree.TableSubquery;
import io.trino.sql.tree.Values;
import io.trino.sql.tree.WhenClause;
import io.trino.sql.tree.WindowDefinition;

import java.util.List;
Expand Down Expand Up @@ -131,11 +129,6 @@ public static Expression equal(Expression left, Expression right)
return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, left, right);
}

public static Expression caseWhen(Expression operand, Expression result)
{
return new SearchedCaseExpression(ImmutableList.of(new WhenClause(operand, result)), Optional.empty());
}

public static Expression functionCall(String name, Expression... arguments)
{
return new FunctionCall(QualifiedName.of(name), ImmutableList.copyOf(arguments));
Expand Down
37 changes: 19 additions & 18 deletions core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ public Node visitInsertInto(SqlBaseParser.InsertIntoContext context)
}

return new Insert(
getLocation(context),
new Table(getQualifiedName(context.qualifiedName())),
columnAliases,
(Query) visit(context.rootQuery()));
Expand Down Expand Up @@ -1164,7 +1165,7 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context)
rowCount = new Parameter(getLocation(context.offset.QUESTION_MARK()), parameterPosition);
parameterPosition++;
}
offset = Optional.of(new Offset(Optional.of(getLocation(context.OFFSET())), rowCount));
offset = Optional.of(new Offset(getLocation(context.OFFSET()), rowCount));
}

Optional<Node> limit = Optional.empty();
Expand All @@ -1179,7 +1180,7 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context)
parameterPosition++;
}
}
limit = Optional.of(new FetchFirst(Optional.of(getLocation(context.FETCH())), rowCount, context.TIES() != null));
limit = Optional.of(new FetchFirst(getLocation(context.FETCH()), rowCount, context.TIES() != null));
}
else if (context.LIMIT() != null) {
if (context.limit == null) {
Expand All @@ -1197,7 +1198,7 @@ else if (context.limit.rowCount().INTEGER_VALUE() != null) {
parameterPosition++;
}

limit = Optional.of(new Limit(Optional.of(getLocation(context.LIMIT())), rowCount));
limit = Optional.of(new Limit(getLocation(context.LIMIT()), rowCount));
}

if (term instanceof QuerySpecification query) {
Expand Down Expand Up @@ -1397,7 +1398,7 @@ public Node visitExplain(SqlBaseParser.ExplainContext context)
@Override
public Node visitExplainAnalyze(SqlBaseParser.ExplainAnalyzeContext context)
{
return new ExplainAnalyze(getLocation(context), context.VERBOSE() != null, (Statement) visit(context.statement()));
return new ExplainAnalyze(getLocation(context), (Statement) visit(context.statement()), context.VERBOSE() != null);
}

@Override
Expand Down Expand Up @@ -1464,14 +1465,14 @@ public Node visitShowColumns(SqlBaseParser.ShowColumnsContext context)
@Override
public Node visitShowStats(SqlBaseParser.ShowStatsContext context)
{
return new ShowStats(Optional.of(getLocation(context)), new Table(getQualifiedName(context.qualifiedName())));
return new ShowStats(getLocation(context), new Table(getQualifiedName(context.qualifiedName())));
}

@Override
public Node visitShowStatsForQuery(SqlBaseParser.ShowStatsForQueryContext context)
{
Query query = (Query) visit(context.rootQuery());
return new ShowStats(Optional.of(getLocation(context)), new TableSubquery(query));
return new ShowStats(getLocation(context), new TableSubquery(query));
}

@Override
Expand Down Expand Up @@ -2058,7 +2059,7 @@ public Node visitTableArgument(SqlBaseParser.TableArgumentContext context)

Optional<OrderBy> orderBy = Optional.empty();
if (context.ORDER() != null) {
orderBy = Optional.of(new OrderBy(visit(context.sortItem(), SortItem.class)));
orderBy = Optional.of(new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class)));
}

Optional<EmptyTableTreatment> emptyTableTreatment = Optional.empty();
Expand Down Expand Up @@ -2444,7 +2445,7 @@ public Node visitExtract(SqlBaseParser.ExtractContext context)
public Node visitListagg(SqlBaseParser.ListaggContext context)
{
Optional<Window> window = Optional.empty();
OrderBy orderBy = new OrderBy(visit(context.sortItem(), SortItem.class));
OrderBy orderBy = new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class));
boolean distinct = isDistinct(context.setQuantifier());

Expression expression = (Expression) visit(context.expression());
Expand Down Expand Up @@ -2537,7 +2538,7 @@ else if (errorBehaviorContext.ERROR() != null) {
}

return new JsonExists(
Optional.of(getLocation(context)),
getLocation(context),
jsonPathInvocation,
errorBehavior);
}
Expand Down Expand Up @@ -2584,7 +2585,7 @@ else if (errorBehaviorContext.DEFAULT() != null) {
}

return new JsonValue(
Optional.of(getLocation(context)),
getLocation(context),
jsonPathInvocation,
returnedType,
emptyBehavior,
Expand Down Expand Up @@ -2662,7 +2663,7 @@ else if (errorBehaviorContext.OBJECT() != null) {
}

return new JsonQuery(
Optional.of(getLocation(context)),
getLocation(context),
jsonPathInvocation,
returnedType,
jsonOutputFormat,
Expand All @@ -2689,7 +2690,7 @@ public Node visitJsonPathInvocation(SqlBaseParser.JsonPathInvocationContext cont
Optional<Identifier> pathName = visitIfPresent(context.pathName, Identifier.class);
List<JsonPathParameter> pathParameters = visit(context.jsonArgument(), JsonPathParameter.class);

return new JsonPathInvocation(Optional.of(getLocation(context)), jsonInput, inputFormat, jsonPath, pathName, pathParameters);
return new JsonPathInvocation(getLocation(context), jsonInput, inputFormat, jsonPath, pathName, pathParameters);
}

private static JsonFormat getJsonFormat(SqlBaseParser.JsonRepresentationContext context)
Expand All @@ -2710,7 +2711,7 @@ private static JsonFormat getJsonFormat(SqlBaseParser.JsonRepresentationContext
public Node visitJsonArgument(SqlBaseParser.JsonArgumentContext context)
{
return new JsonPathParameter(
Optional.of(getLocation(context)),
getLocation(context),
(Identifier) visit(context.identifier()),
(Expression) visit(context.jsonValueExpression().expression()),
Optional.ofNullable(context.jsonValueExpression().jsonRepresentation())
Expand All @@ -2732,7 +2733,7 @@ public Node visitJsonObject(SqlBaseParser.JsonObjectContext context)
jsonOutputFormat = Optional.of(getJsonFormat(context.jsonRepresentation()));
}

return new JsonObject(Optional.of(getLocation(context)), members, nullOnNull, uniqueKeys, returnedType, jsonOutputFormat);
return new JsonObject(getLocation(context), members, nullOnNull, uniqueKeys, returnedType, jsonOutputFormat);
}

@Override
Expand Down Expand Up @@ -2760,7 +2761,7 @@ public Node visitJsonArray(SqlBaseParser.JsonArrayContext context)
jsonOutputFormat = Optional.of(getJsonFormat(context.jsonRepresentation()));
}

return new JsonArray(Optional.of(getLocation(context)), elements, nullOnNull, returnedType, jsonOutputFormat);
return new JsonArray(getLocation(context), elements, nullOnNull, returnedType, jsonOutputFormat);
}

@Override
Expand Down Expand Up @@ -2857,7 +2858,7 @@ public Node visitFunctionCall(SqlBaseParser.FunctionCallContext context)

Optional<OrderBy> orderBy = Optional.empty();
if (context.ORDER() != null) {
orderBy = Optional.of(new OrderBy(visit(context.sortItem(), SortItem.class)));
orderBy = Optional.of(new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class)));
}

QualifiedName name = getQualifiedName(context.qualifiedName());
Expand Down Expand Up @@ -2982,7 +2983,7 @@ public Node visitMeasure(SqlBaseParser.MeasureContext context)
public Node visitLambda(SqlBaseParser.LambdaContext context)
{
List<LambdaArgumentDeclaration> arguments = visit(context.identifier(), Identifier.class).stream()
.map(LambdaArgumentDeclaration::new)
.map(argument -> new LambdaArgumentDeclaration(argument.getLocation().orElseThrow(), argument))
.collect(toList());

Expression body = (Expression) visit(context.expression());
Expand Down Expand Up @@ -3103,7 +3104,7 @@ public Node visitGroupingOperation(SqlBaseParser.GroupingOperationContext contex
.map(this::getQualifiedName)
.collect(toList());

return new GroupingOperation(Optional.of(getLocation(context)), arguments);
return new GroupingOperation(getLocation(context), arguments);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;
Expand All @@ -32,7 +31,7 @@ public class AddColumn

public AddColumn(NodeLocation location, QualifiedName name, ColumnDefinition column, boolean tableExists, boolean columnNotExists)
{
super(Optional.of(location));
super(location);
this.name = requireNonNull(name, "name is null");
this.column = requireNonNull(column, "column is null");
this.tableExists = tableExists;
Expand Down
21 changes: 13 additions & 8 deletions core/trino-parser/src/main/java/io/trino/sql/tree/AllColumns.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,35 @@ public class AllColumns
private final List<Identifier> aliases;
private final Optional<Expression> target;

@Deprecated
public AllColumns()
{
this(Optional.empty(), Optional.empty(), ImmutableList.of());
}

public AllColumns(Expression target)
@Deprecated
public AllColumns(Expression target, List<Identifier> aliases)
{
this(Optional.empty(), Optional.of(target), ImmutableList.of());
this(Optional.empty(), Optional.of(target), aliases);
}

public AllColumns(Expression target, List<Identifier> aliases)
@Deprecated
public AllColumns(Optional<NodeLocation> location, Optional<Expression> target, List<Identifier> aliases)
{
this(Optional.empty(), Optional.of(target), aliases);
super(location);
this.aliases = ImmutableList.copyOf(requireNonNull(aliases, "aliases is null"));
this.target = requireNonNull(target, "target is null");
}

public AllColumns(NodeLocation location, Optional<Expression> target, List<Identifier> aliases)
public AllColumns(NodeLocation location)
{
this(Optional.of(location), target, aliases);
this(location, Optional.empty(), ImmutableList.of());
}

public AllColumns(Optional<NodeLocation> location, Optional<Expression> target, List<Identifier> aliases)
public AllColumns(NodeLocation location, Optional<Expression> target, List<Identifier> aliases)
{
super(location);
this.aliases = ImmutableList.copyOf(requireNonNull(aliases, "aliases is null"));
this.aliases = ImmutableList.copyOf(aliases);
this.target = requireNonNull(target, "target is null");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@
public class AllRows
extends Expression
{
@Deprecated
public AllRows()
{
this(Optional.empty());
super(Optional.empty());
}

public AllRows(NodeLocation location)
{
this(Optional.of(location));
}

public AllRows(Optional<NodeLocation> location)
{
super(location);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;
Expand All @@ -30,7 +29,7 @@ public class Analyze

public Analyze(NodeLocation location, QualifiedName tableName, List<Property> properties)
{
super(Optional.of(location));
super(location);
this.tableName = requireNonNull(tableName, "tableName is null");
this.properties = ImmutableList.copyOf(requireNonNull(properties, "properties is null"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;
Expand All @@ -35,7 +34,7 @@ public enum Type

public AnchorPattern(NodeLocation location, Type type)
{
super(Optional.of(location));
super(location);
this.type = requireNonNull(type, "type is null");
}

Expand Down
Loading

0 comments on commit f63ecb5

Please sign in to comment.