diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 35abe12879..455f2ca444 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.between; import static org.opensearch.sql.ast.dsl.AstDSL.not; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; -import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE; @@ -41,7 +40,6 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MultiFieldRelevanceFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NoFieldRelevanceFunctionContext; -import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScoreRelevanceFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NotExpressionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NullLiteralContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.OverClauseContext; @@ -53,6 +51,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RelevanceFieldAndWeightContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScalarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScalarWindowFunctionContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScoreRelevanceFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ShowDescribePatternContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SignedDecimalContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SignedRealContext; @@ -100,7 +99,6 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.expression.function.BuiltinFunctionName; -import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AlternateMultiMatchQueryContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext; @@ -108,7 +106,6 @@ import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IntervalLiteralContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NestedExpressionAtomContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.OrExpressionContext; -import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RelevanceFunctionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.TableNameContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor; @@ -178,11 +175,11 @@ public UnresolvedExpression visitHighlightFunctionCall( @Override public UnresolvedExpression visitPositionFunction( - PositionFunctionContext ctx) { + PositionFunctionContext ctx) { return new Function( - POSITION.getName().getFunctionName(), - Arrays.asList(visitFunctionArg(ctx.functionArg(0)), - visitFunctionArg(ctx.functionArg(1)))); + POSITION.getName().getFunctionName(), + Arrays.asList(visitFunctionArg(ctx.functionArg(0)), + visitFunctionArg(ctx.functionArg(1)))); } @Override @@ -219,20 +216,20 @@ public UnresolvedExpression visitWindowFunctionClause(WindowFunctionClauseContex List partitionByList = Collections.emptyList(); if (overClause.partitionByClause() != null) { partitionByList = overClause.partitionByClause() - .expression() - .stream() - .map(this::visit) - .collect(Collectors.toList()); + .expression() + .stream() + .map(this::visit) + .collect(Collectors.toList()); } List> sortList = Collections.emptyList(); if (overClause.orderByClause() != null) { sortList = overClause.orderByClause() - .orderByElement() - .stream() - .map(item -> ImmutablePair.of( - createSortOption(item), visit(item.expression()))) - .collect(Collectors.toList()); + .orderByElement() + .stream() + .map(item -> ImmutablePair.of( + createSortOption(item), visit(item.expression()))) + .collect(Collectors.toList()); } return new WindowFunction(visit(ctx.function), partitionByList, sortList); } @@ -301,7 +298,7 @@ public UnresolvedExpression visitLikePredicate(LikePredicateContext ctx) { @Override public UnresolvedExpression visitRegexpPredicate(RegexpPredicateContext ctx) { return new Function(REGEXP.getName().getFunctionName(), - Arrays.asList(visit(ctx.left), visit(ctx.right))); + Arrays.asList(visit(ctx.left), visit(ctx.right))); } @Override @@ -402,9 +399,9 @@ public UnresolvedExpression visitBinaryComparisonPredicate( public UnresolvedExpression visitCaseFunctionCall(CaseFunctionCallContext ctx) { UnresolvedExpression caseValue = (ctx.expression() == null) ? null : visit(ctx.expression()); List whenStatements = ctx.caseFuncAlternative() - .stream() - .map(when -> (When) visit(when)) - .collect(Collectors.toList()); + .stream() + .map(when -> (When) visit(when)) + .collect(Collectors.toList()); UnresolvedExpression elseStatement = (ctx.elseArg == null) ? null : visit(ctx.elseArg); return new Case(caseValue, whenStatements, elseStatement); @@ -429,10 +426,10 @@ public UnresolvedExpression visitConvertedDataType( @Override public UnresolvedExpression visitNoFieldRelevanceFunction( - NoFieldRelevanceFunctionContext ctx) { + NoFieldRelevanceFunctionContext ctx) { return new Function( - ctx.noFieldRelevanceFunctionName().getText().toLowerCase(), - noFieldRelevanceArguments(ctx)); + ctx.noFieldRelevanceFunctionName().getText().toLowerCase(), + noFieldRelevanceArguments(ctx)); } @Override @@ -460,7 +457,7 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( if ((funcName.equalsIgnoreCase(BuiltinFunctionName.MULTI_MATCH.toString()) || funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCH.toString()) || funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCHQUERY.toString())) - && ! ctx.getRuleContexts(AlternateMultiMatchQueryContext.class) + && !ctx.getRuleContexts(AlternateMultiMatchQueryContext.class) .isEmpty()) { return new Function( ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(), @@ -480,13 +477,16 @@ public UnresolvedExpression visitAltMultiFieldRelevanceFunction( altMultiFieldRelevanceFunctionArguments(ctx)); } + /** + * Visit score-relevance function and collect children. + * + * @param ctx the parse tree + * @return children + */ public UnresolvedExpression visitScoreRelevanceFunction(ScoreRelevanceFunctionContext ctx) { - RelevanceFunctionContext relevanceFunction = ctx.relevanceFunction(); - List functionArgs = ctx.functionArg(); - return new ScoreFunction( - visit(ctx.relevanceFunction()), - ctx.functionArg().stream().map(this::visitFunctionArg).collect(Collectors.toList()) + visit(ctx.relevanceFunction()), + ctx.functionArg().stream().map(this::visitFunctionArg).collect(Collectors.toList()) ); } @@ -502,13 +502,14 @@ private Function buildFunction(String functionName, } private QualifiedName visitIdentifiers(List identifiers) { - Boolean isMetadataField = identifiers.stream().filter(id -> id.metadataField() != null).findFirst().isPresent(); + Boolean isMetadataField = + identifiers.stream().anyMatch(id -> id.metadataField() != null); return new QualifiedName( - identifiers.stream() - .map(RuleContext::getText) - .map(StringUtils::unquoteIdentifier) - .collect(Collectors.toList()), - isMetadataField); + identifiers.stream() + .map(RuleContext::getText) + .map(StringUtils::unquoteIdentifier) + .collect(Collectors.toList()), + isMetadataField); } private void fillRelevanceArgs(List args, @@ -523,18 +524,18 @@ private void fillRelevanceArgs(List args, } private List noFieldRelevanceArguments( - NoFieldRelevanceFunctionContext ctx) { + NoFieldRelevanceFunctionContext ctx) { // all the arguments are defaulted to string values // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); builder.add(new UnresolvedArgument("query", - new Literal(StringUtils.unquoteText(ctx.query.getText()), DataType.STRING))); + new Literal(StringUtils.unquoteText(ctx.query.getText()), DataType.STRING))); fillRelevanceArgs(ctx.relevanceArg(), builder); return builder.build(); } private List singleFieldRelevanceArguments( - SingleFieldRelevanceFunctionContext ctx) { + SingleFieldRelevanceFunctionContext ctx) { // all the arguments are defaulted to string values // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); @@ -590,6 +591,7 @@ private List getFormatFunctionArguments( /** * Adds support for multi_match alternate syntax like * MULTI_MATCH('query'='Dale', 'fields'='*name'). + * * @param ctx : Context for multi field relevance function. * @return : Returns list of all arguments for relevance function. */ @@ -602,7 +604,7 @@ private List alternateMultiMatchArguments( String[] fieldAndWeights = StringUtils.unquoteText( ctx.getRuleContexts(AlternateMultiMatchFieldContext.class) - .stream().findFirst().get().argVal.getText()).split(","); + .stream().findFirst().get().argVal.getText()).split(","); for (var fieldAndWeight : fieldAndWeights) { String[] splitFieldAndWeights = fieldAndWeight.split("\\^"); @@ -614,9 +616,9 @@ private List alternateMultiMatchArguments( ctx.getRuleContexts(AlternateMultiMatchQueryContext.class) .stream().findFirst().ifPresent( - arg -> - builder.add(new UnresolvedArgument("query", - new Literal(StringUtils.unquoteText(arg.argVal.getText()), DataType.STRING))) + arg -> + builder.add(new UnresolvedArgument("query", + new Literal(StringUtils.unquoteText(arg.argVal.getText()), DataType.STRING))) ); fillRelevanceArgs(ctx.relevanceArg(), builder);