From c539f07ac873e028894236cbd2feff898720f417 Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Mon, 10 Apr 2023 12:41:55 -0700 Subject: [PATCH] #639: Support OpenSearch metadata fields and the score OpenSearch function (#228) (#1456) Allow metadata fields and score OpenSearch function. Signed-off-by: Andrew Carbonetto (cherry picked from commit e80515172586a099b066c1c050b2a6c3395dda5c) --- .../org/opensearch/sql/analysis/Analyzer.java | 7 + .../sql/analysis/ExpressionAnalyzer.java | 95 +++++++++- .../ExpressionReferenceOptimizer.java | 14 +- .../sql/analysis/TypeEnvironment.java | 25 ++- .../sql/ast/AbstractNodeVisitor.java | 5 + .../org/opensearch/sql/ast/dsl/AstDSL.java | 7 +- .../sql/ast/expression/ScoreFunction.java | 36 ++++ .../org/opensearch/sql/expression/DSL.java | 14 +- .../function/BuiltinFunctionName.java | 6 + .../function/OpenSearchFunctions.java | 21 ++- .../org/opensearch/sql/storage/Table.java | 7 + .../opensearch/sql/analysis/AnalyzerTest.java | 164 ++++++++++++++++++ .../sql/analysis/AnalyzerTestBase.java | 4 + .../sql/analysis/ExpressionAnalyzerTest.java | 164 ++++++++++++++++++ docs/user/dql/basics.rst | 40 +++++ docs/user/dql/functions.rst | 42 +++++ .../sql/legacy/CsvFormatResponseIT.java | 8 +- .../opensearch/sql/legacy/MethodQueryIT.java | 6 +- .../sql/legacy/PrettyFormatResponseIT.java | 5 +- .../org/opensearch/sql/sql/IdentifierIT.java | 65 +++++++ .../java/org/opensearch/sql/sql/MatchIT.java | 14 ++ .../org/opensearch/sql/sql/ScoreQueryIT.java | 142 +++++++++++++++ .../request/OpenSearchQueryRequest.java | 12 +- .../request/OpenSearchRequestBuilder.java | 5 + .../request/OpenSearchScrollRequest.java | 10 +- .../response/OpenSearchResponse.java | 64 ++++++- .../opensearch/storage/OpenSearchIndex.java | 25 ++- .../storage/OpenSearchIndexScan.java | 11 +- .../scan/OpenSearchIndexScanQueryBuilder.java | 19 +- .../storage/script/sort/SortQueryBuilder.java | 3 + .../client/OpenSearchNodeClientTest.java | 2 +- .../client/OpenSearchRestClientTest.java | 2 +- .../request/OpenSearchQueryRequestTest.java | 61 +++++++ .../request/OpenSearchRequestBuilderTest.java | 4 +- .../request/OpenSearchScrollRequestTest.java | 87 ++++++++++ .../response/OpenSearchResponseTest.java | 145 ++++++++++++++-- .../storage/OpenSearchIndexTest.java | 15 ++ .../OpenSearchIndexScanOptimizationTest.java | 142 +++++++++++++++ sql/src/main/antlr/OpenSearchSQLLexer.g4 | 8 +- sql/src/main/antlr/OpenSearchSQLParser.g4 | 8 + .../sql/sql/parser/AstExpressionBuilder.java | 34 +++- .../sql/parser/AstExpressionBuilderTest.java | 51 ++++++ 42 files changed, 1537 insertions(+), 62 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/ast/expression/ScoreFunction.java create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index ba40020782..b7c03db6d4 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -62,6 +62,7 @@ import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.data.model.ExprMissingValue; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -150,6 +151,9 @@ public LogicalPlan visitRelation(Relation node, AnalysisContext context) { dataSourceSchemaIdentifierNameResolver.getIdentifierName()); } table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); + table.getReservedFieldTypes().forEach( + (k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v) + ); // Put index name or its alias in index namespace on type environment so qualifier // can be removed when analyzing qualified name. The value (expr type) here doesn't matter. @@ -193,6 +197,9 @@ public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext contex TypeEnvironment curEnv = context.peek(); Table table = tableFunctionImplementation.applyArguments(); table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); + table.getReservedFieldTypes().forEach( + (k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v) + ); curEnv.define(new Symbol(Namespace.INDEX_NAME, dataSourceSchemaIdentifierNameResolver.getIdentifierName()), STRUCT); return new LogicalRelation(dataSourceSchemaIdentifierNameResolver.getIdentifierName(), diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index ff3c01d5b8..436c26374c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -8,8 +8,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.and; import static org.opensearch.sql.ast.dsl.AstDSL.compare; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.GTE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTE; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -31,6 +29,7 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -42,6 +41,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.RelevanceFieldList; +import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedAttribute; @@ -51,6 +51,7 @@ import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -67,6 +68,7 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.expression.span.SpanExpression; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; @@ -207,6 +209,65 @@ public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext return new HighlightExpression(expr); } + /** + * visitScoreFunction removes the score function from the AST and replaces it with the child + * relevance function node. If the optional boost variable is provided, the boost argument + * of the relevance function is combined. + * + * @param node score function node + * @param context analysis context for the query + * @return resolved relevance function + */ + public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) { + Literal boostArg = node.getRelevanceFieldWeight(); + if (!boostArg.getType().equals(DataType.DOUBLE)) { + throw new SemanticCheckException(String.format("Expected boost type '%s' but got '%s'", + DataType.DOUBLE.name(), boostArg.getType().name())); + } + Double thisBoostValue = ((Double) boostArg.getValue()); + + // update the existing unresolved expression to add a boost argument if it doesn't exist + // OR multiply the existing boost argument + Function relevanceQueryUnresolvedExpr = (Function) node.getRelevanceQuery(); + List relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs(); + + boolean doesFunctionContainBoostArgument = false; + List updatedFuncArgs = new ArrayList<>(); + for (UnresolvedExpression expr : relevanceFuncArgs) { + String argumentName = ((UnresolvedArgument) expr).getArgName(); + if (argumentName.equalsIgnoreCase("boost")) { + doesFunctionContainBoostArgument = true; + Literal boostArgLiteral = (Literal) ((UnresolvedArgument) expr).getValue(); + Double boostValue = + Double.parseDouble((String) boostArgLiteral.getValue()) * thisBoostValue; + UnresolvedArgument newBoostArg = new UnresolvedArgument( + argumentName, + new Literal(boostValue.toString(), DataType.STRING) + ); + updatedFuncArgs.add(newBoostArg); + } else { + updatedFuncArgs.add(expr); + } + } + + // since nothing was found, add an argument + if (!doesFunctionContainBoostArgument) { + UnresolvedArgument newBoostArg = new UnresolvedArgument( + "boost", new Literal(Double.toString(thisBoostValue), DataType.STRING)); + updatedFuncArgs.add(newBoostArg); + } + + // create a new function expression with boost argument and resolve it + Function updatedRelevanceQueryUnresolvedExpr = new Function( + relevanceQueryUnresolvedExpr.getFuncName(), + updatedFuncArgs); + OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr = + (OpenSearchFunctions.OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr + .accept(this, context); + relevanceQueryExpr.setScoreTracked(true); + return relevanceQueryExpr; + } + @Override public Expression visitIn(In node, AnalysisContext context) { return visitIn(node.getField(), node.getValueList(), context); @@ -297,6 +358,23 @@ public Expression visitAllFields(AllFields node, AnalysisContext context) { @Override public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) { QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); + + // check for reserved words in the identifier + for (String part : node.getParts()) { + for (TypeEnvironment typeEnv = context.peek(); + typeEnv != null; + typeEnv = typeEnv.getParent()) { + Optional exprType = typeEnv.getReservedSymbolTable().lookup( + new Symbol(Namespace.FIELD_NAME, part)); + if (exprType.isPresent()) { + return visitMetadata( + qualifierAnalyzer.unqualified(node), + (ExprCoreType) exprType.get(), + context + ); + } + } + } return visitIdentifier(qualifierAnalyzer.unqualified(node), context); } @@ -313,6 +391,19 @@ public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisConte return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context)); } + /** + * If QualifiedName is actually a reserved metadata field, return the expr type associated + * with the metadata field. + * @param ident metadata field name + * @param context analysis context + * @return DSL reference + */ + private Expression visitMetadata(String ident, + ExprCoreType exprCoreType, + AnalysisContext context) { + return DSL.ref(ident, exprCoreType); + } + private Expression visitIdentifier(String ident, AnalysisContext context) { // ParseExpression will always override ReferenceExpression when ident conflicts for (NamedExpression expr : context.getNamedParseExpressions()) { diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index f75bcd5a1d..eaf5c4abca 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -19,6 +19,7 @@ import org.opensearch.sql.expression.conditional.cases.CaseClause; import org.opensearch.sql.expression.conditional.cases.WhenClause; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; @@ -70,8 +71,17 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context final List args = node.getArguments().stream().map(expr -> expr.accept(this, context)) .collect(Collectors.toList()); - return (Expression) repository.compile(context.getFunctionProperties(), - node.getFunctionName(), args); + Expression optimizedFunctionExpression = (Expression) repository.compile( + context.getFunctionProperties(), + node.getFunctionName(), + args + ); + // Propagate scoreTracked for OpenSearch functions + if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) { + ((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression).setScoreTracked( + ((OpenSearchFunctions.OpenSearchFunction)node).isScoreTracked()); + } + return optimizedFunctionExpression; } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java index c86d8109ad..c9fd8030e0 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java +++ b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java @@ -29,14 +29,30 @@ public class TypeEnvironment implements Environment { private final TypeEnvironment parent; private final SymbolTable symbolTable; + @Getter + private final SymbolTable reservedSymbolTable; + + /** + * Constructor with empty symbol tables. + * + * @param parent parent environment + */ public TypeEnvironment(TypeEnvironment parent) { this.parent = parent; this.symbolTable = new SymbolTable(); + this.reservedSymbolTable = new SymbolTable(); } + /** + * Constructor with empty reserved symbol table. + * + * @param parent parent environment + * @param symbolTable type table + */ public TypeEnvironment(TypeEnvironment parent, SymbolTable symbolTable) { this.parent = parent; this.symbolTable = symbolTable; + this.reservedSymbolTable = new SymbolTable(); } /** @@ -59,6 +75,7 @@ public ExprType resolve(Symbol symbol) { /** * Resolve all fields in the current environment. + * * @param namespace a namespace * @return all symbols in the namespace */ @@ -102,7 +119,11 @@ public void remove(ReferenceExpression ref) { * Clear all fields in the current environment. */ public void clearAllFields() { - lookupAllFields(FIELD_NAME).keySet().stream() - .forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v))); + lookupAllFields(FIELD_NAME).keySet().forEach( + v -> remove(new Symbol(Namespace.FIELD_NAME, v))); + } + + public void addReservedWord(Symbol symbol, ExprType type) { + reservedSymbolTable.store(symbol, type); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 393de05164..d2ebb9eb99 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -29,6 +29,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.RelevanceFieldList; +import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedAttribute; @@ -278,6 +279,10 @@ public T visitHighlightFunction(HighlightFunction node, C context) { return visitChildren(node, context); } + public T visitScoreFunction(ScoreFunction node, C context) { + return visitChildren(node, context); + } + public T visitStatement(Statement node, C context) { return visit(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 039b6380f7..de2ab5404a 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -34,6 +34,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; @@ -60,7 +61,6 @@ import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; -import org.opensearch.sql.expression.function.BuiltinFunctionName; /** * Class of static methods to create specific node instances. @@ -285,6 +285,11 @@ public UnresolvedExpression highlight(UnresolvedExpression fieldName, return new HighlightFunction(fieldName, arguments); } + public UnresolvedExpression score(UnresolvedExpression relevanceQuery, + Literal relevanceFieldWeight) { + return new ScoreFunction(relevanceQuery, relevanceFieldWeight); + } + public UnresolvedExpression window(UnresolvedExpression function, List partitionByList, List> sortList) { diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/ScoreFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/ScoreFunction.java new file mode 100644 index 0000000000..1b73f9bd95 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/ScoreFunction.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of Score function. + * Score takes a relevance-search expression as an argument and returns it + */ +@AllArgsConstructor +@EqualsAndHashCode(callSuper = false) +@Getter +@ToString +public class ScoreFunction extends UnresolvedExpression { + private final UnresolvedExpression relevanceQuery; + private final Literal relevanceFieldWeight; + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitScoreFunction(this, context); + } + + @Override + public List getChild() { + return List.of(relevanceQuery); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 2f791f46b6..f60a06f1d9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -862,7 +862,19 @@ public static FunctionExpression match_bool_prefix(Expression... args) { } public static FunctionExpression wildcard_query(Expression... args) { - return compile(FunctionProperties.None,BuiltinFunctionName.WILDCARD_QUERY, args); + return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args); + } + + public static FunctionExpression score(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args); + } + + public static FunctionExpression scorequery(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args); + } + + public static FunctionExpression score_query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args); } public static FunctionExpression now(FunctionProperties functionProperties, diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 359355a898..5aaae1356f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -122,6 +122,7 @@ public enum BuiltinFunctionName { WEEK_OF_YEAR(FunctionName.of("week_of_year")), YEAR(FunctionName.of("year")), YEARWEEK(FunctionName.of("yearweek")), + // `now`-like functions NOW(FunctionName.of("now")), CURDATE(FunctionName.of("curdate")), @@ -132,6 +133,7 @@ public enum BuiltinFunctionName { CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")), LOCALTIMESTAMP(FunctionName.of("localtimestamp")), SYSDATE(FunctionName.of("sysdate")), + /** * Text Functions. */ @@ -255,6 +257,10 @@ public enum BuiltinFunctionName { MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")), HIGHLIGHT(FunctionName.of("highlight")), MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")), + SCORE(FunctionName.of("score")), + SCOREQUERY(FunctionName.of("scorequery")), + SCORE_QUERY(FunctionName.of("score_query")), + /** * Legacy Relevance Function. */ diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index 842cf25cd6..9a50aca344 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -5,11 +5,14 @@ package org.opensearch.sql.expression.function; +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; + import java.util.List; import java.util.stream.Collectors; +import lombok.Getter; +import lombok.Setter; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; @@ -32,6 +35,7 @@ public void register(BuiltinFunctionRepository repository) { repository.register(simple_query_string()); repository.register(query()); repository.register(query_string()); + // Register MATCHPHRASE as MATCH_PHRASE as well for backwards // compatibility. repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); @@ -40,6 +44,9 @@ public void register(BuiltinFunctionRepository repository) { repository.register(match_phrase_prefix()); repository.register(wildcard_query(BuiltinFunctionName.WILDCARD_QUERY)); repository.register(wildcard_query(BuiltinFunctionName.WILDCARDQUERY)); + repository.register(score(BuiltinFunctionName.SCORE)); + repository.register(score(BuiltinFunctionName.SCOREQUERY)); + repository.register(score(BuiltinFunctionName.SCORE_QUERY)); } private static FunctionResolver match_bool_prefix() { @@ -86,10 +93,19 @@ private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery return new RelevanceFunctionResolver(funcName); } + private static FunctionResolver score(BuiltinFunctionName score) { + FunctionName funcName = score.getName(); + return new RelevanceFunctionResolver(funcName); + } + public static class OpenSearchFunction extends FunctionExpression { private final FunctionName functionName; private final List arguments; + @Getter + @Setter + private boolean isScoreTracked; + /** * Required argument constructor. * @param functionName name of the function @@ -99,6 +115,7 @@ public OpenSearchFunction(FunctionName functionName, List arguments) super(functionName, arguments); this.functionName = functionName; this.arguments = arguments; + this.isScoreTracked = false; } @Override @@ -110,7 +127,7 @@ public ExprValue valueOf(Environment valueEnv) { @Override public ExprType type() { - return ExprCoreType.BOOLEAN; + return BOOLEAN; } @Override diff --git a/core/src/main/java/org/opensearch/sql/storage/Table.java b/core/src/main/java/org/opensearch/sql/storage/Table.java index 496281fa8d..e2586ed22c 100644 --- a/core/src/main/java/org/opensearch/sql/storage/Table.java +++ b/core/src/main/java/org/opensearch/sql/storage/Table.java @@ -43,6 +43,13 @@ default void create(Map schema) { */ Map getFieldTypes(); + /** + * Get the {@link ExprType} for each meta-field (reserved fields) in the table. + */ + default Map getReservedFieldTypes() { + return Map.of(); + } + /** * Implement a {@link LogicalPlan} by {@link PhysicalPlan} in storage engine. * diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 1db29a6a42..f711c2362d 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -31,6 +31,7 @@ import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; @@ -71,17 +72,22 @@ import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.HighlightExpression; +import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; @@ -102,6 +108,54 @@ public void filter_relation() { AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); } + @Test + public void filter_relation_with_reserved_qualifiedName() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + DSL.equal(DSL.ref("_test", STRING), DSL.literal(stringValue("value")))), + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.equalTo(AstDSL.qualifiedName("_test"), AstDSL.stringLiteral("value")))); + } + + @Test + public void filter_relation_with_invalid_qualifiedName_SemanticCheckException() { + UnresolvedPlan invalidFieldPlan = AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.equalTo( + AstDSL.qualifiedName("_invalid"), + AstDSL.stringLiteral("value")) + ); + + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> analyze(invalidFieldPlan)); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=_invalid) in type env", + exception.getMessage()); + } + + @Test + public void filter_relation_with_invalid_qualifiedName_ExpressionEvaluationException() { + UnresolvedPlan typeMismatchPlan = AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.equalTo(AstDSL.qualifiedName("_test"), AstDSL.intLiteral(1)) + ); + + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> analyze(typeMismatchPlan)); + assertEquals( + "= function expected {[BYTE,BYTE],[SHORT,SHORT],[INTEGER,INTEGER],[LONG,LONG]," + + "[FLOAT,FLOAT],[DOUBLE,DOUBLE],[STRING,STRING],[BOOLEAN,BOOLEAN],[DATE,DATE]," + + "[TIME,TIME],[DATETIME,DATETIME],[TIMESTAMP,TIMESTAMP],[INTERVAL,INTERVAL]," + + "[STRUCT,STRUCT],[ARRAY,ARRAY]}, but get [STRING,INTEGER]", + exception.getMessage()); + } + @Test public void filter_relation_with_alias() { assertAnalyzeEqual( @@ -214,6 +268,116 @@ public void filter_relation_with_multiple_tables() { AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); } + @Test + public void analyze_filter_visit_score_function() { + UnresolvedPlan unresolvedPlan = AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function("match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3")) + ), AstDSL.doubleLiteral(1.0)) + ); + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("boost", "3.0") + ) + ), + unresolvedPlan + ); + + LogicalPlan logicalPlan = analyze(unresolvedPlan); + OpenSearchFunctions.OpenSearchFunction relevanceQuery = + (OpenSearchFunctions.OpenSearchFunction)((LogicalFilter) logicalPlan).getCondition(); + assertEquals(true, relevanceQuery.isScoreTracked()); + } + + @Test + public void analyze_filter_visit_without_score_function() { + UnresolvedPlan unresolvedPlan = AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3")) + ) + ); + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("boost", "3") + ) + ), + unresolvedPlan + ); + + LogicalPlan logicalPlan = analyze(unresolvedPlan); + OpenSearchFunctions.OpenSearchFunction relevanceQuery = + (OpenSearchFunctions.OpenSearchFunction)((LogicalFilter) logicalPlan).getCondition(); + assertEquals(false, relevanceQuery.isScoreTracked()); + } + + @Test + public void analyze_filter_visit_score_function_with_double_boost() { + UnresolvedPlan unresolvedPlan = AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function("match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("slop", stringLiteral("3")) + ), new Literal(3.0, DataType.DOUBLE) + ) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("schema", table), + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3"), + DSL.namedArgument("boost", "3.0") + ) + ), + unresolvedPlan + ); + + LogicalPlan logicalPlan = analyze(unresolvedPlan); + OpenSearchFunctions.OpenSearchFunction relevanceQuery = + (OpenSearchFunctions.OpenSearchFunction)((LogicalFilter) logicalPlan).getCondition(); + assertEquals(true, relevanceQuery.isScoreTracked()); + } + + @Test + public void analyze_filter_visit_score_function_with_unsupported_boost_SemanticCheckException() { + UnresolvedPlan unresolvedPlan = AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function("match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3")) + ), AstDSL.stringLiteral("3.0") + ) + ); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> analyze(unresolvedPlan)); + assertEquals( + "Expected boost type 'DOUBLE' but got 'STRING'", + exception.getMessage()); + } + @Test public void head_relation() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 2ec411ba54..d7222d466f 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -112,6 +112,10 @@ public Map getFieldTypes() { public PhysicalPlan implement(LogicalPlan plan) { throw new UnsupportedOperationException(); } + + public Map getReservedFieldTypes() { + return ImmutableMap.of("_test", STRING); + } }); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index c7a11658e3..c7cd8d0556 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -20,7 +20,9 @@ import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import static org.opensearch.sql.data.type.ExprCoreType.FLOAT; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import static org.opensearch.sql.expression.DSL.ref; @@ -228,6 +230,32 @@ public void qualified_name_with_qualifier() { analysisContext.pop(); } + @Test + public void qualified_name_with_reserved_symbol() { + analysisContext.push(); + + analysisContext.peek().addReservedWord(new Symbol(Namespace.FIELD_NAME, "_reserved"), STRING); + analysisContext.peek().addReservedWord(new Symbol(Namespace.FIELD_NAME, "_priority"), FLOAT); + analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); + assertAnalyzeEqual( + DSL.ref("_priority", FLOAT), + qualifiedName("_priority") + ); + assertAnalyzeEqual( + DSL.ref("_reserved", STRING), + qualifiedName("index_alias", "_reserved") + ); + + // reserved fields take priority over symbol table + analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "_reserved"), LONG); + assertAnalyzeEqual( + DSL.ref("_reserved", STRING), + qualifiedName("index_alias", "_reserved") + ); + + analysisContext.pop(); + } + @Test public void interval() { assertAnalyzeEqual( @@ -600,6 +628,142 @@ public void match_phrase_prefix_all_params() { ); } + @Test void score_function_expression() { + assertAnalyzeEqual( + DSL.score( + DSL.namedArgument("RelevanceQuery", + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3") + ) + )), + AstDSL.function("score", + unresolvedArg("RelevanceQuery", + AstDSL.function("match_phrase_prefix", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")) + ) + ) + ) + ); + } + + @Test void score_function_with_boost() { + assertAnalyzeEqual( + DSL.score( + DSL.namedArgument("RelevanceQuery", + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("boost", "3.0") + )), + DSL.namedArgument("boost", "2") + ), + AstDSL.function("score", + unresolvedArg("RelevanceQuery", + AstDSL.function("match_phrase_prefix", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("boost", stringLiteral("3.0")) + ) + ), + unresolvedArg("boost", stringLiteral("2")) + ) + ); + } + + @Test void score_query_function_expression() { + assertAnalyzeEqual( + DSL.score_query( + DSL.namedArgument("RelevanceQuery", + DSL.wildcard_query( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query") + ) + )), + AstDSL.function("score_query", + unresolvedArg("RelevanceQuery", + AstDSL.function("wildcard_query", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")) + ) + ) + ) + ); + } + + @Test void score_query_function_with_boost() { + assertAnalyzeEqual( + DSL.score_query( + DSL.namedArgument("RelevanceQuery", + DSL.wildcard_query( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query") + ) + ), + DSL.namedArgument("boost", "2.0") + ), + AstDSL.function("score_query", + unresolvedArg("RelevanceQuery", + AstDSL.function("wildcard_query", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")) + ) + ), + unresolvedArg("boost", stringLiteral("2.0")) + ) + ); + } + + @Test void scorequery_function_expression() { + assertAnalyzeEqual( + DSL.scorequery( + DSL.namedArgument("RelevanceQuery", + DSL.simple_query_string( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3") + ) + )), + AstDSL.function("scorequery", + unresolvedArg("RelevanceQuery", + AstDSL.function("simple_query_string", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")) + ) + ) + ) + ); + } + + @Test + void scorequery_function_with_boost() { + assertAnalyzeEqual( + DSL.scorequery( + DSL.namedArgument("RelevanceQuery", + DSL.simple_query_string( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3") + )), + DSL.namedArgument("boost", "2.0") + ), + AstDSL.function("scorequery", + unresolvedArg("RelevanceQuery", + AstDSL.function("simple_query_string", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")) + ) + ), + unresolvedArg("boost", stringLiteral("2.0")) + ) + ); + } + @Test public void function_isnt_calculated_on_analyze() { assertTrue(analyze(function("now")) instanceof FunctionExpression); diff --git a/docs/user/dql/basics.rst b/docs/user/dql/basics.rst index 9762f23988..b7e8cf35a4 100644 --- a/docs/user/dql/basics.rst +++ b/docs/user/dql/basics.rst @@ -155,6 +155,46 @@ Result set: | Nanette| Bates| +---------+--------+ +One can also provide meta-field name(s) to retrieve reserved-fields (beginning with underscore) from OpenSearch documents. Meta-fields are not output +from wildcard calls (`SELECT *`) and must be explicitly included to be returned. + +SQL query:: + + POST /_plugins/_sql + { + "query" : "SELECT firstname, lastname, _id, _index, _sort FROM accounts" + } + +Explain:: + + { + "from" : 0, + "size" : 200, + "_source" : { + "includes" : [ + "firstname", + "_id", + "_index", + "_sort", + "lastname" + ], + "excludes" : [ ] + } + } + + +This produces results like this for example:: + + os> SELECT firstname, lastname, _index, _sort FROM accounts; + fetched rows / total rows = 4/4 + +-------------+------------+----------+---------+ + | firstname | lastname | _index | _sort | + |-------------+------------+----------+---------| + | Amber | Duke | accounts | -2 | + | Hattie | Bond | accounts | -2 | + | Nanette | Bates | accounts | -2 | + | Dale | Adams | accounts | -2 | + +-------------+------------+----------+---------+ Example 3: Using Field Alias ---------------------------- diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 8ba2397f98..68d975a318 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -4209,6 +4209,48 @@ Another example to show how to set custom values for the optional parameters:: | 1 | The House at Pooh Corner | Alan Alexander Milne | +------+--------------------------+----------------------+ +SCORE +------------ + +Description +>>>>>>>>>>> + +``score(relevance_expression[, boost])`` +``score_query(relevance_expression[, boost])`` +``scorequery(relevance_expression[, boost])`` + +The `SCORE()` function calculates the `_score` of any documents matching the enclosed relevance-based expression. The `SCORE()` +function expects one argument with an optional second argument. The first argument is the relevance-based search expression. +The second argument is an optional floating-point boost to the score (the default value is 1.0). + +The `SCORE()` function sets `track_scores=true` for OpenSearch requests. Without it, `_score` fields may return `null` for some +relevance-based search expressions. + +Please refer to examples below: + +| ``score(query('Tags:taste OR Body:taste', ...), 2.0)`` + +The `score_query` and `scorequery` functions are alternative names for the `score` function. + +Example boosting score:: + + os> select *, _score from books where score(query('title:Pooh House', default_operator='AND'), 2.0); + fetched rows / total rows = 1/1 + +------+--------------------------+----------------------+-----------+ + | id | title | author | _score | + |------+--------------------------+----------------------+-----------| + | 1 | The House at Pooh Corner | Alan Alexander Milne | 1.5884793 | + +------+--------------------------+----------------------+-----------+ + + os> select *, _score from books where score(query('title:Pooh House', default_operator='AND'), 5.0) OR score(query('title:Winnie', default_operator='AND'), 1.5); + fetched rows / total rows = 2/2 + +------+--------------------------+----------------------+-----------+ + | id | title | author | _score | + |------+--------------------------+----------------------+-----------| + | 1 | The House at Pooh Corner | Alan Alexander Milne | 3.9711983 | + | 2 | Winnie-the-Pooh | Alan Alexander Milne | 1.1581701 | + +------+--------------------------+----------------------+-----------+ + HIGHLIGHT ------------ diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java index 9a08302577..52dcf9a068 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java @@ -99,7 +99,6 @@ public void specificPercentilesIntAndDouble() throws IOException { } } - @Ignore("only work for legacy engine") public void nestedObjectsAndArraysAreQuoted() throws IOException { final String query = String.format(Locale.ROOT, "SELECT * FROM %s WHERE _id = 5", TEST_INDEX_NESTED_TYPE); @@ -114,7 +113,6 @@ public void nestedObjectsAndArraysAreQuoted() throws IOException { Assert.assertThat(result, containsString(expectedMessage)); } - @Ignore("only work for legacy engine") public void arraysAreQuotedInFlatMode() throws IOException { setFlatOption(true); @@ -492,7 +490,7 @@ public void percentileAggregationTest() throws Exception { @Test public void includeScore() throws Exception { String query = String.format(Locale.ROOT, - "select age , firstname from %s where age > 31 order by _score desc limit 2 ", + "select age, firstname, _score from %s where age > 31 order by _score desc limit 2 ", TEST_INDEX_ACCOUNT); CSVResult csvResult = executeCsvRequest(query, false, true, false); List headers = csvResult.getHeaders(); @@ -546,10 +544,10 @@ public void twoCharsSeperator() throws Exception { } - @Ignore("only work for legacy engine") + @Ignore("tested in @see: org.opensearch.sql.sql.IdentifierIT.testMetafieldIdentifierTest") public void includeIdAndNotTypeOrScore() throws Exception { String query = String.format(Locale.ROOT, - "select age , firstname from %s where lastname = 'Marquez' ", TEST_INDEX_ACCOUNT); + "select age, firstname, _id from %s where lastname = 'Marquez' ", TEST_INDEX_ACCOUNT); CSVResult csvResult = executeCsvRequest(query, false, false, true); List headers = csvResult.getHeaders(); Assert.assertEquals(3, headers.size()); diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/MethodQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/MethodQueryIT.java index 680058d844..027228a92b 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/MethodQueryIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/MethodQueryIT.java @@ -12,6 +12,7 @@ import java.io.IOException; import java.util.Locale; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; /** @@ -111,13 +112,13 @@ public void matchQueryTest() throws IOException { * * @throws IOException */ - // todo @Test + @Ignore("score query no longer maps to constant_score in the V2 engine - @see org.opensearch.sql.sql.ScoreQueryIT") public void scoreQueryTest() throws IOException { final String result = explainQuery(String.format(Locale.ROOT, "select address from %s " + "where score(matchQuery(address, 'Lane'),100) " + - "or score(matchQuery(address,'Street'),0.5) order by _score desc limit 3", + "or score(matchQuery(address,'Street'),0.5) order by _score desc limit 3", TestsConstants.TEST_INDEX_ACCOUNT)); Assert.assertThat(result, both(containsString("{\"constant_score\":" + @@ -176,6 +177,7 @@ public void wildcardQueryTest() throws IOException { * @throws IOException */ @Test + @Ignore("score query no longer handled by legacy engine - @see org.opensearch.sql.sql.ScoreQueryIT") public void matchPhraseQueryTest() throws IOException { final String result = explainQuery(String.format(Locale.ROOT, "select address from %s " + diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java index 226645ce85..1e2073acbd 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java @@ -126,10 +126,11 @@ public void selectWrongField() throws IOException { } @Test + @Ignore("_score tested in V2 engine - @see org.opensearch.sql.sql.ScoreQueryIT") public void selectScore() throws IOException { JSONObject response = executeQuery( - String.format(Locale.ROOT, "SELECT _score FROM %s WHERE balance > 30000", - TestsConstants.TEST_INDEX_ACCOUNT)); + String.format(Locale.ROOT, "SELECT _score FROM %s WHERE SCORE(match_phrase(phrase, 'brown fox'))", + TestsConstants.TEST_INDEX_PHRASE)); List fields = Collections.singletonList("_score"); assertContainsColumns(getSchema(response), fields); diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java index 591364ea19..d5c194968d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java @@ -64,6 +64,65 @@ public void testMultipleQueriesWithSpecialIndexNames() throws IOException { queryAndAssertTheDoc("SELECT * FROM test.two"); } + @Test + public void testDoubleUnderscoreIdentifierTest() throws IOException { + new Index("test.twounderscores") + .addDoc("{\"__age\": 30}"); + final JSONObject result = new JSONObject(executeQuery("SELECT __age FROM test.twounderscores", "jdbc")); + + verifySchema(result, + schema("__age", null, "long")); + verifyDataRows(result, rows(30)); + } + + @Test + public void testMetafieldIdentifierTest() throws IOException { + // create an index, but the contents doesn't matter + String id = "12345"; + String index = "test.metafields"; + new Index(index).addDoc("{\"age\": 30}", id); + + // Execute using field metadata values + final JSONObject result = new JSONObject(executeQuery( + "SELECT *, _id, _index, _score, _maxscore, _sort " + + "FROM " + index, + "jdbc")); + + // Verify that the metadata values are returned when requested + verifySchema(result, + schema("age", null, "long"), + schema("_id", null, "keyword"), + schema("_index", null, "keyword"), + schema("_score", null, "float"), + schema("_maxscore", null, "float"), + schema("_sort", null, "long")); + verifyDataRows(result, rows(30, id, index, 1.0, 1.0, -2)); + } + + @Test + public void testMetafieldIdentifierWithAliasTest() throws IOException { + // create an index, but the contents doesn't matter + String id = "99999"; + String index = "test.aliasmetafields"; + new Index(index).addDoc("{\"age\": 30}", id); + + // Execute using field metadata values + final JSONObject result = new JSONObject(executeQuery( + "SELECT _id AS A, _index AS B, _score AS C, _maxscore AS D, _sort AS E " + + "FROM " + index + " " + + "WHERE _id = \\\"" + id + "\\\"", + "jdbc")); + + // Verify that the metadata values are returned when requested + verifySchema(result, + schema("_id", "A", "keyword"), + schema("_index", "B", "keyword"), + schema("_score", "C", "float"), + schema("_maxscore", "D", "float"), + schema("_sort", "E", "long")); + verifyDataRows(result, rows(id, index, null, null, -2)); + } + private void createIndexWithOneDoc(String... indexNames) throws IOException { for (String indexName : indexNames) { new Index(indexName).addDoc("{\"age\": 30}"); @@ -98,6 +157,12 @@ void addDoc(String doc) { indexDoc.setJsonEntity(doc); performRequest(client(), indexDoc); } + + void addDoc(String doc, String id) { + Request indexDoc = new Request("POST", String.format("/%s/_doc/%s?refresh=true", indexName, id)); + indexDoc.setJsonEntity(doc); + performRequest(client(), indexDoc); + } } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java index 28573fdd10..9885ddfa33 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java @@ -5,6 +5,7 @@ package org.opensearch.sql.sql; +import static org.hamcrest.Matchers.containsString; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; @@ -12,9 +13,12 @@ import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; +import java.util.Locale; import org.json.JSONObject; +import org.junit.Assert; import org.junit.Test; import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; import org.opensearch.sql.legacy.utils.StringUtils; public class MatchIT extends SQLIntegTestCase { @@ -147,4 +151,14 @@ public void match_alternate_syntaxes_return_the_same_results() throws IOExceptio assertEquals(result1.getInt("total"), result2.getInt("total")); assertEquals(result1.getInt("total"), result3.getInt("total")); } + + @Test + public void matchPhraseQueryTest() throws IOException { + final String result = explainQuery(String.format(Locale.ROOT, + "select address from %s " + + "where address= matchPhrase('671 Bristol Street') order by _score desc limit 3", + TestsConstants.TEST_INDEX_ACCOUNT)); + Assert.assertThat(result, + containsString("{\\\"match_phrase\\\":{\\\"address\\\":{\\\"query\\\":\\\"671 Bristol Street\\\"")); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java new file mode 100644 index 0000000000..03df7d0e29 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java @@ -0,0 +1,142 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import org.json.JSONObject; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +import java.io.IOException; +import java.util.Locale; + +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +public class ScoreQueryIT extends SQLIntegTestCase { + @Override + protected void init() throws Exception { + loadIndex(Index.ACCOUNT); + } + + /** + * "query" : { + * "from": 0, + * "size": 3, + * "timeout": "1m", + * "query": { + * "bool": { + * "should": [ + * { + * "match": { + * "address": { + * "query": "Lane", + * "operator": "OR", + * "prefix_length": 0, + * "max_expansions": 50, + * "fuzzy_transpositions": true, + * "lenient": false, + * "zero_terms_query": "NONE", + * "auto_generate_synonyms_phrase_query": true, + * "boost": 100.0 + * } + * } + * }, + * { + * "match": { + * "address": { + * "query": "Street", + * "operator": "OR", + * "prefix_length": 0, + * "max_expansions": 50, + * "fuzzy_transpositions": true, + * "lenient": false, + * "zero_terms_query": "NONE", + * "auto_generate_synonyms_phrase_query": true, + * "boost": 0.5 + * } + * } + * } + * ], + * "adjust_pure_negative": true, + * "boost": 1.0 + * } + * }, + * "_source": { + * "includes": [ + * "address" + * ], + * "excludes": [] + * }, + * "sort": [ + * { + * "_score": { + * "order": "desc" + * } + * } + * ], + * "track_scores": true + * } + * @throws IOException + */ + @Test + public void scoreQueryExplainTest() throws IOException { + final String result = explainQuery(String.format(Locale.ROOT, + "select address from %s " + + "where score(matchQuery(address, 'Douglass'), 100) " + + "or score(matchQuery(address, 'Hall'), 0.5) order by _score desc limit 2", + TestsConstants.TEST_INDEX_ACCOUNT)); + Assert.assertThat(result, containsString("\\\"match\\\":{\\\"address\\\":{\\\"query\\\":\\\"Douglass\\\"")); + Assert.assertThat(result, containsString("\\\"boost\\\":100.0")); + Assert.assertThat(result, containsString("\\\"match\\\":{\\\"address\\\":{\\\"query\\\":\\\"Hall\\\"")); + Assert.assertThat(result, containsString("\\\"boost\\\":0.5")); + Assert.assertThat(result, containsString("\\\"sort\\\":[{\\\"_score\\\"")); + Assert.assertThat(result, containsString("\\\"track_scores\\\":true")); + } + + @Test + public void scoreQueryTest() throws IOException { + final JSONObject result = new JSONObject(executeQuery(String.format(Locale.ROOT, + "select address, _score from %s " + + "where score(matchQuery(address, 'Douglass'), 100) " + + "or score(matchQuery(address, 'Hall'), 0.5) order by _score desc limit 2", + TestsConstants.TEST_INDEX_ACCOUNT), "jdbc")); + verifySchema(result, + schema("address", null, "text"), + schema("_score", null, "float")); + verifyDataRows(result, + rows("154 Douglass Street", 650.1515), + rows("565 Hall Street", 3.2507575)); + } + + @Test + public void scoreQueryDefaultBoostExplainTest() throws IOException { + final String result = explainQuery(String.format(Locale.ROOT, + "select address from %s " + + "where score(matchQuery(address, 'Lane')) order by _score desc limit 2", + TestsConstants.TEST_INDEX_ACCOUNT)); + Assert.assertThat(result, containsString("\\\"match\\\":{\\\"address\\\":{\\\"query\\\":\\\"Lane\\\"")); + Assert.assertThat(result, containsString("\\\"boost\\\":1.0")); + Assert.assertThat(result, containsString("\\\"sort\\\":[{\\\"_score\\\"")); + Assert.assertThat(result, containsString("\\\"track_scores\\\":true")); + } + + @Test + public void scoreQueryDefaultBoostQueryTest() throws IOException { + final JSONObject result = new JSONObject(executeQuery(String.format(Locale.ROOT, + "select address, _score from %s " + + "where score(matchQuery(address, 'Powell')) order by _score desc limit 2", + TestsConstants.TEST_INDEX_ACCOUNT), "jdbc")); + verifySchema(result, + schema("address", null, "text"), + schema("_score", null, "float")); + verifyDataRows(result, rows("305 Powell Street", 6.501515)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java index 6f6fea841b..3976f854fd 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java @@ -7,6 +7,8 @@ package org.opensearch.sql.opensearch.request; import com.google.common.annotations.VisibleForTesting; +import java.util.Arrays; +import java.util.List; import java.util.function.Consumer; import java.util.function.Function; import lombok.EqualsAndHashCode; @@ -18,6 +20,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -94,11 +97,16 @@ public OpenSearchQueryRequest(IndexName indexName, SearchSourceBuilder sourceBui @Override public OpenSearchResponse search(Function searchAction, Function scrollAction) { + FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); + List includes = fetchSource != null && fetchSource.includes() != null + ? Arrays.asList(fetchSource.includes()) + : List.of(); if (searchDone) { - return new OpenSearchResponse(SearchHits.empty(), exprValueFactory); + return new OpenSearchResponse(SearchHits.empty(), exprValueFactory, includes); } else { searchDone = true; - return new OpenSearchResponse(searchAction.apply(searchRequest()), exprValueFactory); + return new OpenSearchResponse( + searchAction.apply(searchRequest()), exprValueFactory, includes); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 97aeee3747..95f9fa39b0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -100,6 +100,7 @@ public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName, sourceBuilder.from(0); sourceBuilder.size(querySize); sourceBuilder.timeout(DEFAULT_QUERY_TIMEOUT); + sourceBuilder.trackScores(false); } /** @@ -180,6 +181,10 @@ public void pushDownLimit(Integer limit, Integer offset) { sourceBuilder.from(offset).size(limit); } + public void pushDownTrackedScore(boolean trackScores) { + sourceBuilder.trackScores(trackScores); + } + /** * Add highlight to DSL requests. * @param field name of the field to highlight diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java index 4509e443c0..dacbecc7b9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java @@ -6,6 +6,8 @@ package org.opensearch.sql.opensearch.request; +import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; @@ -19,6 +21,7 @@ import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.common.unit.TimeValue; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -87,8 +90,11 @@ public OpenSearchResponse search(Function searchA openSearchResponse = searchAction.apply(searchRequest()); } setScrollId(openSearchResponse.getScrollId()); - - return new OpenSearchResponse(openSearchResponse, exprValueFactory); + FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); + List includes = fetchSource != null && fetchSource.includes() != null + ? Arrays.asList(this.sourceBuilder.fetchSource().includes()) + : List.of(); + return new OpenSearchResponse(openSearchResponse, exprValueFactory, includes); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index aadd73efdd..85a6b503f6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -6,16 +6,29 @@ package org.opensearch.sql.opensearch.response; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATAFIELD_TYPE_MAP; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_ID; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_INDEX; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_MAXSCORE; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_SCORE; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_SORT; + import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import lombok.EqualsAndHashCode; import lombok.ToString; import org.opensearch.action.search.SearchResponse; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -38,6 +51,11 @@ public class OpenSearchResponse implements Iterable { */ private final Aggregations aggregations; + /** + * List of requested include fields. + */ + private final List includes; + /** * ElasticsearchExprValueFactory used to build ExprValue from search result. */ @@ -48,19 +66,24 @@ public class OpenSearchResponse implements Iterable { * Constructor of ElasticsearchResponse. */ public OpenSearchResponse(SearchResponse searchResponse, - OpenSearchExprValueFactory exprValueFactory) { + OpenSearchExprValueFactory exprValueFactory, + List includes) { this.hits = searchResponse.getHits(); this.aggregations = searchResponse.getAggregations(); this.exprValueFactory = exprValueFactory; + this.includes = includes; } /** * Constructor of ElasticsearchResponse with SearchHits. */ - public OpenSearchResponse(SearchHits hits, OpenSearchExprValueFactory exprValueFactory) { + public OpenSearchResponse(SearchHits hits, + OpenSearchExprValueFactory exprValueFactory, + List includes) { this.hits = hits; this.aggregations = null; this.exprValueFactory = exprValueFactory; + this.includes = includes; } /** @@ -92,14 +115,37 @@ public Iterator iterator() { return (ExprValue) ExprTupleValue.fromExprValueMap(builder.build()); }).iterator(); } else { + List metaDataFieldSet = includes.stream() + .filter(include -> METADATAFIELD_TYPE_MAP.containsKey(include)) + .collect(Collectors.toList()); + ExprFloatValue maxScore = Float.isNaN(hits.getMaxScore()) + ? null : new ExprFloatValue(hits.getMaxScore()); return Arrays.stream(hits.getHits()) .map(hit -> { - ExprValue docData = exprValueFactory.construct(hit.getSourceAsString()); - if (hit.getHighlightFields().isEmpty()) { - return docData; - } else { - ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - builder.putAll(docData.tupleValue()); + String source = hit.getSourceAsString(); + ExprValue docData = exprValueFactory.construct(source); + + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + builder.putAll(docData.tupleValue()); + metaDataFieldSet.forEach(metaDataField -> { + if (metaDataField.equals(METADATA_FIELD_INDEX)) { + builder.put(METADATA_FIELD_INDEX, new ExprStringValue(hit.getIndex())); + } else if (metaDataField.equals(METADATA_FIELD_ID)) { + builder.put(METADATA_FIELD_ID, new ExprStringValue(hit.getId())); + } else if (metaDataField.equals(METADATA_FIELD_SCORE)) { + if (!Float.isNaN(hit.getScore())) { + builder.put(METADATA_FIELD_SCORE, new ExprFloatValue(hit.getScore())); + } + } else if (metaDataField.equals(METADATA_FIELD_MAXSCORE)) { + if (maxScore != null) { + builder.put(METADATA_FIELD_MAXSCORE, maxScore); + } + } else { // if (metaDataField.equals(METADATA_FIELD_SORT)) { + builder.put(METADATA_FIELD_SORT, new ExprLongValue(hit.getSeqNo())); + } + }); + + if (!hit.getHighlightFields().isEmpty()) { var hlBuilder = ImmutableMap.builder(); for (var es : hit.getHighlightFields().entrySet()) { hlBuilder.put(es.getKey(), ExprValueUtils.collectionValue( @@ -107,8 +153,8 @@ public Iterator iterator() { t -> (t.toString())).collect(Collectors.toList()))); } builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); - return ExprTupleValue.fromExprValueMap(builder.build()); } + return (ExprValue) ExprTupleValue.fromExprValueMap(builder.build()); }).iterator(); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 9ed8adf3ee..cf09b32de9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -12,6 +12,7 @@ import java.util.Map; import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; @@ -34,6 +35,20 @@ /** OpenSearch table (index) implementation. */ public class OpenSearchIndex implements Table { + public static final String METADATA_FIELD_ID = "_id"; + public static final String METADATA_FIELD_INDEX = "_index"; + public static final String METADATA_FIELD_SCORE = "_score"; + public static final String METADATA_FIELD_MAXSCORE = "_maxscore"; + public static final String METADATA_FIELD_SORT = "_sort"; + + public static final java.util.Map METADATAFIELD_TYPE_MAP = Map.of( + METADATA_FIELD_ID, ExprCoreType.STRING, + METADATA_FIELD_INDEX, ExprCoreType.STRING, + METADATA_FIELD_SCORE, ExprCoreType.FLOAT, + METADATA_FIELD_MAXSCORE, ExprCoreType.FLOAT, + METADATA_FIELD_SORT, ExprCoreType.LONG + ); + /** OpenSearch client connection. */ private final OpenSearchClient client; @@ -111,6 +126,11 @@ public Map getFieldTypes() { return cachedFieldTypes; } + @Override + public Map getReservedFieldTypes() { + return METADATAFIELD_TYPE_MAP; + } + /** * Get parsed mapping info. * @return A complete map between field names and their types. @@ -151,8 +171,11 @@ public LogicalPlan optimize(LogicalPlan plan) { @Override public TableScanBuilder createScanBuilder() { + Map allFields = new HashMap<>(); + getReservedFieldTypes().forEach((k, v) -> allFields.put(k, OpenSearchDataType.of(v))); + allFields.putAll(getFieldOpenSearchTypes()); OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, - getMaxResultWindow(), new OpenSearchExprValueFactory(getFieldOpenSearchTypes())); + getMaxResultWindow(), new OpenSearchExprValueFactory(allFields)); return new OpenSearchIndexScanBuilder(indexScan); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java index e9746e1fae..a26e64a809 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java @@ -58,8 +58,13 @@ public class OpenSearchIndexScan extends TableScanOperator { public OpenSearchIndexScan(OpenSearchClient client, Settings settings, String indexName, Integer maxResultWindow, OpenSearchExprValueFactory exprValueFactory) { - this(client, settings, - new OpenSearchRequest.IndexName(indexName),maxResultWindow, exprValueFactory); + this( + client, + settings, + new OpenSearchRequest.IndexName(indexName), + maxResultWindow, + exprValueFactory + ); } /** @@ -70,7 +75,7 @@ public OpenSearchIndexScan(OpenSearchClient client, Settings settings, OpenSearchExprValueFactory exprValueFactory) { this.client = client; this.requestBuilder = new OpenSearchRequestBuilder( - indexName, maxResultWindow, settings,exprValueFactory); + indexName, maxResultWindow, settings, exprValueFactory); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java index 7190d58000..d5a0c72f20 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -18,8 +18,10 @@ import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; @@ -61,8 +63,11 @@ public TableScanOperator build() { public boolean pushDownFilter(LogicalFilter filter) { FilterQueryBuilder queryBuilder = new FilterQueryBuilder( new DefaultExpressionSerializer()); - QueryBuilder query = queryBuilder.build(filter.getCondition()); + Expression queryCondition = filter.getCondition(); + QueryBuilder query = queryBuilder.build(queryCondition); indexScan.getRequestBuilder().pushDown(query); + indexScan.getRequestBuilder().pushDownTrackedScore( + trackScoresFromOpenSearchFunction(queryCondition)); return true; } @@ -99,6 +104,18 @@ public boolean pushDownHighlight(LogicalHighlight highlight) { return true; } + private boolean trackScoresFromOpenSearchFunction(Expression condition) { + if (condition instanceof OpenSearchFunctions.OpenSearchFunction + && ((OpenSearchFunctions.OpenSearchFunction) condition).isScoreTracked()) { + return true; + } + if (condition instanceof FunctionExpression) { + return ((FunctionExpression) condition).getArguments().stream() + .anyMatch(this::trackScoresFromOpenSearchFunction); + } + return false; + } + /** * Find reference expression from expression. * @param expressions a list of expression. diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java index ab8f086dff..1415fc22c6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilder.java @@ -49,6 +49,9 @@ public class SortQueryBuilder { */ public SortBuilder build(Expression expression, Sort.SortOption option) { if (expression instanceof ReferenceExpression) { + if (((ReferenceExpression) expression).getAttr().equalsIgnoreCase("_score")) { + return SortBuilders.scoreSort().order(sortOrderMap.get(option.getSortOrder())); + } return fieldBuild((ReferenceExpression) expression, option); } else { throw new IllegalStateException("unsupported expression " + expression.getClass()); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index 1c79a28f3f..aa603157a8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -324,7 +324,7 @@ void search() { Iterator hits = response1.iterator(); assertTrue(hits.hasNext()); - assertEquals(exprTupleValue, hits.next()); + assertEquals(exprTupleValue.tupleValue().get("id"), hits.next().tupleValue().get("id")); assertFalse(hits.hasNext()); // Verify response for second scroll request diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index f2da6fd1e0..a86399ed32 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -305,7 +305,7 @@ void search() throws IOException { Iterator hits = response1.iterator(); assertTrue(hits.hasNext()); - assertEquals(exprTupleValue, hits.next()); + assertEquals(exprTupleValue.tupleValue().get("id"), hits.next().tupleValue().get("id")); assertFalse(hits.hasNext()); // Verify response for second scroll request diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java index 1ba26e33dc..be83622578 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java @@ -15,6 +15,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.Iterator; import java.util.function.Consumer; import java.util.function.Function; import org.junit.jupiter.api.Test; @@ -28,6 +29,8 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -52,6 +55,12 @@ public class OpenSearchQueryRequestTest { @Mock private SearchHit searchHit; + @Mock + private SearchSourceBuilder sourceBuilder; + + @Mock + private FetchSourceContext fetchSourceContext; + @Mock private OpenSearchExprValueFactory factory; @@ -60,17 +69,69 @@ public class OpenSearchQueryRequestTest { @Test void search() { + OpenSearchQueryRequest request = new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory + ); + + when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); + when(fetchSourceContext.includes()).thenReturn(null); when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + verify(fetchSourceContext, times(1)).includes(); assertFalse(searchResponse.isEmpty()); searchResponse = request.search(searchAction, scrollAction); assertTrue(searchResponse.isEmpty()); verify(searchAction, times(1)).apply(any()); } + @Test + void search_withoutContext() { + OpenSearchQueryRequest request = new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory + ); + + when(sourceBuilder.fetchSource()).thenReturn(null); + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + verify(sourceBuilder, times(1)).fetchSource(); + assertFalse(searchResponse.isEmpty()); + } + + @Test + void search_withIncludes() { + OpenSearchQueryRequest request = new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory + ); + + String[] includes = {"_id", "_index"}; + when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); + when(fetchSourceContext.includes()).thenReturn(includes); + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + verify(fetchSourceContext, times(2)).includes(); + assertFalse(searchResponse.isEmpty()); + + searchResponse = request.search(searchAction, scrollAction); + assertTrue(searchResponse.isEmpty()); + + verify(searchAction, times(1)).apply(any()); + } + @Test void clean() { request.clean(cleanAction); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 980d68ed80..85e259a400 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -72,6 +72,7 @@ void buildQueryRequest() { Integer limit = 200; Integer offset = 0; requestBuilder.pushDownLimit(limit, offset); + requestBuilder.pushDownTrackedScore(true); assertEquals( new OpenSearchQueryRequest( @@ -79,7 +80,8 @@ void buildQueryRequest() { new SearchSourceBuilder() .from(offset) .size(limit) - .timeout(DEFAULT_QUERY_TIMEOUT), + .timeout(DEFAULT_QUERY_TIMEOUT) + .trackScores(true), exprValueFactory), requestBuilder.build()); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java index 0fc9c92810..b3c049ce03 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java @@ -9,20 +9,50 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import java.util.function.Function; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) class OpenSearchScrollRequestTest { + @Mock + private Function searchAction; + + @Mock + private Function scrollAction; + + @Mock + private SearchResponse searchResponse; + + @Mock + private SearchHits searchHits; + + @Mock + private SearchHit searchHit; + + @Mock + private SearchSourceBuilder sourceBuilder; + + @Mock + private FetchSourceContext fetchSourceContext; @Mock private OpenSearchExprValueFactory factory; @@ -58,4 +88,61 @@ void scrollRequest() { .scrollId("scroll123"), request.scrollRequest()); } + + @Test + void search() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory + ); + + String[] includes = {"_id", "_index"}; + when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); + when(fetchSourceContext.includes()).thenReturn(includes); + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + verify(fetchSourceContext, times(2)).includes(); + assertFalse(searchResponse.isEmpty()); + } + + @Test + void search_withoutContext() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory + ); + + when(sourceBuilder.fetchSource()).thenReturn(null); + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + verify(sourceBuilder, times(1)).fetchSource(); + assertFalse(searchResponse.isEmpty()); + } + + @Test + void search_withoutIncludes() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory + ); + + when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); + when(fetchSourceContext.includes()).thenReturn(null); + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + verify(fetchSourceContext, times(1)).includes(); + assertFalse(searchResponse.isEmpty()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 0a60503415..92b47bc7da 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; import org.apache.lucene.search.TotalHits; @@ -31,7 +32,10 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -56,6 +60,8 @@ class OpenSearchResponseTest { @Mock private Aggregations aggregations; + private List includes = List.of(); + @Mock private OpenSearchAggregationResponseParser parser; @@ -74,20 +80,20 @@ void isEmpty() { new TotalHits(2L, TotalHits.Relation.EQUAL_TO), 1.0F)); - assertFalse(new OpenSearchResponse(searchResponse, factory).isEmpty()); + assertFalse(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(null); - assertTrue(new OpenSearchResponse(searchResponse, factory).isEmpty()); + assertTrue(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); when(searchResponse.getHits()) .thenReturn(new SearchHits(null, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0)); - OpenSearchResponse response3 = new OpenSearchResponse(searchResponse, factory); + OpenSearchResponse response3 = new OpenSearchResponse(searchResponse, factory, includes); assertTrue(response3.isEmpty()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(new Aggregations(emptyList())); - assertFalse(new OpenSearchResponse(searchResponse, factory).isEmpty()); + assertFalse(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); } @Test @@ -104,11 +110,126 @@ void iterator() { when(factory.construct(any())).thenReturn(exprTupleValue1).thenReturn(exprTupleValue2); int i = 0; - for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, List.of("id1"))) { if (i == 0) { - assertEquals(exprTupleValue1, hit); + assertEquals(exprTupleValue1.tupleValue().get("id"), hit.tupleValue().get("id")); } else if (i == 1) { - assertEquals(exprTupleValue2, hit); + assertEquals(exprTupleValue2.tupleValue().get("id"), hit.tupleValue().get("id")); + } else { + fail("More search hits returned than expected"); + } + i++; + } + } + + @Test + void iterator_metafields() { + + ExprTupleValue exprTupleHit = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "id1", new ExprIntegerValue(1) + )); + + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit1}, + new TotalHits(1L, TotalHits.Relation.EQUAL_TO), + 3.75F)); + + when(searchHit1.getSourceAsString()).thenReturn("{\"id1\", 1}"); + when(searchHit1.getId()).thenReturn("testId"); + when(searchHit1.getIndex()).thenReturn("testIndex"); + when(searchHit1.getScore()).thenReturn(3.75F); + when(searchHit1.getSeqNo()).thenReturn(123456L); + + when(factory.construct(any())).thenReturn(exprTupleHit); + + ExprTupleValue exprTupleResponse = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "id1", new ExprIntegerValue(1), + "_index", new ExprStringValue("testIndex"), + "_id", new ExprStringValue("testId"), + "_sort", new ExprLongValue(123456L), + "_score", new ExprFloatValue(3.75F), + "_maxscore", new ExprFloatValue(3.75F) + )); + List includes = List.of("id1", "_index", "_id", "_sort", "_score", "_maxscore"); + int i = 0; + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, includes)) { + if (i == 0) { + assertEquals(exprTupleResponse, hit); + } else { + fail("More search hits returned than expected"); + } + i++; + } + } + + @Test + void iterator_metafields_withoutIncludes() { + + ExprTupleValue exprTupleHit = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "id1", new ExprIntegerValue(1) + )); + + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit1}, + new TotalHits(1L, TotalHits.Relation.EQUAL_TO), + 3.75F)); + + when(searchHit1.getSourceAsString()).thenReturn("{\"id1\", 1}"); + + when(factory.construct(any())).thenReturn(exprTupleHit); + + List includes = List.of("id1"); + ExprTupleValue exprTupleResponse = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "id1", new ExprIntegerValue(1) + )); + int i = 0; + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, includes)) { + if (i == 0) { + assertEquals(exprTupleResponse, hit); + } else { + fail("More search hits returned than expected"); + } + i++; + } + } + + @Test + void iterator_metafields_scoreNaN() { + + ExprTupleValue exprTupleHit = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "id1", new ExprIntegerValue(1) + )); + + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit1}, + new TotalHits(1L, TotalHits.Relation.EQUAL_TO), + Float.NaN)); + + when(searchHit1.getSourceAsString()).thenReturn("{\"id1\", 1}"); + when(searchHit1.getId()).thenReturn("testId"); + when(searchHit1.getIndex()).thenReturn("testIndex"); + when(searchHit1.getScore()).thenReturn(Float.NaN); + when(searchHit1.getSeqNo()).thenReturn(123456L); + + when(factory.construct(any())).thenReturn(exprTupleHit); + + List includes = List.of("id1", "_index", "_id", "_sort", "_score", "_maxscore"); + ExprTupleValue exprTupleResponse = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "id1", new ExprIntegerValue(1), + "_index", new ExprStringValue("testIndex"), + "_id", new ExprStringValue("testId"), + "_sort", new ExprLongValue(123456L) + )); + int i = 0; + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, includes)) { + if (i == 0) { + assertEquals(exprTupleResponse, hit); } else { fail("More search hits returned than expected"); } @@ -120,7 +241,7 @@ void iterator() { void response_is_aggregation_when_aggregation_not_empty() { when(searchResponse.getAggregations()).thenReturn(aggregations); - OpenSearchResponse response = new OpenSearchResponse(searchResponse, factory); + OpenSearchResponse response = new OpenSearchResponse(searchResponse, factory, includes); assertTrue(response.isAggregationResponse()); } @@ -128,12 +249,14 @@ void response_is_aggregation_when_aggregation_not_empty() { void response_isnot_aggregation_when_aggregation_is_empty() { when(searchResponse.getAggregations()).thenReturn(null); - OpenSearchResponse response = new OpenSearchResponse(searchResponse, factory); + OpenSearchResponse response = new OpenSearchResponse(searchResponse, factory, includes); assertFalse(response.isAggregationResponse()); } @Test void aggregation_iterator() { + final List includes = List.of("id1", "id2"); + when(parser.parse(any())) .thenReturn(Arrays.asList(ImmutableMap.of("id1", 1), ImmutableMap.of("id2", 2))); when(searchResponse.getAggregations()).thenReturn(aggregations); @@ -143,7 +266,7 @@ void aggregation_iterator() { .thenReturn(new ExprIntegerValue(2)); int i = 0; - for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, includes)) { if (i == 0) { assertEquals(exprTupleValue1, hit); } else if (i == 1) { @@ -176,7 +299,7 @@ void highlight_iterator() { when(searchHit1.getHighlightFields()).thenReturn(highlightMap); when(factory.construct(any())).thenReturn(resultTuple); - for (ExprValue resultHit : new OpenSearchResponse(searchResponse, factory)) { + for (ExprValue resultHit : new OpenSearchResponse(searchResponse, factory, includes)) { var expected = ExprValueUtils.collectionValue( Arrays.stream(searchHit.getHighlightFields().get("highlights").getFragments()) .map(t -> (t.toString())).collect(Collectors.toList())); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 8d4dad48a9..3d856cb1e2 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -183,6 +183,21 @@ void checkCacheUsedForFieldMappings() { hasEntry("name", OpenSearchDataType.of(STRING)))); } + @Test + void getReservedFieldTypes() { + Map fieldTypes = index.getReservedFieldTypes(); + assertThat( + fieldTypes, + allOf( + aMapWithSize(5), + hasEntry("_id", ExprCoreType.STRING), + hasEntry("_index", ExprCoreType.STRING), + hasEntry("_sort", ExprCoreType.LONG), + hasEntry("_score", ExprCoreType.FLOAT), + hasEntry("_maxscore", ExprCoreType.FLOAT) + )); + } + @Test void implementRelationOperatorOnly() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java index b90ca8836d..852a5a71bc 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -7,6 +7,7 @@ package org.opensearch.sql.opensearch.storage.scan; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -33,9 +34,11 @@ import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_SORT; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -48,6 +51,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.SpanOrQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; @@ -57,10 +61,15 @@ import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; @@ -68,6 +77,7 @@ import org.opensearch.sql.opensearch.response.agg.SingleValueParser; import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; @@ -134,6 +144,119 @@ void test_filter_push_down() { ); } + /** + * SELECT intV as i FROM schema WHERE query_string(["intV^1.5", "QUERY", boost=12.5). + */ + @Test + void test_filter_on_opensearchfunction_with_trackedscores_push_down() { + LogicalPlan expectedPlan = + project( + indexScanBuilder( + withFilterPushedDown( + QueryBuilders.queryStringQuery("QUERY") + .field("intV", 1.5F) + .boost(12.5F) + ), + withTrackedScoresPushedDown(true) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ); + FunctionExpression queryString = DSL.query_string( + DSL.namedArgument("fields", DSL.literal( + new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( + "intV", ExprValueUtils.floatValue(1.5F)))))), + DSL.namedArgument("query", "QUERY"), + DSL.namedArgument("boost", "12.5")); + + ((OpenSearchFunctions.OpenSearchFunction) queryString).setScoreTracked(true); + + LogicalPlan logicalPlan = project( + filter( + relation("schema", table), + queryString + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ); + assertEqualsAfterOptimization(expectedPlan, logicalPlan); + } + + @Test + void test_filter_on_multiple_opensearchfunctions_with_trackedscores_push_down() { + LogicalPlan expectedPlan = + project( + indexScanBuilder( + withFilterPushedDown( + QueryBuilders.boolQuery() + .should( + QueryBuilders.queryStringQuery("QUERY") + .field("intV", 1.5F) + .boost(12.5F)) + .should( + QueryBuilders.queryStringQuery("QUERY") + .field("intV", 1.5F) + .boost(12.5F) + ) + ), + withTrackedScoresPushedDown(true) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ); + FunctionExpression firstQueryString = DSL.query_string( + DSL.namedArgument("fields", DSL.literal( + new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( + "intV", ExprValueUtils.floatValue(1.5F)))))), + DSL.namedArgument("query", "QUERY"), + DSL.namedArgument("boost", "12.5")); + ((OpenSearchFunctions.OpenSearchFunction) firstQueryString).setScoreTracked(false); + FunctionExpression secondQueryString = DSL.query_string( + DSL.namedArgument("fields", DSL.literal( + new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( + "intV", ExprValueUtils.floatValue(1.5F)))))), + DSL.namedArgument("query", "QUERY"), + DSL.namedArgument("boost", "12.5")); + ((OpenSearchFunctions.OpenSearchFunction) secondQueryString).setScoreTracked(true); + + LogicalPlan logicalPlan = project( + filter( + relation("schema", table), + DSL.or(firstQueryString, secondQueryString) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ); + assertEqualsAfterOptimization(expectedPlan, logicalPlan); + } + + @Test + void test_filter_on_opensearchfunction_without_trackedscores_push_down() { + LogicalPlan expectedPlan = + project( + indexScanBuilder( + withFilterPushedDown( + QueryBuilders.queryStringQuery("QUERY") + .field("intV", 1.5F) + .boost(12.5F) + ), + withTrackedScoresPushedDown(false) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ); + FunctionExpression queryString = DSL.query_string( + DSL.namedArgument("fields", DSL.literal( + new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( + "intV", ExprValueUtils.floatValue(1.5F)))))), + DSL.namedArgument("query", "QUERY"), + DSL.namedArgument("boost", "12.5")); + + LogicalPlan logicalPlan = project( + filter( + relation("schema", table), + queryString + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ); + assertEqualsAfterOptimization(expectedPlan, logicalPlan); + } + /** * SELECT avg(intV) FROM schema GROUP BY string_value. */ @@ -210,6 +333,21 @@ void test_sort_push_down() { ); } + @Test + void test_score_sort_push_down() { + assertEqualsAfterOptimization( + indexScanBuilder( + withSortPushedDown( + SortBuilders.scoreSort().order(SortOrder.ASC) + ) + ), + sort( + relation("schema", table), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("_score", INTEGER)) + ) + ); + } + @Test void test_limit_push_down() { assertEqualsAfterOptimization( @@ -577,6 +715,10 @@ private Runnable withHighlightPushedDown(String field, Map argu return () -> verify(requestBuilder, times(1)).pushDownHighlight(field, arguments); } + private Runnable withTrackedScoresPushedDown(boolean trackScores) { + return () -> verify(requestBuilder, times(1)).pushDownTrackedScore(trackScores); + } + private static AggregationAssertHelper.AggregationAssertHelperBuilder aggregate(String aggName) { var aggBuilder = new AggregationAssertHelper.AggregationAssertHelperBuilder(); aggBuilder.aggregateName = aggName; diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 616bfa8a79..b65f60e289 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -134,7 +134,6 @@ STDDEV_SAMP: 'STDDEV_SAMP'; SUBSTRING: 'SUBSTRING'; TRIM: 'TRIM'; - // Keywords, but can be ID // Common Keywords, but can be ID @@ -328,6 +327,8 @@ REVERSE_NESTED: 'REVERSE_NESTED'; QUERY: 'QUERY'; RANGE: 'RANGE'; SCORE: 'SCORE'; +SCOREQUERY: 'SCOREQUERY'; +SCORE_QUERY: 'SCORE_QUERY'; SECOND_OF_MINUTE: 'SECOND_OF_MINUTE'; STATS: 'STATS'; TERM: 'TERM'; @@ -465,7 +466,6 @@ BACKTICK_QUOTE_ID: BQUOTA_STRING; // Fragments for Literal primitives fragment EXPONENT_NUM_PART: 'E' [-+]? DEC_DIGIT+; -fragment ID_LITERAL: [@*A-Z]+?[*A-Z_\-0-9]*; fragment DQUOTA_STRING: '"' ( '\\'. | '""' | ~('"'| '\\') )* '"'; fragment SQUOTA_STRING: '\'' ('\\'. | '\'\'' | ~('\'' | '\\'))* '\''; fragment BQUOTA_STRING: '`' ( '\\'. | '``' | ~('`'|'\\'))* '`'; @@ -473,6 +473,10 @@ fragment HEX_DIGIT: [0-9A-F]; fragment DEC_DIGIT: [0-9]; fragment BIT_STRING_L: 'B' '\'' [01]+ '\''; +// Identifiers cannot start with a single '_' since this an OpenSearch reserved +// metadata field. Two underscores (or more) is acceptable, such as '__field'. +fragment ID_LITERAL: ([@*A-Z_])+?[*A-Z_\-0-9]*; + // Last tokens must generate Errors ERROR_RECOGNITION: . -> channel(ERRORCHANNEL); diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 3861716ac9..2b5bf8c478 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -311,6 +311,7 @@ functionCall | windowFunctionClause #windowFunctionCall | aggregateFunction #aggregateFunctionCall | aggregateFunction (orderByClause)? filterClause #filteredAggregationFunctionCall + | scoreRelevanceFunction #scoreRelevanceFunctionCall | relevanceFunction #relevanceFunctionCall | highlightFunction #highlightFunctionCall | positionFunction #positionFunctionCall @@ -404,7 +405,10 @@ specificFunction relevanceFunction : noFieldRelevanceFunction | singleFieldRelevanceFunction | multiFieldRelevanceFunction | altSingleFieldRelevanceFunction | altMultiFieldRelevanceFunction + ; +scoreRelevanceFunction + : scoreRelevanceFunctionName LR_BRACKET relevanceFunction (COMMA weight=relevanceFieldWeight)? RR_BRACKET ; noFieldRelevanceFunction @@ -562,6 +566,10 @@ systemFunctionName : TYPEOF ; +scoreRelevanceFunctionName + : SCORE | SCOREQUERY | SCORE_QUERY + ; + singleFieldRelevanceFunctionName : MATCH | MATCHQUERY | MATCH_QUERY | MATCH_PHRASE | MATCHPHRASE | MATCHPHRASEQUERY diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index d15a9dea0b..a5d071ef95 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.between; import static org.opensearch.sql.ast.dsl.AstDSL.not; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; -import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE; @@ -53,6 +52,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RelevanceFieldAndWeightContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScalarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScalarWindowFunctionContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScoreRelevanceFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ShowDescribePatternContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SignedDecimalContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SignedRealContext; @@ -93,6 +93,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.RelevanceFieldList; +import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; @@ -188,7 +189,7 @@ public UnresolvedExpression visitPositionFunction( return new Function( POSITION.getName().getFunctionName(), Arrays.asList(visitFunctionArg(ctx.functionArg(0)), - visitFunctionArg(ctx.functionArg(1)))); + visitFunctionArg(ctx.functionArg(1)))); } @Override @@ -470,7 +471,7 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( if ((funcName.equalsIgnoreCase(BuiltinFunctionName.MULTI_MATCH.toString()) || funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCH.toString()) || funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCHQUERY.toString())) - && ! ctx.getRuleContexts(AlternateMultiMatchQueryContext.class) + && !ctx.getRuleContexts(AlternateMultiMatchQueryContext.class) .isEmpty()) { return new Function( ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(), @@ -490,6 +491,20 @@ public UnresolvedExpression visitAltMultiFieldRelevanceFunction( altMultiFieldRelevanceFunctionArguments(ctx)); } + /** + * Visit score-relevance function and collect children. + * + * @param ctx the parse tree + * @return children + */ + public UnresolvedExpression visitScoreRelevanceFunction(ScoreRelevanceFunctionContext ctx) { + Literal weight = + ctx.weight == null + ? new Literal(Double.valueOf(1.0), DataType.DOUBLE) + : new Literal(Double.parseDouble(ctx.weight.getText()), DataType.DOUBLE); + return new ScoreFunction(visit(ctx.relevanceFunction()), weight); + } + private Function buildFunction(String functionName, List arg) { return new Function( @@ -514,8 +529,7 @@ private QualifiedName visitIdentifiers(List identifiers) { identifiers.stream() .map(RuleContext::getText) .map(StringUtils::unquoteIdentifier) - .collect(Collectors.toList()) - ); + .collect(Collectors.toList())); } private void fillRelevanceArgs(List args, @@ -609,6 +623,7 @@ private List timestampFunctionArguments( /** * Adds support for multi_match alternate syntax like * MULTI_MATCH('query'='Dale', 'fields'='*name'). + * * @param ctx : Context for multi field relevance function. * @return : Returns list of all arguments for relevance function. */ @@ -621,7 +636,7 @@ private List alternateMultiMatchArguments( String[] fieldAndWeights = StringUtils.unquoteText( ctx.getRuleContexts(AlternateMultiMatchFieldContext.class) - .stream().findFirst().get().argVal.getText()).split(","); + .stream().findFirst().get().argVal.getText()).split(","); for (var fieldAndWeight : fieldAndWeights) { String[] splitFieldAndWeights = fieldAndWeight.split("\\^"); @@ -633,9 +648,10 @@ private List alternateMultiMatchArguments( ctx.getRuleContexts(AlternateMultiMatchQueryContext.class) .stream().findFirst().ifPresent( - arg -> - builder.add(new UnresolvedArgument("query", - new Literal(StringUtils.unquoteText(arg.argVal.getText()), DataType.STRING))) + arg -> + builder.add(new UnresolvedArgument("query", + new Literal( + StringUtils.unquoteText(arg.argVal.getText()), DataType.STRING))) ); fillRelevanceArgs(ctx.relevanceArg(), builder); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index 52dd5e3572..20655bc020 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -36,6 +36,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.HashMap; +import java.util.stream.Stream; import org.antlr.v4.runtime.CommonTokenStream; import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; @@ -463,6 +464,26 @@ public void canBuildKeywordsAsIdentInQualifiedName() { ); } + @Test + public void canBuildMetaDataFieldAsQualifiedName() { + Stream.of("_id", "_index", "_sort", "_score", "_maxscore").forEach( + field -> assertEquals( + qualifiedName(field), + buildExprAst(field) + ) + ); + } + + @Test + public void canBuildNonMetaDataFieldAsQualifiedName() { + Stream.of("id", "__id", "_routing", "___field").forEach( + field -> assertEquals( + qualifiedName(field), + buildExprAst(field) + ) + ); + } + @Test public void canCastFieldAsString() { assertEquals( @@ -798,6 +819,36 @@ public void relevanceWildcard_query() { ); } + @Test + public void relevanceScore_query() { + assertEquals( + AstDSL.score( + AstDSL.function("query_string", + unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( + "field1", 1.F, "field2", 3.2F))), + unresolvedArg("query", stringLiteral("search query")) + ), + AstDSL.doubleLiteral(1.0) + ), + buildExprAst("score(query_string(['field1', 'field2' ^ 3.2], 'search query'))") + ); + } + + @Test + public void relevanceScore_withBoost_query() { + assertEquals( + AstDSL.score( + AstDSL.function("query_string", + unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( + "field1", 1.F, "field2", 3.2F))), + unresolvedArg("query", stringLiteral("search query")) + ), + doubleLiteral(1.0) + ), + buildExprAst("score(query_string(['field1', 'field2' ^ 3.2], 'search query'), 1.0)") + ); + } + @Test public void relevanceQuery() { assertEquals(AstDSL.function("query",