Skip to content

Commit

Permalink
[CALCITE-6071] RexCall should carry source position information for r…
Browse files Browse the repository at this point in the history
…untime error reporting

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed Sep 20, 2024
1 parent 204ae1f commit be044ff
Show file tree
Hide file tree
Showing 30 changed files with 838 additions and 396 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4781,7 +4781,7 @@ private static class QuantifyCollectionImplementor extends AbstractRexCallImplem
final ParameterExpression lambdaArg =
Expressions.parameter(translator.typeFactory.getJavaClass(rightComponentType), "el");
final RexCall binaryImplementorRexCall =
(RexCall) translator.builder.makeCall(binaryOperator, leftRex,
(RexCall) translator.builder.makeCall(call.getParserPosition(), binaryOperator, leftRex,
translator.builder.makeDynamicParam(rightComponentType, 0));
final List<RexToLixTranslator.Result> binaryImplementorArgs =
ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ private static RexNode canonizeNode(RexBuilder rexBuilder, RexNode condition) {
if (newOperands.size() < 2) {
return newOperands.values().iterator().next();
}
return rexBuilder.makeCall(call.getOperator(),
return rexBuilder.makeCall(call.getParserPosition(), call.getOperator(),
ImmutableList.copyOf(newOperands.values()));
}
case EQUALS:
Expand All @@ -362,7 +362,8 @@ private static RexNode canonizeNode(RexBuilder rexBuilder, RexNode condition) {
RexCall call = (RexCall) condition;
RexNode left = canonizeNode(rexBuilder, call.getOperands().get(0));
RexNode right = canonizeNode(rexBuilder, call.getOperands().get(1));
call = (RexCall) rexBuilder.makeCall(call.getOperator(), left, right);
call =
(RexCall) rexBuilder.makeCall(call.getParserPosition(), call.getOperator(), left, right);

if (left.toString().compareTo(right.toString()) <= 0) {
return call;
Expand All @@ -384,10 +385,11 @@ private static RexNode canonizeNode(RexBuilder rexBuilder, RexNode condition) {
RexNode right = canonizeNode(rexBuilder, call.getOperands().get(1));

if (left.toString().compareTo(right.toString()) <= 0) {
return rexBuilder.makeCall(call.getOperator(), left, right);
return rexBuilder.makeCall(call.getParserPosition(), call.getOperator(), left, right);
}

RexNode newCall = rexBuilder.makeCall(call.getOperator(), right, left);
RexNode newCall =
rexBuilder.makeCall(call.getParserPosition(), call.getOperator(), right, left);
// new call should not be used if its inferred type is not same as old
if (!newCall.getType().equals(call.getType())) {
return call;
Expand Down Expand Up @@ -1973,7 +1975,8 @@ public static MutableAggregate permute(MutableAggregate aggregate,
final SqlAggFunction aggFunction = aggregateCall.getAggregation().getRollup();
if (aggFunction != null) {
newAggCall =
AggregateCall.create(aggFunction, aggregateCall.isDistinct(),
AggregateCall.create(aggregateCall.getParserPosition(),
aggFunction, aggregateCall.isDistinct(),
aggregateCall.isApproximate(), aggregateCall.ignoreNulls(),
aggregateCall.rexList,
ImmutableList.of(target.groupSet.cardinality() + i), -1,
Expand Down Expand Up @@ -2044,7 +2047,7 @@ public static MutableAggregate permute(MutableAggregate aggregate,
if (!isAllowBuild) {
return null;
}
return AggregateCall.create(aggregation,
return AggregateCall.create(queryAggCall.getParserPosition(), aggregation,
queryAggCall.isDistinct(), queryAggCall.isApproximate(),
queryAggCall.ignoreNulls(), queryAggCall.rexList,
newArgList, -1, queryAggCall.distinctKeys,
Expand Down
75 changes: 57 additions & 18 deletions core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Optionality;
Expand All @@ -50,6 +51,14 @@
public class AggregateCall {
//~ Instance fields --------------------------------------------------------

/**
* Some aggregate calls may produce runtime errors. For these
* we need to keep around the original source position information
* so that the runtime can produce error messages pointing to
* the offending source operation. For "safe" aggregations
* this field may be ZERO.
*/
private final SqlParserPos pos;
private final SqlAggFunction aggFunction;

private final boolean distinct;
Expand Down Expand Up @@ -84,14 +93,17 @@ public AggregateCall(
List<Integer> argList,
RelDataType type,
String name) {
this(aggFunction, distinct, false, false,
this(SqlParserPos.ZERO, aggFunction, distinct, false, false,
ImmutableList.of(), argList, -1, null,
RelCollations.EMPTY, type, name);
}

/**
* Creates an AggregateCall.
*
* @param pos Source position for this aggregate.
* Ideally it should only be ZERO when the aggregate
* can never fail at runtime.
* @param aggFunction Aggregate function
* @param distinct Whether distinct
* @param approximate Whether approximate
Expand All @@ -106,11 +118,12 @@ public AggregateCall(
* @param type Result type
* @param name Name (may be null)
*/
private AggregateCall(SqlAggFunction aggFunction, boolean distinct,
private AggregateCall(SqlParserPos pos, SqlAggFunction aggFunction, boolean distinct,
boolean approximate, boolean ignoreNulls,
List<RexNode> rexList, List<Integer> argList,
int filterArg, @Nullable ImmutableBitSet distinctKeys,
RelCollation collation, RelDataType type, @Nullable String name) {
this.pos = pos;
this.type = requireNonNull(type, "type");
this.name = name;
this.aggFunction = requireNonNull(aggFunction, "aggFunction");
Expand Down Expand Up @@ -187,6 +200,17 @@ public static AggregateCall create(SqlAggFunction aggFunction,
@Nullable ImmutableBitSet distinctKeys, RelCollation collation,
int groupCount,
RelNode input, @Nullable RelDataType type, @Nullable String name) {
return create(SqlParserPos.ZERO, aggFunction, distinct, approximate,
ignoreNulls, rexList, argList, filterArg, distinctKeys, collation, groupCount,
input, type, name);
}

public static AggregateCall create(SqlParserPos pos, SqlAggFunction aggFunction,
boolean distinct, boolean approximate, boolean ignoreNulls,
List<RexNode> rexList, List<Integer> argList, int filterArg,
@Nullable ImmutableBitSet distinctKeys, RelCollation collation,
int groupCount,
RelNode input, @Nullable RelDataType type, @Nullable String name) {
if (type == null) {
final RelDataTypeFactory typeFactory =
input.getCluster().getTypeFactory();
Expand All @@ -198,7 +222,7 @@ public static AggregateCall create(SqlAggFunction aggFunction,
types, groupCount, filterArg >= 0);
type = aggFunction.inferReturnType(callBinding);
}
return create(aggFunction, distinct, approximate, ignoreNulls,
return create(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand Down Expand Up @@ -245,9 +269,18 @@ public static AggregateCall create(SqlAggFunction aggFunction,
List<RexNode> rexList, List<Integer> argList, int filterArg,
@Nullable ImmutableBitSet distinctKeys, RelCollation collation,
RelDataType type, @Nullable String name) {
return create(SqlParserPos.ZERO, aggFunction, distinct, approximate,
ignoreNulls, rexList, argList, filterArg, distinctKeys, collation, type, name);
}

public static AggregateCall create(SqlParserPos pos, SqlAggFunction aggFunction,
boolean distinct, boolean approximate, boolean ignoreNulls,
List<RexNode> rexList, List<Integer> argList, int filterArg,
@Nullable ImmutableBitSet distinctKeys, RelCollation collation,
RelDataType type, @Nullable String name) {
final boolean distinct2 = distinct
&& (aggFunction.getDistinctOptionality() != Optionality.IGNORED);
return new AggregateCall(aggFunction, distinct2, approximate, ignoreNulls,
return new AggregateCall(pos, aggFunction, distinct2, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -264,7 +297,7 @@ public final boolean isDistinct() {
/** Withs {@link #isDistinct()}. */
public AggregateCall withDistinct(boolean distinct) {
return distinct == this.distinct ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -281,7 +314,7 @@ public final boolean isApproximate() {
/** Withs {@link #isApproximate()}. */
public AggregateCall withApproximate(boolean approximate) {
return approximate == this.approximate ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -297,7 +330,7 @@ public final boolean ignoreNulls() {
/** Withs {@link #ignoreNulls()}. */
public AggregateCall withIgnoreNulls(boolean ignoreNulls) {
return ignoreNulls == this.ignoreNulls ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -323,7 +356,7 @@ public RelCollation getCollation() {
/** Withs {@link #getCollation()}. */
public AggregateCall withCollation(RelCollation collation) {
return collation.equals(this.collation) ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -341,15 +374,15 @@ public final List<Integer> getArgList() {
/** Withs {@link #getArgList()}. */
public AggregateCall withArgList(List<Integer> argList) {
return argList.equals(this.argList) ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

/** Withs {@link #distinctKeys}. */
public AggregateCall withDistinctKeys(
@Nullable ImmutableBitSet distinctKeys) {
return Objects.equals(distinctKeys, this.distinctKeys) ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -374,7 +407,7 @@ public final RelDataType getType() {
/** Withs {@link #name}. */
public AggregateCall withName(@Nullable String name) {
return Objects.equals(name, this.name) ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand Down Expand Up @@ -435,11 +468,16 @@ public boolean hasFilter() {
/** Withs {@link #filterArg}. */
public AggregateCall withFilter(int filterArg) {
return filterArg == this.filterArg ? this
: new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
: new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

public SqlParserPos getParserPosition() {
return this.pos;
}

@Override public boolean equals(@Nullable Object o) {
// Intentionally ignore the position
return o == this
|| o instanceof AggregateCall
&& aggFunction.equals(((AggregateCall) o).aggFunction)
Expand All @@ -453,6 +491,7 @@ public AggregateCall withFilter(int filterArg) {
}

@Override public int hashCode() {
// Ignore the position!
return Objects.hash(aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation);
}
Expand Down Expand Up @@ -492,29 +531,29 @@ public Aggregate.AggCallBinding createBinding(
@Deprecated // to be removed before 2.0
public AggregateCall copy(List<Integer> argList, int filterArg,
@Nullable ImmutableBitSet distinctKeys, RelCollation collation) {
return new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
return new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

@Deprecated // to be removed before 2.0
public AggregateCall copy(List<Integer> argList, int filterArg,
RelCollation collation) {
// ignoring distinctKeys is error-prone
return new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
return new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

@Deprecated // to be removed before 2.0
public AggregateCall copy(List<Integer> argList, int filterArg) {
// ignoring distinctKeys, collation is error-prone
return new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
return new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

@Deprecated // to be removed before 2.0
public AggregateCall copy(List<Integer> argList) {
// ignoring filterArg, distinctKeys, collation is error-prone
return new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
return new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation, type, name);
}

Expand All @@ -539,15 +578,15 @@ public AggregateCall adaptTo(RelNode input, List<Integer> argList,
&& filterArg == this.filterArg
? type
: null;
return create(aggFunction, distinct, approximate, ignoreNulls,
return create(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, argList, filterArg, distinctKeys, collation,
newGroupKeyCount, input, newType, getName());
}

/** Creates a copy of this aggregate call, applying a mapping to its
* arguments. */
public AggregateCall transform(Mappings.TargetMapping mapping) {
return new AggregateCall(aggFunction, distinct, approximate, ignoreNulls,
return new AggregateCall(pos, aggFunction, distinct, approximate, ignoreNulls,
rexList, Mappings.apply2((Mapping) mapping, argList),
hasFilter() ? Mappings.apply(mapping, filterArg) : -1,
distinctKeys == null ? null : distinctKeys.permute(mapping),
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/calcite/rel/core/Window.java
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ public List<AggregateCall> getAggregateCalls(Window windowRel) {
@Override public AggregateCall get(int index) {
final RexWinAggCall aggCall = aggCalls.get(index);
final SqlAggFunction op = (SqlAggFunction) aggCall.getOperator();
return AggregateCall.create(op, aggCall.distinct, false,
return AggregateCall.create(aggCall.getParserPosition(), op, aggCall.distinct, false,
aggCall.ignoreNulls, ImmutableList.of(),
getProjectOrdinals(aggCall.getOperands()),
-1, null, RelCollations.EMPTY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,7 @@ private JoinContext(SqlDialect dialect, Context leftContext,
&& ((RexInputRef) op0).getIndex() >= leftContext.fieldCount) {
// Arguments were of form 'op1 = op0'
final SqlOperator op2 = requireNonNull(call.getOperator().reverse());
return (RexCall) rexBuilder.makeCall(op2, op1, op0);
return (RexCall) rexBuilder.makeCall(call.getParserPosition(), op2, op1, op0);
}
// fall through
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlPostfixOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
Expand Down Expand Up @@ -186,7 +187,8 @@ && isThreeArgCase(project.getProjects().get(singleArg))) {
&& RexLiteral.isNullLiteral(arg2)) {
newProjects.add(arg1);
newProjects.add(filter);
return AggregateCall.create(SqlStdOperatorTable.COUNT, true, false,
return AggregateCall.create(
call.getParserPosition(), SqlStdOperatorTable.COUNT, true, false,
false, call.rexList, ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1, null, RelCollations.EMPTY,
call.getType(), call.getName());
Expand All @@ -205,12 +207,13 @@ && isThreeArgCase(project.getProjects().get(singleArg))) {
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END)
// => COUNT() FILTER (x = 'foo')

final SqlParserPos pos = call.getParserPosition();
if (kind == SqlKind.COUNT // Case C
&& arg1.isA(SqlKind.LITERAL)
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
newProjects.add(filter);
return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
return AggregateCall.create(pos, SqlStdOperatorTable.COUNT, false, false,
false, call.rexList, ImmutableList.of(), newProjects.size() - 1, null,
RelCollations.EMPTY, call.getType(),
call.getName());
Expand All @@ -223,7 +226,7 @@ && isIntLiteral(arg2, BigDecimal.ZERO)) {
final RelDataType dataType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.BIGINT), false);
return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false,
return AggregateCall.create(pos, SqlStdOperatorTable.COUNT, false, false,
false, call.rexList, ImmutableList.of(), newProjects.size() - 1, null,
RelCollations.EMPTY, dataType, call.getName());
} else if ((RexLiteral.isNullLiteral(arg2) // Case A1
Expand All @@ -232,7 +235,7 @@ && isIntLiteral(arg2, BigDecimal.ZERO)) {
&& isIntLiteral(arg2, BigDecimal.ZERO))) {
newProjects.add(arg1);
newProjects.add(filter);
return AggregateCall.create(call.getAggregation(), false,
return AggregateCall.create(pos, call.getAggregation(), false,
false, false, call.rexList, ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1, null, RelCollations.EMPTY,
call.getType(), call.getName());
Expand Down
Loading

0 comments on commit be044ff

Please sign in to comment.