Skip to content

Commit

Permalink
Make aggregation statement compilation robust
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <ltjin@amazon.com>
  • Loading branch information
LantaoJin committed Jul 4, 2024
1 parent 81ae699 commit 1c8a5ae
Show file tree
Hide file tree
Showing 25 changed files with 1,758 additions and 65 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ spotless {
removeUnusedImports()
trimTrailingWhitespace()
endWithNewline()
toggleOffOn()
googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format')
}
}
Expand Down
4 changes: 3 additions & 1 deletion core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ spotless {
removeUnusedImports()
trimTrailingWhitespace()
endWithNewline()
toggleOffOn()
googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format')
}
}
Expand Down Expand Up @@ -112,7 +113,8 @@ jacocoTestCoverageVerification {
'org.opensearch.sql.utils.Constants',
'org.opensearch.sql.datasource.model.DataSource',
'org.opensearch.sql.datasource.model.DataSourceStatus',
'org.opensearch.sql.datasource.model.DataSourceType'
'org.opensearch.sql.datasource.model.DataSourceType',
'org.opensearch.sql.QueryCompilationError'
]
limit {
counter = 'LINE'
Expand Down
66 changes: 66 additions & 0 deletions core/src/main/java/org/opensearch/sql/QueryCompilationError.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql;

import static org.opensearch.sql.common.utils.StringUtils.format;

import lombok.experimental.UtilityClass;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;

/** Grouping error messages from {@link SemanticCheckException} thrown during query compilation. */
@UtilityClass
public class QueryCompilationError {

public static SemanticCheckException fieldNotInGroupByClauseError(String name) {
return new SemanticCheckException(
format(
"Field [%s] must appear in the GROUP BY clause or be used in an aggregate function",
name));
}

public static SemanticCheckException aggregateFunctionNotAllowedInGroupByError(
String functionName) {
return new SemanticCheckException(
format(
"Aggregate function is not allowed in a GROUP BY clause, but found [%s]",
functionName));
}

public static SemanticCheckException nonBooleanExpressionInFilterOrHavingError(ExprType type) {
return new SemanticCheckException(
format(
"FILTER or HAVING expression must be type boolean, but found [%s]", type.typeName()));
}

public static SemanticCheckException aggregateFunctionNotAllowedInFilterError(
String functionName) {
return new SemanticCheckException(
format("Aggregate function is not allowed in a FILTER, but found [%s]", functionName));
}

public static SemanticCheckException windowFunctionNotAllowedError() {
return new SemanticCheckException("Window functions are not allowed in WHERE or HAVING");
}

public static SemanticCheckException unsupportedAggregateFunctionError(String functionName) {
return new SemanticCheckException(format("Unsupported aggregation function %s", functionName));
}

public static SemanticCheckException ordinalRefersOutOfBounds(int ordinal) {
return new SemanticCheckException(
format("Ordinal [%d] is out of bound of select item list", ordinal));
}

public static SemanticCheckException groupByClauseIsMissingError(UnresolvedExpression expr) {
return new SemanticCheckException(
format(
"Explicit GROUP BY clause is required because expression [%s] contains non-aggregated"
+ " column",
expr));
}
}
140 changes: 122 additions & 18 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.DataSourceSchemaName;
import org.opensearch.sql.QueryCompilationError;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand All @@ -47,6 +52,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
Expand Down Expand Up @@ -81,6 +87,7 @@
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.window.WindowFunctionExpression;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalCloseCursor;
Expand All @@ -102,6 +109,7 @@
import org.opensearch.sql.planner.logical.LogicalValues;
import org.opensearch.sql.planner.physical.datasource.DataSourceTable;
import org.opensearch.sql.storage.Table;
import org.opensearch.sql.utils.ExpressionUtils;
import org.opensearch.sql.utils.ParseUtils;

/**
Expand Down Expand Up @@ -235,32 +243,51 @@ public LogicalPlan visitLimit(Limit node, AnalysisContext context) {
public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);
verifyCondition(condition);

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expression optimized = optimizer.optimize(condition, context);
return new LogicalFilter(child, optimized);
}

private void verifyCondition(Expression condition) {
// TODO Remove this when adding support for syntax - nested(path, condition)
// Current WHERE nested(path, condition) is not a valid boolean condition.
boolean isNestedFunction =
condition instanceof FunctionExpression
&& ((FunctionExpression) condition).getFunctionName().equals(FunctionName.of("nested"));
// Check if the filter condition is a valid predicate.
if (condition.type() != ExprCoreType.BOOLEAN && !isNestedFunction) {
throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(condition.type());
}
// Check if any window functions in filter
List<Expression> results =
ExpressionUtils.findSubExpressions(condition, WindowFunctionExpression.class::isInstance);
if (!results.isEmpty()) {
throw QueryCompilationError.windowFunctionNotAllowedError();
}
}

/**
* Ensure NESTED function is not used in GROUP BY, and HAVING clauses. Fallback to legacy engine.
* Can remove when support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING
* clauses.
*
* @param condition : Filter condition
* Ensure NESTED function is not used in GROUP BY. Fallback to legacy engine. Can remove when
* support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING clauses. Ensure
* Aggregate function is not used in GROUP BY.
*/
private void verifySupportsCondition(Expression condition) {
if (condition instanceof FunctionExpression) {
if (((FunctionExpression) condition)
private void verifySupportsGroupBy(Expression groupBy) {
if (groupBy instanceof FunctionExpression) {
if (((FunctionExpression) groupBy)
.getFunctionName()
.getFunctionName()
.equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) {
throw new SyntaxCheckException(
"Falling back to legacy engine. Nested function is not supported in WHERE,"
+ " GROUP BY, and HAVING clauses.");
}
((FunctionExpression) condition)
.getArguments().stream().forEach(e -> verifySupportsCondition(e));
((FunctionExpression) groupBy).getArguments().stream().forEach(e -> verifySupportsGroupBy(e));
} else if (groupBy instanceof Aggregator) {
throw QueryCompilationError.aggregateFunctionNotAllowedInGroupByError(
((Aggregator<?>) groupBy).getFunctionName().getFunctionName());
}
}

Expand Down Expand Up @@ -295,13 +322,7 @@ public LogicalPlan visitRename(Rename node, AnalysisContext context) {
@Override
public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
final LogicalPlan child = node.getChild().get(0).accept(this, context);
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggExprList()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}

// resolve group-by list
ImmutableList.Builder<NamedExpression> groupbyBuilder = new ImmutableList.Builder<>();
// Span should be first expression if exist.
if (node.getSpan() != null) {
Expand All @@ -310,12 +331,62 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {

for (UnresolvedExpression expr : node.getGroupExprList()) {
NamedExpression resolvedExpr = namedExpressionAnalyzer.analyze(expr, context);
verifySupportsCondition(resolvedExpr.getDelegated());
verifySupportsGroupBy(resolvedExpr.getDelegated());
groupbyBuilder.add(resolvedExpr);
}
ImmutableList<NamedExpression> groupBys = groupbyBuilder.build();

// spotless:off
// Verify group-by could work with select expressions.
// The following table shows the examples to explain the purpose:
// +------+------------------------------------------------------------------------------------------+---------+-------------------------+
// | Case | Query | IsValid | Field Missed In GroupBy |
// +------+------------------------------------------------------------------------------------------+---------+-------------------------+
// | 1 | SELECT a FROM table GROUP BY b | No | a |
// | 2 | SELECT a as c FROM table GROUP BY b | No | a |
// | 3 | SELECT a FROM table GROUP BY a * 3 | No | a |
// | 4 | SELECT a * 3 FROM table GROUP BY a | Yes | N/A |
// | 5 | SELECT a * 3 FROM table GROUP BY b | No | a |
// | 6 | SELECT a FROM table GROUP BY upper(a) | No | a |
// | 7 | SELECT upper(a) FROM table GROUP BY a | Yes | N/A |
// | 8 | SELECT upper(a) FROM table GROUP BY upper(a) | Yes | N/A |
// | 9 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY b | No | a |
// | 10 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY upper(b) | No | a |
// | 11 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat(upper(a), upper(b)) | Yes | N/A |
// | 12 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat_ws(',', upper(a), upper(b)) | No | a |
// | 13 | SELECT concat(a, b) FROM table group by upper(a), upper(b) | No | a |
// | 14 | SELECT concat(a, b) FROM table group by a | No | b |
// | 15 | SELECT concat(a, b) FROM table group by a, upper(b) | No | b |
// | 16 | SELECT concat(a, b), upper(b) FROM table group by a, upper(b) | No | b |
// | 17 | SELECT concat(a, b), upper(b) FROM table group by a, b | Yes | N/A |
// | 18 | SELECT upper(concat(a, b)) FROM table group by concat(a, b) | Yes | N/A |
// | 19 | SELECT concat(concat(a, b), c) FROM table group by concat(a, b) | No | c |
// | 20 | SELECT 1, 2, 3 FROM table group by a | Yes | N/A |
// | 21 | SELECT 1, 2, b FROM table group by a | No | b |
// +------+------------------------------------------------------------------------------------------+---------+-------------------------+
// spotless:on
for (UnresolvedExpression expr : node.getAliasFreeSelectExprList()) {
Expression resolvedSelectItemExpr = expressionAnalyzer.analyze(expr, context);
Predicate<Expression> notExists =
e ->
(e instanceof ReferenceExpression || e instanceof FunctionExpression)
&& groupBys.stream().noneMatch(g -> e.equals(g.getDelegated()));
Consumer<String> action =
name -> {
throw QueryCompilationError.fieldNotInGroupByClauseError(name);
};
ExpressionUtils.actionOnCheck(resolvedSelectItemExpr, notExists, action);
}

// resolve aggregators
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggExprList()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}
ImmutableList<NamedAggregator> aggregators = aggregatorBuilder.build();

// new context
context.push();
TypeEnvironment newEnv = context.peek();
Expand All @@ -329,6 +400,39 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
return new LogicalAggregation(child, aggregators, groupBys);
}

/** Resolve Having clause to merge its aggregators to {@link LogicalAggregation}. */
@Override
public LogicalPlan visitHaving(Having node, AnalysisContext context) {
LogicalAggregation aggregation =
(LogicalAggregation) node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);
verifyCondition(condition);

// Extract aggregator from Having clause
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggregators()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}
List<NamedAggregator> aggregatorListFromHaving = aggregatorBuilder.build();
// new context
context.push();
TypeEnvironment newEnv = context.peek();
aggregatorListFromHaving.forEach(
aggregator ->
newEnv.define(
new Symbol(Namespace.FIELD_NAME, aggregator.getName()), aggregator.type()));

List<NamedAggregator> aggregatorListFromChild = aggregation.getAggregatorList();
// merge the aggregators from having to its child's
Set<NamedAggregator> dedup = new LinkedHashSet<>(aggregatorListFromChild);
dedup.addAll(aggregatorListFromHaving);
List<NamedAggregator> mergedAggregators = new ArrayList<>(dedup);
return new LogicalAggregation(
aggregation.getChild().get(0), mergedAggregators, aggregation.getGroupByList());
}

/** Build {@link LogicalRareTopN}. */
@Override
public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Getter;
import org.opensearch.sql.QueryCompilationError;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand Down Expand Up @@ -70,6 +71,8 @@
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
import org.opensearch.sql.expression.window.ranking.RankingWindowFunction;
import org.opensearch.sql.utils.ExpressionUtils;

/**
* Analyze the {@link UnresolvedExpression} in the {@link AnalysisContext} to construct the {@link
Expand Down Expand Up @@ -169,11 +172,23 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
builder.build());
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
// Check if the filter condition is a valid predicate.
Expression predicate = node.condition().accept(this, context);
if (predicate.type() != ExprCoreType.BOOLEAN) {
throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(predicate.type());
}
// Check if any aggregate function in filter
List<Expression> results =
ExpressionUtils.findSubExpressions(predicate, Aggregator.class::isInstance);
if (!results.isEmpty()) {
throw QueryCompilationError.aggregateFunctionNotAllowedInFilterError(
((Aggregator) results.get(0)).getFunctionName().getFunctionName());
}
aggregator.condition(predicate);
}
return aggregator;
} else {
throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
throw QueryCompilationError.unsupportedAggregateFunctionError(node.getFuncName());
}
}

Expand Down Expand Up @@ -203,6 +218,10 @@ public Expression visitFunction(Function node, AnalysisContext context) {
repository.compile(context.getFunctionProperties(), functionName, arguments);
}

/**
* Todo. throws SemanticCheckException when a configuration could be set in order to avoid
* breaking change. Order is required if function expression is {@link RankingWindowFunction}.
*/
@SuppressWarnings("unchecked")
@Override
public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
Expand Down Expand Up @@ -312,4 +313,8 @@ public T visitFetchCursor(FetchCursor cursor, C context) {
public T visitCloseCursor(CloseCursor closeCursor, C context) {
return visitChildren(closeCursor, context);
}

public T visitHaving(Having having, C context) {
return visitChildren(having, context);
}
}
Loading

0 comments on commit 1c8a5ae

Please sign in to comment.