diff --git a/build.gradle b/build.gradle index b3e09d7b50..4a33724c7a 100644 --- a/build.gradle +++ b/build.gradle @@ -97,6 +97,7 @@ spotless { removeUnusedImports() trimTrailingWhitespace() endWithNewline() + toggleOffOn() googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format') } } diff --git a/core/build.gradle b/core/build.gradle index 655e7d92c2..460cf87de6 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -78,6 +78,7 @@ spotless { removeUnusedImports() trimTrailingWhitespace() endWithNewline() + toggleOffOn() googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format') } } @@ -112,7 +113,8 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.utils.Constants', 'org.opensearch.sql.datasource.model.DataSource', 'org.opensearch.sql.datasource.model.DataSourceStatus', - 'org.opensearch.sql.datasource.model.DataSourceType' + 'org.opensearch.sql.datasource.model.DataSourceType', + 'org.opensearch.sql.QueryCompilationError' ] limit { counter = 'LINE' diff --git a/core/src/main/java/org/opensearch/sql/QueryCompilationError.java b/core/src/main/java/org/opensearch/sql/QueryCompilationError.java new file mode 100644 index 0000000000..7d2aaec33e --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/QueryCompilationError.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql; + +import static org.opensearch.sql.common.utils.StringUtils.format; + +import lombok.experimental.UtilityClass; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; + +/** Grouping error messages from {@link SemanticCheckException} thrown during query compilation. */ +@UtilityClass +public class QueryCompilationError { + + public static SemanticCheckException fieldNotInGroupByClauseError(String name) { + return new SemanticCheckException( + format( + "Field [%s] must appear in the GROUP BY clause or be used in an aggregate function", + name)); + } + + public static SemanticCheckException aggregateFunctionNotAllowedInGroupByError( + String functionName) { + return new SemanticCheckException( + format( + "Aggregate function is not allowed in a GROUP BY clause, but found [%s]", + functionName)); + } + + public static SemanticCheckException nonBooleanExpressionInFilterOrHavingError(ExprType type) { + return new SemanticCheckException( + format( + "FILTER or HAVING expression must be type boolean, but found [%s]", type.typeName())); + } + + public static SemanticCheckException aggregateFunctionNotAllowedInFilterError( + String functionName) { + return new SemanticCheckException( + format("Aggregate function is not allowed in a FILTER, but found [%s]", functionName)); + } + + public static SemanticCheckException windowFunctionNotAllowedError() { + return new SemanticCheckException("Window functions are not allowed in WHERE or HAVING"); + } + + public static SemanticCheckException unsupportedAggregateFunctionError(String functionName) { + return new SemanticCheckException(format("Unsupported aggregation function %s", functionName)); + } + + public static SemanticCheckException ordinalRefersOutOfBounds(int ordinal) { + return new SemanticCheckException( + format("Ordinal [%d] is out of bound of select item list", ordinal)); + } + + public static SemanticCheckException groupByClauseIsMissingError(UnresolvedExpression expr) { + return new SemanticCheckException( + format( + "Explicit GROUP BY clause is required because expression [%s] contains non-aggregated" + + " column", + expr)); + } +} diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index d5e8b93b13..03f1326c1a 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -22,13 +22,18 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.QueryCompilationError; import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -47,6 +52,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Having; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; @@ -81,6 +87,7 @@ import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.expression.parse.ParseExpression; +import org.opensearch.sql.expression.window.WindowFunctionExpression; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalCloseCursor; @@ -102,6 +109,7 @@ import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.physical.datasource.DataSourceTable; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.utils.ExpressionUtils; import org.opensearch.sql.utils.ParseUtils; /** @@ -235,6 +243,7 @@ public LogicalPlan visitLimit(Limit node, AnalysisContext context) { public LogicalPlan visitFilter(Filter node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); + verifyCondition(condition); ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); @@ -242,16 +251,32 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) { return new LogicalFilter(child, optimized); } + private void verifyCondition(Expression condition) { + // TODO Remove this when adding support for syntax - nested(path, condition) + // Current WHERE nested(path, condition) is not a valid boolean condition. + boolean isNestedFunction = + condition instanceof FunctionExpression + && ((FunctionExpression) condition).getFunctionName().equals(FunctionName.of("nested")); + // Check if the filter condition is a valid predicate. + if (condition.type() != ExprCoreType.BOOLEAN && !isNestedFunction) { + throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(condition.type()); + } + // Check if any window functions in filter + List results = + ExpressionUtils.findSubExpressions(condition, WindowFunctionExpression.class::isInstance); + if (!results.isEmpty()) { + throw QueryCompilationError.windowFunctionNotAllowedError(); + } + } + /** - * Ensure NESTED function is not used in GROUP BY, and HAVING clauses. Fallback to legacy engine. - * Can remove when support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING - * clauses. - * - * @param condition : Filter condition + * Ensure NESTED function is not used in GROUP BY. Fallback to legacy engine. Can remove when + * support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING clauses. Ensure + * Aggregate function is not used in GROUP BY. */ - private void verifySupportsCondition(Expression condition) { - if (condition instanceof FunctionExpression) { - if (((FunctionExpression) condition) + private void verifySupportsGroupBy(Expression groupBy) { + if (groupBy instanceof FunctionExpression) { + if (((FunctionExpression) groupBy) .getFunctionName() .getFunctionName() .equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { @@ -259,8 +284,10 @@ private void verifySupportsCondition(Expression condition) { "Falling back to legacy engine. Nested function is not supported in WHERE," + " GROUP BY, and HAVING clauses."); } - ((FunctionExpression) condition) - .getArguments().stream().forEach(e -> verifySupportsCondition(e)); + ((FunctionExpression) groupBy).getArguments().stream().forEach(e -> verifySupportsGroupBy(e)); + } else if (groupBy instanceof Aggregator) { + throw QueryCompilationError.aggregateFunctionNotAllowedInGroupByError( + ((Aggregator) groupBy).getFunctionName().getFunctionName()); } } @@ -295,13 +322,7 @@ public LogicalPlan visitRename(Rename node, AnalysisContext context) { @Override public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { final LogicalPlan child = node.getChild().get(0).accept(this, context); - ImmutableList.Builder aggregatorBuilder = new ImmutableList.Builder<>(); - for (UnresolvedExpression expr : node.getAggExprList()) { - NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context); - aggregatorBuilder.add( - new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated())); - } - + // resolve group-by list ImmutableList.Builder groupbyBuilder = new ImmutableList.Builder<>(); // Span should be first expression if exist. if (node.getSpan() != null) { @@ -310,12 +331,62 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { for (UnresolvedExpression expr : node.getGroupExprList()) { NamedExpression resolvedExpr = namedExpressionAnalyzer.analyze(expr, context); - verifySupportsCondition(resolvedExpr.getDelegated()); + verifySupportsGroupBy(resolvedExpr.getDelegated()); groupbyBuilder.add(resolvedExpr); } ImmutableList groupBys = groupbyBuilder.build(); + // spotless:off + // Verify group-by could work with select expressions. + // The following table shows the examples to explain the purpose: + // +------+------------------------------------------------------------------------------------------+---------+-------------------------+ + // | Case | Query | IsValid | Field Missed In GroupBy | + // +------+------------------------------------------------------------------------------------------+---------+-------------------------+ + // | 1 | SELECT a FROM table GROUP BY b | No | a | + // | 2 | SELECT a as c FROM table GROUP BY b | No | a | + // | 3 | SELECT a FROM table GROUP BY a * 3 | No | a | + // | 4 | SELECT a * 3 FROM table GROUP BY a | Yes | N/A | + // | 5 | SELECT a * 3 FROM table GROUP BY b | No | a | + // | 6 | SELECT a FROM table GROUP BY upper(a) | No | a | + // | 7 | SELECT upper(a) FROM table GROUP BY a | Yes | N/A | + // | 8 | SELECT upper(a) FROM table GROUP BY upper(a) | Yes | N/A | + // | 9 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY b | No | a | + // | 10 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY upper(b) | No | a | + // | 11 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat(upper(a), upper(b)) | Yes | N/A | + // | 12 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat_ws(',', upper(a), upper(b)) | No | a | + // | 13 | SELECT concat(a, b) FROM table group by upper(a), upper(b) | No | a | + // | 14 | SELECT concat(a, b) FROM table group by a | No | b | + // | 15 | SELECT concat(a, b) FROM table group by a, upper(b) | No | b | + // | 16 | SELECT concat(a, b), upper(b) FROM table group by a, upper(b) | No | b | + // | 17 | SELECT concat(a, b), upper(b) FROM table group by a, b | Yes | N/A | + // | 18 | SELECT upper(concat(a, b)) FROM table group by concat(a, b) | Yes | N/A | + // | 19 | SELECT concat(concat(a, b), c) FROM table group by concat(a, b) | No | c | + // | 20 | SELECT 1, 2, 3 FROM table group by a | Yes | N/A | + // | 21 | SELECT 1, 2, b FROM table group by a | No | b | + // +------+------------------------------------------------------------------------------------------+---------+-------------------------+ + // spotless:on + for (UnresolvedExpression expr : node.getAliasFreeSelectExprList()) { + Expression resolvedSelectItemExpr = expressionAnalyzer.analyze(expr, context); + Predicate notExists = + e -> + (e instanceof ReferenceExpression || e instanceof FunctionExpression) + && groupBys.stream().noneMatch(g -> e.equals(g.getDelegated())); + Consumer action = + name -> { + throw QueryCompilationError.fieldNotInGroupByClauseError(name); + }; + ExpressionUtils.actionOnCheck(resolvedSelectItemExpr, notExists, action); + } + + // resolve aggregators + ImmutableList.Builder aggregatorBuilder = new ImmutableList.Builder<>(); + for (UnresolvedExpression expr : node.getAggExprList()) { + NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context); + aggregatorBuilder.add( + new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated())); + } ImmutableList aggregators = aggregatorBuilder.build(); + // new context context.push(); TypeEnvironment newEnv = context.peek(); @@ -329,6 +400,39 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { return new LogicalAggregation(child, aggregators, groupBys); } + /** Resolve Having clause to merge its aggregators to {@link LogicalAggregation}. */ + @Override + public LogicalPlan visitHaving(Having node, AnalysisContext context) { + LogicalAggregation aggregation = + (LogicalAggregation) node.getChild().get(0).accept(this, context); + Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); + verifyCondition(condition); + + // Extract aggregator from Having clause + ImmutableList.Builder aggregatorBuilder = new ImmutableList.Builder<>(); + for (UnresolvedExpression expr : node.getAggregators()) { + NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context); + aggregatorBuilder.add( + new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated())); + } + List aggregatorListFromHaving = aggregatorBuilder.build(); + // new context + context.push(); + TypeEnvironment newEnv = context.peek(); + aggregatorListFromHaving.forEach( + aggregator -> + newEnv.define( + new Symbol(Namespace.FIELD_NAME, aggregator.getName()), aggregator.type())); + + List aggregatorListFromChild = aggregation.getAggregatorList(); + // merge the aggregators from having to its child's + Set dedup = new LinkedHashSet<>(aggregatorListFromChild); + dedup.addAll(aggregatorListFromHaving); + List mergedAggregators = new ArrayList<>(dedup); + return new LogicalAggregation( + aggregation.getChild().get(0), mergedAggregators, aggregation.getGroupByList()); + } + /** Build {@link LogicalRareTopN}. */ @Override public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) { diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 5a8d6fe976..bd29ba79cb 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -18,6 +18,7 @@ import java.util.Optional; import java.util.stream.Collectors; import lombok.Getter; +import org.opensearch.sql.QueryCompilationError; import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -70,6 +71,8 @@ import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.expression.span.SpanExpression; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; +import org.opensearch.sql.expression.window.ranking.RankingWindowFunction; +import org.opensearch.sql.utils.ExpressionUtils; /** * Analyze the {@link UnresolvedExpression} in the {@link AnalysisContext} to construct the {@link @@ -169,11 +172,23 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext builder.build()); aggregator.distinct(node.getDistinct()); if (node.condition() != null) { - aggregator.condition(analyze(node.condition(), context)); + // Check if the filter condition is a valid predicate. + Expression predicate = node.condition().accept(this, context); + if (predicate.type() != ExprCoreType.BOOLEAN) { + throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(predicate.type()); + } + // Check if any aggregate function in filter + List results = + ExpressionUtils.findSubExpressions(predicate, Aggregator.class::isInstance); + if (!results.isEmpty()) { + throw QueryCompilationError.aggregateFunctionNotAllowedInFilterError( + ((Aggregator) results.get(0)).getFunctionName().getFunctionName()); + } + aggregator.condition(predicate); } return aggregator; } else { - throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName()); + throw QueryCompilationError.unsupportedAggregateFunctionError(node.getFuncName()); } } @@ -203,6 +218,10 @@ public Expression visitFunction(Function node, AnalysisContext context) { repository.compile(context.getFunctionProperties(), functionName, arguments); } + /** + * Todo. throws SemanticCheckException when a configuration could be set in order to avoid + * breaking change. Order is required if function expression is {@link RankingWindowFunction}. + */ @SuppressWarnings("unchecked") @Override public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) { diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 973b10310b..597d8dbe09 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -46,6 +46,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Having; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; @@ -312,4 +313,8 @@ public T visitFetchCursor(FetchCursor cursor, C context) { public T visitCloseCursor(CloseCursor closeCursor, C context) { return visitChildren(closeCursor, context); } + + public T visitHaving(Having having, C context) { + return visitChildren(having, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 4f3056b0f7..35edf28fd0 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -47,6 +47,7 @@ import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Having; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Parse; @@ -104,15 +105,36 @@ public static UnresolvedPlan projectWithArg( return new Project(Arrays.asList(projectList), argList).attach(input); } + /** + * Creates an aggregation with the specified aggregators, sort expressions, group-by expressions, + * and arguments. + * + * @param input the child unresolved plan + * @param aggList the list of aggregator + * @param sortList the list of sort expressions + * @param groupList the list of group-by expressions + * @param argList the list of arguments + */ public static UnresolvedPlan agg( UnresolvedPlan input, List aggList, List sortList, List groupList, List argList) { - return new Aggregation(aggList, sortList, groupList, null, argList).attach(input); + return new Aggregation(aggList, sortList, groupList, null, argList, List.of()).attach(input); } + /** + * Creates an aggregation with the specified aggregators, sort expressions, group-by expressions, + * span expression, and arguments. + * + * @param input the child unresolved plan + * @param aggList the list of aggregators + * @param sortList the list of sort expressions + * @param groupList the list of group-by expressions + * @param span the span expression + * @param argList the list of arguments + */ public static UnresolvedPlan agg( UnresolvedPlan input, List aggList, @@ -120,7 +142,53 @@ public static UnresolvedPlan agg( List groupList, UnresolvedExpression span, List argList) { - return new Aggregation(aggList, sortList, groupList, span, argList).attach(input); + return new Aggregation(aggList, sortList, groupList, span, argList, List.of()).attach(input); + } + + /** + * Creates an aggregation with the specified aggregators, sort expressions, group-by expressions, + * arguments, and select expressions. + * + * @param input the child unresolved plan + * @param aggList the list of aggregators + * @param sortList the list of sort expressions + * @param groupList the list of group-by expressions + * @param argList the list of arguments + * @param aliasFreeSelectExprList the list of alias free select expressions + */ + public static UnresolvedPlan agg( + UnresolvedPlan input, + List aggList, + List sortList, + List groupList, + List argList, + List aliasFreeSelectExprList) { + return new Aggregation(aggList, sortList, groupList, null, argList, aliasFreeSelectExprList) + .attach(input); + } + + /** + * Creates an aggregation with the specified aggregators, sort expressions, group-by expressions, + * span expression, arguments, and select expressions. + * + * @param input the child unresolved plan + * @param aggList the list of aggregators + * @param sortList the list of sort expressions + * @param groupList the list of group-by expressions + * @param span the span expression + * @param argList the list of arguments + * @param aliasFreeSelectExprList the list of alias free select expressions + */ + public static UnresolvedPlan agg( + UnresolvedPlan input, + List aggList, + List sortList, + List groupList, + UnresolvedExpression span, + List argList, + List aliasFreeSelectExprList) { + return new Aggregation(aggList, sortList, groupList, span, argList, aliasFreeSelectExprList) + .attach(input); } public static UnresolvedPlan rename(UnresolvedPlan input, Map... maps) { @@ -471,4 +539,11 @@ public static Parse parse( java.util.Map arguments) { return new Parse(parseMethod, sourceField, pattern, arguments, input); } + + public static UnresolvedPlan having( + UnresolvedPlan input, + List aggregators, + UnresolvedExpression condition) { + return new Having(aggregators, condition).attach(input); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index 5208e39623..5b18eca79a 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -71,6 +71,6 @@ public R accept(AbstractNodeVisitor nodeVisitor, C context) { @Override public String toString() { - return StringUtils.format("%s(%s)", funcName, field); + return StringUtils.format("%s(%s%s)", funcName, distinct ? "DISTINCT " : "", field); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/Alias.java b/core/src/main/java/org/opensearch/sql/ast/expression/Alias.java index 7b3078629b..cc04fbf8e1 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/Alias.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/Alias.java @@ -5,12 +5,15 @@ package org.opensearch.sql.ast.expression; +import com.google.common.collect.ImmutableList; +import java.util.List; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; /** * Alias abstraction that associate an unnamed expression with a name and an optional alias. The @@ -38,4 +41,9 @@ public class Alias extends UnresolvedExpression { public T accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitAlias(this, context); } + + @Override + public List getChild() { + return ImmutableList.of(this.delegated); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java b/core/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java index f098d0ec53..0802d0c76e 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java @@ -27,28 +27,53 @@ public class Aggregation extends UnresolvedPlan { private List groupExprList; private UnresolvedExpression span; private List argExprList; + private List aliasFreeSelectExprList; private UnresolvedPlan child; /** Aggregation Constructor without span and argument. */ public Aggregation( List aggExprList, List sortExprList, - List groupExprList) { - this(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + List groupExprList, + List aliasFreeSelectExprList) { + this( + aggExprList, + sortExprList, + groupExprList, + null, + Collections.emptyList(), + aliasFreeSelectExprList); } - /** Aggregation Constructor. */ + /** Aggregation Constructor without select expressions, used in PPL. */ public Aggregation( List aggExprList, List sortExprList, List groupExprList, UnresolvedExpression span, List argExprList) { + this(aggExprList, sortExprList, groupExprList, span, argExprList, Collections.emptyList()); + } + + /** + * Aggregation Constructor. + * + * @param aliasFreeSelectExprList is used to verify that all fields in Select must appear in the + * GROUP BY clause or be used in an aggregate function. + */ + public Aggregation( + List aggExprList, + List sortExprList, + List groupExprList, + UnresolvedExpression span, + List argExprList, + List aliasFreeSelectExprList) { this.aggExprList = aggExprList; this.sortExprList = sortExprList; this.groupExprList = groupExprList; this.span = span; this.argExprList = argExprList; + this.aliasFreeSelectExprList = aliasFreeSelectExprList; } public boolean hasArgument() { diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Having.java b/core/src/main/java/org/opensearch/sql/ast/tree/Having.java new file mode 100644 index 0000000000..13b3d81dd3 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Having.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +/** + * Represents unresolved HAVING clause, its child can be Aggregation. Having without aggregation + * equals to {@link Filter} + */ +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +public class Having extends UnresolvedPlan { + private List aggregators; + private UnresolvedExpression condition; + private UnresolvedPlan child; + + public Having(List aggregators, UnresolvedExpression condition) { + this.aggregators = aggregators; + this.condition = condition; + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitHaving(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java index 9f14ba1e5d..2ab5330530 100644 --- a/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java @@ -130,7 +130,7 @@ public Boolean visitField(Field node, Object context) { @Override public Boolean visitAlias(Alias node, Object context) { - return canPaginate(node, context) && node.getDelegated().accept(this, context); + return canPaginate(node, context); } @Override diff --git a/core/src/main/java/org/opensearch/sql/utils/ExpressionUtils.java b/core/src/main/java/org/opensearch/sql/utils/ExpressionUtils.java index f04bf3748f..653f4963ce 100644 --- a/core/src/main/java/org/opensearch/sql/utils/ExpressionUtils.java +++ b/core/src/main/java/org/opensearch/sql/utils/ExpressionUtils.java @@ -5,10 +5,15 @@ package org.opensearch.sql.utils; +import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; +import java.util.function.Predicate; import java.util.stream.Collectors; +import lombok.Generated; import lombok.experimental.UtilityClass; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; /** Utils for {@link Expression}. */ @UtilityClass @@ -20,4 +25,57 @@ public class ExpressionUtils { public static String format(List expressionList) { return expressionList.stream().map(Expression::toString).collect(Collectors.joining(",")); } + + /** + * Find the children expressions matching the given predicate from an {@link Expression}. This + * method could help you to traverse an expression and return all sub expressions under the + * condition you provided. + * + * @param expr the expression to traverse + * @param condition the condition to test + * @return all sub expressions matching the condition + */ + public static List findSubExpressions( + Expression expr, Predicate condition) { + List results = new ArrayList<>(); + findSubExpressionsHelper(expr, condition, results); + return results; + } + + private static void findSubExpressionsHelper( + Expression expr, Predicate condition, List results) { + if (condition.test(expr)) { + results.add((T) expr); + } + if (expr instanceof FunctionExpression) { + for (Expression child : ((FunctionExpression) expr).getArguments()) { + findSubExpressionsHelper(child, condition, results); + } + } + } + + /** + * Traverse an {@link Expression} to consume when the given predicate matched. + * + *

Add @Generated annotation since the jacoco has a bug to report test coverage for the Lambda + * {@code action.accept()}. + * + * @param expr the expression to traverse + * @param condition the condition to test + * @param action execute the action when the condition matched. + */ + @Generated + public static void actionOnCheck( + Expression expr, Predicate condition, Consumer action) { + if (expr instanceof FunctionExpression) { + if (!condition.test(expr)) { + return; + } + for (Expression child : ((FunctionExpression) expr).getArguments()) { + actionOnCheck(child, condition, action); + } + } else if (condition.test(expr)) { + action.accept(expr.toString()); + } + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 8d935b11d2..5122d77351 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -7,14 +7,18 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction; +import static org.opensearch.sql.ast.dsl.AstDSL.agg; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.and; import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.compare; @@ -22,13 +26,16 @@ import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.having; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.nestedAllTupleFields; +import static org.opensearch.sql.ast.dsl.AstDSL.project; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.relation; import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.ast.dsl.AstDSL.window; import static org.opensearch.sql.ast.tree.Sort.NullOrder; import static org.opensearch.sql.ast.tree.Sort.SortOption; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; @@ -93,9 +100,11 @@ import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalCloseCursor; import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; @@ -1767,4 +1776,1019 @@ public void visit_close_cursor() { () -> assertEquals("pewpew", ((LogicalFetchCursor) analyzed.getChild().get(0)).getCursor())); } + + // spotless:off + // Complex cases should be checked: + // +------+------------------------------------------------------------------------------------------+---------+-------------------------+ + // | Case | Query | IsValid | Missed Field In GroupBy | + // +------+------------------------------------------------------------------------------------------+---------+-------------------------+ + // | 1 | SELECT a FROM table GROUP BY b | No | a | + // | 2 | SELECT a as c FROM table GROUP BY b | No | a | + // | 3 | SELECT a FROM table GROUP BY a * 3 | No | a | + // | 4 | SELECT a * 3 FROM table GROUP BY a | Yes | N/A | + // | 5 | SELECT a * 3 FROM table GROUP BY b | No | a | + // | 6 | SELECT a FROM table GROUP BY upper(a) | No | a | + // | 7 | SELECT upper(a) FROM table GROUP BY a | Yes | N/A | + // | 8 | SELECT upper(a) FROM table GROUP BY upper(a) | Yes | N/A | + // | 9 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY b | No | a | + // | 10 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY upper(b) | No | a | + // | 11 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat(upper(a), upper(b)) | Yes | N/A | + // | 12 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat_ws(',', upper(a), upper(b)) | No | a | + // | 13 | SELECT concat(a, b) FROM table group by upper(a), upper(b) | No | a | + // | 14 | SELECT concat(a, b) FROM table group by a | No | b | + // | 15 | SELECT concat(a, b) FROM table group by a, upper(b) | No | b | + // | 16 | SELECT concat(a, b), upper(b) FROM table group by a, upper(b) | No | b | + // | 17 | SELECT concat(a, b), upper(b) FROM table group by a, b | Yes | N/A | + // | 18 | SELECT upper(concat(a, b)) FROM table group by concat(a, b) | Yes | N/A | + // | 19 | SELECT concat(concat(a, b), c) FROM table group by concat(a, b) | No | c | + // | 20 | SELECT 1, 2, 3 FROM table group by a | Yes | N/A | + // | 21 | SELECT 1, 2, b FROM table group by a | No | b | + // +------+------------------------------------------------------------------------------------------+---------+-------------------------+ + // spotless:on + /** case 1: SELECT integer_value FROM schema group by string_value; (invalid) */ + @Test + public void field_not_in_group_by_case1() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of( + alias("integer_value", qualifiedName("integer_value")))), + alias("integer_value", qualifiedName("integer_value"))))); + assertEquals( + "Field [integer_value] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** case 2: SELECT integer_value as int_value FROM schema group by string_value; (invalid) */ + @Test + public void field_not_in_group_by_case2() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of(alias("int_value", qualifiedName("integer_value")))), + alias("int_value", qualifiedName("integer_value"))))); + assertEquals( + "Field [integer_value] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** case 3: SELECT integer_value FROM table GROUP BY integer_value * 3; (invalid) */ + @Test + public void field_not_in_group_by_case3() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "*(integer_value, 3)", + function( + "*", + alias("integer_value", qualifiedName("integer_value")), + intLiteral(3)))), + emptyList(), + ImmutableList.of(qualifiedName("integer_value"))), + alias("integer_value", qualifiedName("integer_value"))))); + assertEquals( + "Field [integer_value] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** case 4: SELECT integer_value * 3 FROM table GROUP BY integer_value; (valid) */ + @Test + public void field_not_in_group_by_case4() { + assertDoesNotThrow( + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("integer_value", qualifiedName("integer_value"))), + emptyList(), + ImmutableList.of( + function( + "*", + alias("integer_value", qualifiedName("integer_value")), + intLiteral(3)))), + alias( + "*(integer_value, 3)", + function( + "*", + alias("integer_value", qualifiedName("integer_value")), + intLiteral(3)))))); + } + + /** case 5: SELECT integer_value * 3 FROM table GROUP BY string_value; (invalid) */ + @Test + public void field_not_in_group_by_case5() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of( + function( + "*", + alias("integer_value", qualifiedName("integer_value")), + intLiteral(3)))), + alias( + "*(integer_value, 3)", + function( + "*", + alias("integer_value", qualifiedName("integer_value")), + intLiteral(3)))))); + assertEquals( + "Field [integer_value] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** case 6: SELECT string_value FROM table GROUP BY upper(string_value); (invalid) */ + @Test + public void field_not_in_group_by_case6() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "upper(string_value)", + function( + "upper", + alias("string_value", qualifiedName("string_value"))))), + emptyList(), + ImmutableList.of(qualifiedName("string_value"))), + alias("string_value", qualifiedName("string_value"))))); + assertEquals( + "Field [string_value] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** case 7: SELECT upper(string_value) FROM table GROUP BY string_value; (valid) */ + @Test + public void field_not_in_group_by_case7() { + assertDoesNotThrow( + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of( + function( + "upper", alias("string_value", qualifiedName("string_value"))))), + alias( + "upper(integer_value)", + function("upper", alias("string_value", qualifiedName("string_value"))))))); + } + + /** case 8: SELECT upper(string_value) FROM table GROUP BY upper(string_value); (valid) */ + @Test + public void field_not_in_group_by_case8() { + assertDoesNotThrow( + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "upper(string_value)", + function( + "upper", + alias("string_value", qualifiedName("string_value"))))), + emptyList(), + ImmutableList.of( + function( + "upper", alias("string_value", qualifiedName("string_value"))))), + alias( + "upper(integer_value)", + function("upper", alias("string_value", qualifiedName("string_value"))))))); + } + + /** + * case 9: SELECT concat(upper(field_value1), upper(field_value2)) FROM table GROUP BY + * field_value2; (invalid) + */ + @Test + public void field_not_in_group_by_case9() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("field_value2", qualifiedName("field_value2"))), + emptyList(), + ImmutableList.of( + function( + "concat", + function("upper", qualifiedName("field_value1")), + function("upper", qualifiedName("field_value2"))))), + alias( + "concat(upper(string_value), upper(field_value2))", + function( + "concat", + function( + "upper", alias("field_value1", qualifiedName("field_value1"))), + function( + "upper", + alias("field_value2", qualifiedName("field_value2")))))))); + assertEquals( + "Field [field_value1] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 10: SELECT concat(upper(field_value1), upper(field_value2)) FROM table GROUP BY + * upper(field_value2); (invalid) + */ + @Test + public void field_not_in_group_by_case10() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "upper(field_value2)", + function( + "upper", + alias("field_value2", qualifiedName("field_value2"))))), + emptyList(), + ImmutableList.of( + function( + "concat", + function("upper", qualifiedName("field_value1")), + function("upper", qualifiedName("field_value2"))))), + alias( + "concat(upper(field_value1), upper(field_value2))", + function( + "concat", + function( + "upper", alias("field_value1", qualifiedName("field_value1"))), + function( + "upper", + alias("field_value2", qualifiedName("field_value2")))))))); + assertEquals( + "Field [field_value1] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 11: SELECT concat(upper(field_value1), upper(field_value2)) FROM table GROUP BY + * concat(upper(field_value1), upper(field_value2)); (valid) + */ + @Test + public void field_not_in_group_by_case11() { + assertDoesNotThrow( + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "concat(upper(field_value1), upper(field_value2))", + function( + "concat", + function( + "upper", + alias("field_value1", qualifiedName("field_value1"))), + function( + "upper", + alias("field_value2", qualifiedName("field_value2")))))), + emptyList(), + ImmutableList.of( + function( + "concat", + function("upper", qualifiedName("field_value1")), + function("upper", qualifiedName("field_value2"))))), + alias( + "concat(upper(field_value1), upper(field_value2))", + function( + "concat", + function("upper", alias("field_value1", qualifiedName("field_value1"))), + function( + "upper", alias("field_value2", qualifiedName("field_value2")))))))); + } + + /** + * case 12: SELECT concat(upper(field_value1), upper(field_value2)) FROM table GROUP BY + * concat_ws(',', upper(field_value1), upper(field_value2)); (invalid) + */ + @Test + public void field_not_in_group_by_case12() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "concat_ws(upper(field_value1), upper(field_value2))", + function( + "concat_ws", + stringLiteral(","), + function( + "upper", + alias("field_value1", qualifiedName("field_value1"))), + function( + "upper", + alias( + "field_value2", qualifiedName("field_value2")))))), + emptyList(), + ImmutableList.of( + function( + "concat", + function("upper", qualifiedName("field_value1")), + function("upper", qualifiedName("field_value2"))))), + alias( + "concat(upper(field_value1), upper(field_value2))", + function( + "concat", + function( + "upper", alias("field_value1", qualifiedName("field_value1"))), + function( + "upper", + alias("field_value2", qualifiedName("field_value2")))))))); + assertEquals( + "Field [field_value1] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 13: SELECT concat(field_value1, field_value2) FROM schema group by upper(field_value1), + * upper(field_value2); (invalid) + */ + @Test + public void field_not_in_group_by_case13() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "upper(field_value1)", + function( + "upper", + alias("field_value1", qualifiedName("field_value1")))), + alias( + "upper(field_value2)", + function( + "upper", + alias("field_value2", qualifiedName("field_value2"))))), + emptyList(), + ImmutableList.of( + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2")))), + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))))))); + assertEquals( + "Field [field_value1] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 14: SELECT concat(field_value1, field_value2) FROM schema group by field_value1; (invalid) + */ + @Test + public void field_not_in_group_by_case14() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("field_value1", qualifiedName("field_value1"))), + emptyList(), + ImmutableList.of( + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2")))), + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))))))); + assertEquals( + "Field [field_value2] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 15: SELECT concat(field_value1, field_value2) FROM schema group by field_value1, + * upper(field_value2); (invalid) + */ + @Test + public void field_not_in_group_by_case15() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias("field_value1", qualifiedName("field_value1")), + alias( + "upper(field_value2)", + function( + "upper", + alias("field_value2", qualifiedName("field_value2"))))), + emptyList(), + ImmutableList.of( + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2")))), + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))))))); + assertEquals( + "Field [field_value2] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 16: SELECT concat(field_value1, field_value2), upper(b) FROM schema group by field_value1, + * upper(field_value2); (invalid) + */ + @Test + public void field_not_in_group_by_case16() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias("field_value1", qualifiedName("field_value1")), + alias( + "upper(field_value2)", + function( + "upper", + alias("field_value2", qualifiedName("field_value2"))))), + emptyList(), + ImmutableList.of( + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2")), + function( + "upper", + alias("field_value2", qualifiedName("field_value2"))))), + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2")))), + alias( + "upper(field_value2)", + function( + "upper", alias("field_value2", qualifiedName("field_value2"))))))); + assertEquals( + "Field [field_value2] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** + * case 17: SELECT concat(field_value1, field_value2), upper(field_value2) FROM schema group by + * field_value1, field_value2; (valid) + */ + @Test + public void field_not_in_group_by_case17() { + assertDoesNotThrow( + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))), + emptyList(), + ImmutableList.of( + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2")), + function( + "upper", alias("field_value2", qualifiedName("field_value2"))))), + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2")))), + alias( + "upper(field_value2)", + function("upper", alias("field_value2", qualifiedName("field_value2"))))))); + } + + /** + * case 18: SELECT upper(concat(field_value1, field_value2)) FROM schema group by + * concat(field_value1, field_value2); (valid) + */ + @Test + public void field_not_in_group_by_case18() { + assertDoesNotThrow( + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))))), + emptyList(), + ImmutableList.of( + function( + "upper", + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2"))))), + alias( + "upper(concat(field_value1, field_value2))", + function( + "upper", + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))))))))); + } + + /** + * case 19: SELECT concat(concat(field_value1, field_value2), string_value) FROM schema group by + * concat(field_value1, field_value2); (invalid) + */ + @Test + public void field_not_in_group_by_case19() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2"))))), + emptyList(), + ImmutableList.of( + function( + "concat", + function( + "concat", + qualifiedName("field_value1"), + qualifiedName("field_value2")), + qualifiedName("string_value")))), + alias( + "concat(concat(field_value1, field_value2), string_value)", + function( + "concat", + alias( + "concat(field_value1, field_value2)", + function( + "concat", + alias("field_value1", qualifiedName("field_value1")), + alias("field_value2", qualifiedName("field_value2")))), + alias("string_value", qualifiedName("string_value"))))))); + assertEquals( + "Field [string_value] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** case 20: SELECT 1, 2, 3 FROM schema group by string_value; (valid) */ + @Test + public void field_not_in_group_by_case20() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.relation("schema", table), + emptyList(), + ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), + ImmutableList.of( + DSL.named("1", DSL.literal(1)), + DSL.named("2", DSL.literal(2)), + DSL.named("3", DSL.literal(3))), + emptyList()), + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of(intLiteral(1), intLiteral(2), intLiteral(3))), + alias("1", intLiteral(1)), + alias("2", intLiteral(2)), + alias("3", intLiteral(3)))); + } + + /** case 21: SELECT 1, 2, field_value2 FROM schema group by field_value1; (invalid) */ + @Test + public void field_not_in_group_by_case21() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("field_value1", qualifiedName("field_value1"))), + emptyList(), + ImmutableList.of( + intLiteral(1), intLiteral(2), qualifiedName("field_value2"))), + alias("1", intLiteral(1)), + alias("2", intLiteral(2)), + alias("field_value2", qualifiedName("field_value2"))))); + assertEquals( + "Field [field_value2] must appear in the GROUP BY clause or be used in an aggregate" + + " function", + exception.getMessage()); + } + + /** SELECT integer_value FROM schema GROUP BY avg(integer_value) */ + @Test + public void aggregate_function_not_allowed_in_group_by_error() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + ImmutableList.of( + alias("integer_value", qualifiedName("integer_value"))), + emptyList(), + ImmutableList.of( + alias( + "AVG(integer_value)", + aggregate("AVG", qualifiedName("integer_value")))), + emptyList()), + alias("integer_value", qualifiedName("integer_value"))))); + assertEquals( + "Aggregate function is not allowed in a GROUP BY clause, but found [avg]", + exception.getMessage()); + } + + /** SELECT avg(integer_value) FROM schema GROUP BY 1 */ + @Test + public void aggregate_function_from_ordinal_not_allowed_in_group_by() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + agg( + relation("schema"), + ImmutableList.of( + alias("integer_value", qualifiedName("integer_value"))), + emptyList(), + ImmutableList.of( + alias( + "AVG(integer_value)", + aggregate("AVG", qualifiedName("integer_value")))), + emptyList()), + alias("integer_value", qualifiedName("integer_value"))))); + assertEquals( + "Aggregate function is not allowed in a GROUP BY clause, but found [avg]", + exception.getMessage()); + } + + /** SELECT integer_value FROM schema WHERE integer_value */ + @Test + public void non_boolean_expression_in_filter_error() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + filter(relation("schema"), qualifiedName("integer_value")), + alias("integer_value", qualifiedName("integer_value"))))); + assertEquals( + "FILTER or HAVING expression must be type boolean, but found [INTEGER]", + exception.getMessage()); + } + + /** SELECT integer_value FROM schema WHERE nested(integer_value, true) */ + @Test + public void non_boolean_expression_in_filter_error_except_in_nested() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + DSL.nested(DSL.ref("message", STRING), DSL.literal(true))), + ImmutableList.of(DSL.named("integer_value", DSL.ref("integer_value", INTEGER))), + emptyList()), + project( + filter( + relation("schema"), + function( + "nested", alias("message", qualifiedName("message")), booleanLiteral(true))), + alias("integer_value", qualifiedName("integer_value")))); + } + + /** SELECT string_value FROM schema GROUP BY string_value HAVING avg(integer_value) */ + @Test + public void non_boolean_expression_in_having_error() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + having( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of( + alias("string_value", qualifiedName("string_value")))), + ImmutableList.of( + alias( + "AVG(integer_value)", + aggregate("AVG", qualifiedName("integer_value")))), + aggregate("AVG", qualifiedName("integer_value"))), + alias("string_value", qualifiedName("string_value"))))); + assertEquals( + "FILTER or HAVING expression must be type boolean, but found [DOUBLE]", + exception.getMessage()); + } + + /** SELECT count(string_value) filter(where integer_value) FROM schema */ + @Test + public void non_boolean_expression_in_filtered_aggregation_error() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.project( + AstDSL.agg( + AstDSL.relation("schema"), + ImmutableList.of( + alias( + "count(string_value) filter(where integer_value)", + filteredAggregate( + "count", + qualifiedName("string_value"), + qualifiedName("integer_value")))), + emptyList(), + emptyList(), + emptyList()), + AstDSL.alias( + "count(string_value) filter(where integer_value)", + filteredAggregate( + "count", + qualifiedName("string_value"), + qualifiedName("integer_value")))))); + assertEquals( + "FILTER or HAVING expression must be type boolean, but found [INTEGER]", + exception.getMessage()); + } + + /** SELECT count(string_value) filter(where 10 > (avg(integer_value) + 3) * 2) FROM schema */ + @Test + public void aggregate_function_in_filter_error() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.project( + AstDSL.agg( + AstDSL.relation("schema"), + ImmutableList.of( + alias( + "count(string_value) filter(where integer_value)", + filteredAggregate( + "count", + qualifiedName("string_value"), + function( + ">", + intLiteral(10), + function( + "*", + function( + "+", + aggregate( + "AVG", qualifiedName("integer_value")), + intLiteral(3)), + intLiteral(2)))))), + emptyList(), + emptyList(), + emptyList()), + AstDSL.alias( + "count(string_value) filter(where 10 > (max(integer_value) + 3) * 2)", + filteredAggregate( + "count", + qualifiedName("string_value"), + function( + ">", + intLiteral(10), + function( + "*", + function( + "+", + aggregate("AVG", qualifiedName("integer_value")), + intLiteral(3)), + intLiteral(2)))))))); + assertEquals( + "Aggregate function is not allowed in a FILTER, but found [avg]", exception.getMessage()); + } + + /** SELECT string_value FROM schema where ROW_NUMBER() OVER(ORDER BY string_value) > 0 */ + @Test + public void window_function_in_where_error() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + project( + filter( + relation("schema"), + function( + ">", + window( + function("row_number"), + ImmutableList.of(), + ImmutableList.of( + ImmutablePair.of( + DEFAULT_ASC, qualifiedName("string_value")))), + intLiteral(0)))))); + assertEquals("Window functions are not allowed in WHERE or HAVING", exception.getMessage()); + } + + /** + * SELECT string_value FROM schema GROUP BY string_value HAVING MIN(integer_value) > 1 AND + * MAX(integer_value) < 10 + */ + @Test + public void aggregators_in_having_should_be_merged_to_logical_aggregation() { + LogicalPlan actual = + analyze( + project( + having( + agg( + relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value"))), + emptyList(), + ImmutableList.of(alias("string_value", qualifiedName("string_value")))), + ImmutableList.of( + alias( + "MIN(integer_value)", aggregate("MIN", qualifiedName("integer_value"))), + alias( + "MAX(integer_value)", + aggregate("MAX", qualifiedName("integer_value")))), + and( + compare( + ">", aggregate("MIN", qualifiedName("integer_value")), intLiteral(1)), + compare( + "<", + aggregate("MAX", qualifiedName("integer_value")), + intLiteral(10)))), + alias("string_value", qualifiedName("string_value")))); + assertInstanceOf(LogicalAggregation.class, actual.getChild().get(0)); + List mergedAggregators = + ((LogicalAggregation) actual.getChild().get(0)).getAggregatorList(); + List expected = + ImmutableList.of( + DSL.named("MIN(integer_value)", DSL.min(DSL.ref("integer_value", INTEGER))), + DSL.named("MAX(integer_value)", DSL.max(DSL.ref("integer_value", INTEGER)))); + assertEquals(expected, mergedAggregators); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index b27b8348e2..661cdc73a0 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -16,6 +16,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; @@ -31,6 +32,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; @@ -174,7 +176,12 @@ public void case_with_default_result_type_different() { @Test public void scalar_window_function() { assertAnalyzeEqual( - DSL.rank(), AstDSL.window(AstDSL.function("rank"), emptyList(), emptyList())); + DSL.rank(), + AstDSL.window( + AstDSL.function("rank"), + emptyList(), + Collections.singletonList( + ImmutablePair.of(DEFAULT_ASC, AstDSL.qualifiedName("integer_value"))))); } @SuppressWarnings("unchecked") diff --git a/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java index acb11f0b57..cf7ff70a00 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java @@ -68,6 +68,7 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed() analysisContext)); } + // TODO row_number window function requires window to be ordered. @Test void should_not_generate_sort_operator_if_no_partition_by_and_order_by_list() { assertEquals( @@ -83,6 +84,29 @@ void should_not_generate_sort_operator_if_no_partition_by_and_order_by_list() { analysisContext)); } + @Test + void can_analyze_without_partition_by() { + assertEquals( + LogicalPlanDSL.window( + LogicalPlanDSL.sort( + LogicalPlanDSL.relation("test", table), + ImmutablePair.of(DEFAULT_DESC, DSL.ref("integer_value", INTEGER))), + DSL.named("row_number", DSL.rowNumber()), + new WindowDefinition( + ImmutableList.of(), + ImmutableList.of( + ImmutablePair.of(DEFAULT_DESC, DSL.ref("integer_value", INTEGER))))), + analyzer.analyze( + AstDSL.alias( + "row_number", + AstDSL.window( + AstDSL.function("row_number"), + ImmutableList.of(), + ImmutableList.of( + ImmutablePair.of(DEFAULT_DESC, AstDSL.qualifiedName("integer_value"))))), + analysisContext)); + } + @Test void should_return_original_child_if_project_item_not_windowed() { assertEquals( diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java index 5f2ba86c2f..d13f067b38 100644 --- a/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java @@ -61,6 +61,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.Relation; @@ -504,4 +505,14 @@ public void visitFilter() { project(filter(tableFunction(List.of("1", "2")), booleanLiteral(true))) .accept(visitor, null))); } + + /** + * The {@code getChild()} in {@link Alias} now returns its delegated member. Testing for non-child + * {@link Node}. + */ + @Test + public void visitFetchCursor() { + var plan = new FetchCursor("test"); + assertTrue(visitor.canPaginate(plan, null)); + } } diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstAggregationBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstAggregationBuilder.java index e46147b7a3..1059f76103 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstAggregationBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstAggregationBuilder.java @@ -13,6 +13,7 @@ import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.QueryCompilationError; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -20,8 +21,6 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import org.opensearch.sql.common.utils.StringUtils; -import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor; import org.opensearch.sql.sql.parser.context.QuerySpecification; @@ -76,8 +75,13 @@ public UnresolvedPlan visit(ParseTree groupByClause) { } private UnresolvedPlan buildExplicitAggregation() { + List aliasFreeSelectItems = querySpec.getSelectItems(); List groupByItems = replaceGroupByItemIfAliasOrOrdinal(); - return new Aggregation(new ArrayList<>(querySpec.getAggregators()), emptyList(), groupByItems); + return new Aggregation( + new ArrayList<>(querySpec.getAggregators()), + emptyList(), + groupByItems, + aliasFreeSelectItems); } private UnresolvedPlan buildImplicitAggregation() { @@ -85,15 +89,14 @@ private UnresolvedPlan buildImplicitAggregation() { if (invalidSelectItem.isPresent()) { // Report semantic error to avoid fall back to old engine again - throw new SemanticCheckException( - StringUtils.format( - "Explicit GROUP BY clause is required because expression [%s] " - + "contains non-aggregated column", - invalidSelectItem.get())); + throw QueryCompilationError.groupByClauseIsMissingError(invalidSelectItem.get()); } return new Aggregation( - new ArrayList<>(querySpec.getAggregators()), emptyList(), querySpec.getGroupByItems()); + new ArrayList<>(querySpec.getAggregators()), + emptyList(), + querySpec.getGroupByItems(), + emptyList()); } private List replaceGroupByItemIfAliasOrOrdinal() { diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java index ab96f16263..48889777ad 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java @@ -145,8 +145,16 @@ public UnresolvedPlan visitFromClause(FromClauseContext ctx) { if (ctx.havingClause() != null) { UnresolvedPlan havingPlan = visit(ctx.havingClause()); - verifySupportsCondition(((Filter) havingPlan).getCondition()); - result = visit(ctx.havingClause()).attach(result); + UnresolvedExpression condition = + verifySupportsCondition(((Filter) havingPlan).getCondition()); + if (aggregation != null) { + UnresolvedPlan newHavingPlan = + new AstHavingAggregationBuilder(context.peek(), condition).visit(ctx.havingClause()); + // This having clause with aggregation, attach to new having plan + result = newHavingPlan.attach(result); + } + // attach to filter + result = havingPlan.attach(result); } if (ctx.orderByClause() != null) { @@ -162,15 +170,16 @@ public UnresolvedPlan visitFromClause(FromClauseContext ctx) { * * @param func : Function in HAVING clause */ - private void verifySupportsCondition(UnresolvedExpression func) { + private UnresolvedExpression verifySupportsCondition(UnresolvedExpression func) { if (func instanceof Function) { if (((Function) func).getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { throw new SyntaxCheckException( "Falling back to legacy engine. Nested function is not supported in the HAVING" + " clause."); } - ((Function) func).getFuncArgs().stream().forEach(e -> verifySupportsCondition(e)); + ((Function) func).getFuncArgs().stream().forEach(this::verifySupportsCondition); } + return func; } @Override diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstHavingAggregationBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstHavingAggregationBuilder.java new file mode 100644 index 0000000000..88d2921172 --- /dev/null +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstHavingAggregationBuilder.java @@ -0,0 +1,32 @@ +package org.opensearch.sql.sql.parser; + +import java.util.ArrayList; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.Having; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor; +import org.opensearch.sql.sql.parser.context.QuerySpecification; + +/** + * AST Having Aggregation builder that creates a {@link Having} clause with condition expressions + * and all aggregators in Having. Those aggregators will be pushed down to the underlying {@link + * LogicalAggregation}. + */ +@RequiredArgsConstructor +public class AstHavingAggregationBuilder extends OpenSearchSQLParserBaseVisitor { + private final QuerySpecification querySpec; + private final UnresolvedExpression condition; + + @Override + public UnresolvedPlan visitHavingClause(OpenSearchSQLParser.HavingClauseContext ctx) { + return new Having(createAggregators(), condition); + } + + private List createAggregators() { + return new ArrayList<>(querySpec.getAggregatorsInHaving()); + } +} diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/ParserUtils.java b/sql/src/main/java/org/opensearch/sql/sql/parser/ParserUtils.java index 3c60d43733..dd655938d9 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/ParserUtils.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/ParserUtils.java @@ -10,10 +10,14 @@ import static org.opensearch.sql.ast.tree.Sort.SortOrder; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.OrderByElementContext; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; import lombok.experimental.UtilityClass; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.TerminalNode; +import org.opensearch.sql.ast.Node; /** Parser Utils Class. */ @UtilityClass @@ -50,4 +54,23 @@ public static NullOrder createNullOrder(TerminalNode first, TerminalNode last) { return null; } } + + /** Find the all the nodes from a tree that matches the given predicate. */ + public static List findNodes(Node node, Predicate condition) { + List results = new ArrayList<>(); + findNodesHelper(node, condition, results); + return results; + } + + private static void findNodesHelper( + Node node, Predicate condition, List results) { + if (condition.test(node)) { + results.add((T) node); + } + if (node.getChild() != null) { + for (Node child : node.getChild()) { + findNodesHelper(child, condition, results); + } + } + } } diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java b/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java index 5625371f05..9188b87744 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java @@ -25,7 +25,9 @@ import lombok.Getter; import lombok.ToString; import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.QueryCompilationError; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.QualifiedName; @@ -33,11 +35,13 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AggregateFunctionCallContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.QuerySpecificationContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SelectSpecContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor; import org.opensearch.sql.sql.parser.AstExpressionBuilder; +import org.opensearch.sql.sql.parser.ParserUtils; /** * Query specification domain that collects basic info for a simple query. @@ -67,11 +71,14 @@ public class QuerySpecification { private final Map selectItemsByAlias = new HashMap<>(); /** - * Aggregate function calls that spreads in SELECT, HAVING clause. Since this is going to be - * pushed to aggregation operator, de-duplicate is necessary to avoid duplication. + * Aggregate function calls that spreads in SELECT clause. Since this is going to be pushed to + * aggregation operator, de-duplicate is necessary to avoid duplication. */ private final Set aggregators = new LinkedHashSet<>(); + /** Aggregate function calls that spreads HAVING clause. */ + private final Set aggregatorsInHaving = new LinkedHashSet<>(); + /** * * @@ -85,6 +92,8 @@ public class QuerySpecification { */ private final List groupByItems = new ArrayList<>(); + private final List distinctItems = new ArrayList<>(); + /** Items in ORDER BY clause that may be different forms as above and its options. */ private final List orderByItems = new ArrayList<>(); @@ -131,8 +140,7 @@ private boolean isIntegerLiteral(UnresolvedExpression expr) { private UnresolvedExpression getSelectItemByOrdinal(UnresolvedExpression expr) { int ordinal = (Integer) ((Literal) expr).getValue(); if (ordinal <= 0 || ordinal > selectItems.size()) { - throw new SemanticCheckException( - StringUtils.format("Ordinal [%d] is out of bound of select item list", ordinal)); + throw QueryCompilationError.ordinalRefersOutOfBounds(ordinal); } return selectItems.get(ordinal - 1); } @@ -197,6 +205,9 @@ public Void visitSelectClause(SelectClauseContext ctx) { @Override public Void visitSelectElement(SelectElementContext ctx) { UnresolvedExpression expr = visitAstExpression(ctx.expression()); + if (expr instanceof AggregateFunction && ((AggregateFunction) expr).getDistinct()) { + distinctItems.add(((AggregateFunction) expr).getField()); + } selectItems.add(expr); if (ctx.alias() != null) { @@ -232,6 +243,17 @@ public Void visitFilteredAggregationFunctionCall(FilteredAggregationFunctionCall return super.visitFilteredAggregationFunctionCall(ctx); } + @Override + public Void visitHavingClause(OpenSearchSQLParser.HavingClauseContext ctx) { + UnresolvedExpression expression = visitAstExpression(ctx); + List aggregateFunctions = + ParserUtils.findNodes(expression, n -> n instanceof AggregateFunction); + for (AggregateFunction aggregateFunction : aggregateFunctions) { + aggregatorsInHaving.add(AstDSL.alias(aggregateFunction.toString(), aggregateFunction)); + } + return super.visitHavingClause(ctx); + } + private boolean isDistinct(SelectSpecContext ctx) { return (ctx != null) && (ctx.DISTINCT() != null); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 95188e20b6..0b58ab39d3 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -146,6 +146,17 @@ void can_build_distinct_aggregator() { alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName("name")))))); } + @Disabled + void distinct_is_an_equivalent_grouping_by() { + assertThat( + buildAggregation("SELECT COUNT(DISTINCT name) FROM test GROUP BY age"), + allOf( + hasGroupByItems( + alias("age", qualifiedName("age")), alias("name", qualifiedName("name"))), + hasAggregators( + alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName("name")))))); + } + @Test void should_build_nothing_if_no_group_by_and_no_aggregators_in_select() { assertNull(buildAggregation("SELECT name FROM test")); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 8ab314f695..18f84b1e34 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -17,6 +17,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.having; import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.limit; @@ -176,7 +177,8 @@ public void can_build_group_by_field_name() { ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), emptyList(), ImmutableList.of(alias("name", qualifiedName("name"))), - emptyList()), + emptyList(), + ImmutableList.of(qualifiedName("name"), aggregate("AVG", qualifiedName("age")))), alias("name", qualifiedName("name")), alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), buildAST("SELECT name, AVG(age) FROM test GROUP BY name")); @@ -191,7 +193,10 @@ public void can_build_group_by_function() { ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), emptyList(), ImmutableList.of(alias("abs(name)", function("abs", qualifiedName("name")))), - emptyList()), + emptyList(), + ImmutableList.of( + function("abs", qualifiedName("name")), + aggregate("AVG", qualifiedName("age")))), alias("abs(name)", function("abs", qualifiedName("name"))), alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), buildAST("SELECT abs(name), AVG(age) FROM test GROUP BY abs(name)")); @@ -206,7 +211,10 @@ public void can_build_group_by_uppercase_function() { ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), emptyList(), ImmutableList.of(alias("ABS(name)", function("ABS", qualifiedName("name")))), - emptyList()), + emptyList(), + ImmutableList.of( + function("ABS", qualifiedName("name")), + aggregate("AVG", qualifiedName("age")))), alias("ABS(name)", function("ABS", qualifiedName("name"))), alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), buildAST("SELECT ABS(name), AVG(age) FROM test GROUP BY 1")); @@ -221,7 +229,10 @@ public void can_build_group_by_alias() { ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), emptyList(), ImmutableList.of(alias("abs(name)", function("abs", qualifiedName("name")))), - emptyList()), + emptyList(), + ImmutableList.of( + function("abs", qualifiedName("name")), + aggregate("AVG", qualifiedName("age")))), alias("abs(name)", function("abs", qualifiedName("name")), "n"), alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), buildAST("SELECT abs(name) as n, AVG(age) FROM test GROUP BY n")); @@ -236,7 +247,10 @@ public void can_build_group_by_ordinal() { ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), emptyList(), ImmutableList.of(alias("abs(name)", function("abs", qualifiedName("name")))), - emptyList()), + emptyList(), + ImmutableList.of( + function("abs", qualifiedName("name")), + aggregate("AVG", qualifiedName("age")))), alias("abs(name)", function("abs", qualifiedName("name")), "n"), alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), buildAST("SELECT abs(name) as n, AVG(age) FROM test GROUP BY 1")); @@ -261,14 +275,20 @@ public void can_build_having_clause() { assertEquals( project( filter( - agg( - relation("test"), + having( + agg( + relation("test"), + ImmutableList.of( + alias("AVG(age)", aggregate("AVG", qualifiedName("age"))), + alias("MIN(balance)", aggregate("MIN", qualifiedName("balance")))), + emptyList(), + ImmutableList.of(alias("name", qualifiedName("name"))), + emptyList(), + ImmutableList.of( + qualifiedName("name"), aggregate("AVG", qualifiedName("age")))), ImmutableList.of( - alias("AVG(age)", aggregate("AVG", qualifiedName("age"))), alias("MIN(balance)", aggregate("MIN", qualifiedName("balance")))), - emptyList(), - ImmutableList.of(alias("name", qualifiedName("name"))), - emptyList()), + function(">", aggregate("MIN", qualifiedName("balance")), intLiteral(1000))), function(">", aggregate("MIN", qualifiedName("balance")), intLiteral(1000))), alias("name", qualifiedName("name")), alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), @@ -280,18 +300,32 @@ public void can_build_having_condition_using_alias() { assertEquals( project( filter( - agg( - relation("test"), - ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), - emptyList(), - ImmutableList.of(alias("name", qualifiedName("name"))), - emptyList()), + having( + agg( + relation("test"), + ImmutableList.of(alias("AVG(age)", aggregate("AVG", qualifiedName("age")))), + emptyList(), + ImmutableList.of(alias("name", qualifiedName("name"))), + emptyList(), + ImmutableList.of( + qualifiedName("name"), aggregate("AVG", qualifiedName("age")))), + ImmutableList.of(), + function(">", aggregate("AVG", qualifiedName("age")), intLiteral(1000))), function(">", aggregate("AVG", qualifiedName("age")), intLiteral(1000))), alias("name", qualifiedName("name")), alias("AVG(age)", aggregate("AVG", qualifiedName("age")), "a")), buildAST("SELECT name, AVG(age) AS a FROM test GROUP BY name HAVING a > 1000")); } + @Test + public void can_build_having_without_group_by_clause_equals_where_clause() { + assertEquals( + project( + filter(relation("test"), function("=", qualifiedName("name"), stringLiteral("John"))), + alias("name", qualifiedName("name"))), + buildAST("SELECT name FROM test HAVING name = 'John'")); + } + @Test public void can_build_order_by_field_name() { assertEquals( @@ -354,7 +388,8 @@ public void can_build_select_distinct_clause() { emptyList(), ImmutableList.of( alias("name", qualifiedName("name")), alias("age", qualifiedName("age"))), - emptyList()), + emptyList(), + ImmutableList.of(qualifiedName("name"), qualifiedName("age"))), alias("name", qualifiedName("name")), alias("age", qualifiedName("age"))), buildAST("SELECT DISTINCT name, age FROM test")); @@ -373,7 +408,9 @@ public void can_build_select_distinct_clause_with_function() { "SUBSTRING(name, 1, 2)", function( "SUBSTRING", qualifiedName("name"), intLiteral(1), intLiteral(2)))), - emptyList()), + emptyList(), + ImmutableList.of( + function("SUBSTRING", qualifiedName("name"), intLiteral(1), intLiteral(2)))), alias( "SUBSTRING(name, 1, 2)", function("SUBSTRING", qualifiedName("name"), intLiteral(1), intLiteral(2)))), diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/ParserUtilsTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/ParserUtilsTest.java new file mode 100644 index 0000000000..a2d4c26025 --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/ParserUtilsTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.ast.dsl.AstDSL.agg; +import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; +import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.filter; +import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.having; +import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.project; +import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; +import static org.opensearch.sql.ast.dsl.AstDSL.relation; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.FetchCursor; +import org.opensearch.sql.ast.tree.Limit; + +public class ParserUtilsTest { + @Test + public void testFindNodes() { + Node root = + project( + filter( + having( + agg( + relation("test"), + ImmutableList.of( + alias("AVG(age)", aggregate("AVG", qualifiedName("age"))), + alias("MIN(balance)", aggregate("MIN", qualifiedName("balance")))), + emptyList(), + ImmutableList.of(alias("name", qualifiedName("name"))), + emptyList(), + ImmutableList.of( + qualifiedName("name"), aggregate("AVG", qualifiedName("age")))), + ImmutableList.of( + alias("MIN(balance)", aggregate("MIN", qualifiedName("balance")))), + function(">", aggregate("MIN", qualifiedName("balance")), intLiteral(1000))), + function(">", aggregate("MIN", qualifiedName("balance")), intLiteral(1000))), + alias("name", qualifiedName("name")), + alias("AVG(age)", aggregate("AVG", qualifiedName("age")))); + // test finding a UnresolvedPlan + Node aggNode = ParserUtils.findNodes(root, n -> n instanceof Aggregation).getFirst(); + assertInstanceOf(Aggregation.class, aggNode); + // test finding a UnresolvedExpression + UnresolvedExpression expr = ((Aggregation) aggNode).getAggExprList().getFirst(); + Node aliasNode = ParserUtils.findNodes(expr, n -> n instanceof Alias).getFirst(); + assertInstanceOf(Alias.class, aliasNode); + assertEquals("AVG(age)", ((Alias) aliasNode).getName()); + // test finding a nonexistent node + List nodes = ParserUtils.findNodes(root, n -> n instanceof Limit); + assertTrue(nodes.isEmpty()); + } + + @Test + public void testFindNodesOnNodeWithoutChild() { + Node root = project(new FetchCursor("test"), alias("name", qualifiedName("name"))); + Node child = ParserUtils.findNodes(root, n -> n instanceof FetchCursor).getFirst(); + assertInstanceOf(FetchCursor.class, child); + // FetchCursor.getChild() return null + assertTrue(ParserUtils.findNodes(child, n -> n instanceof Limit).isEmpty()); + } +}