Skip to content

Commit

Permalink
Rebase from main
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Carbonetto <andrewc@bitquilltech.com>
  • Loading branch information
acarbonetto committed Mar 7, 2023
1 parent 444eefd commit 6c480be
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 154 deletions.
101 changes: 20 additions & 81 deletions core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import static org.opensearch.sql.ast.dsl.AstDSL.and;
import static org.opensearch.sql.ast.dsl.AstDSL.compare;
import static org.opensearch.sql.ast.expression.QualifiedName.METADATAFIELD_TYPE_MAP;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.GTE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTE;

Expand All @@ -32,7 +31,6 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -64,14 +62,14 @@
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.ScoreExpression;
import org.opensearch.sql.expression.aggregation.AggregationState;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.conditional.cases.CaseClause;
import org.opensearch.sql.expression.conditional.cases.WhenClause;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
Expand Down Expand Up @@ -212,75 +210,9 @@ public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext
return new HighlightExpression(expr);
}

/**
* visitScoreFunction removes the score function from the AST and replaces it with the child
* relevance function node. If the optional boost variable is provided, the boost argument
* of the relevance function is combined.
* @param node score function node
* @param context analysis context for the query
* @return resolved relevance function
*/
public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) {
// if no function argument given, just accept the relevance query and return
if (node.getFuncArgs().isEmpty() || !(node.getFuncArgs().get(0) instanceof Literal)) {
OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr =
(OpenSearchFunctions.OpenSearchFunction) node
.getRelevanceQuery().accept(this, context);
relevanceQueryExpr.setScoreTracked(true);
return relevanceQueryExpr;
}

// note: if an argument exists, and there should only be one, it will be a boost argument
Literal boostFunctionArg = (Literal) node.getFuncArgs().get(0);
Double thisBoostValue;
if (boostFunctionArg.getType().equals(DataType.DOUBLE)) {
thisBoostValue = ((Double) boostFunctionArg.getValue());
} else if (boostFunctionArg.getType().equals(DataType.INTEGER)) {
thisBoostValue = ((Integer) boostFunctionArg.getValue()).doubleValue();
} else {
throw new SemanticCheckException(String.format("Expected boost type '%s' but got '%s'",
DataType.DOUBLE.name(), boostFunctionArg.getType().name()));
}

// update the existing unresolved expression to add a boost argument if it doesn't exist
// OR multiply the existing boost argument
Function relevanceQueryUnresolvedExpr = (Function)node.getRelevanceQuery();
List<UnresolvedExpression> relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs();

boolean doesFunctionContainBoostArgument = false;
List<UnresolvedExpression> updatedFuncArgs = new ArrayList<>();
for (UnresolvedExpression expr: relevanceFuncArgs) {
String argumentName = ((UnresolvedArgument) expr).getArgName();
if (argumentName.equalsIgnoreCase("boost")) {
doesFunctionContainBoostArgument = true;
Literal boostArgLiteral = (Literal)((UnresolvedArgument) expr).getValue();
Double boostValue = Double.parseDouble((String)boostArgLiteral.getValue()) * thisBoostValue;
UnresolvedArgument newBoostArg = new UnresolvedArgument(
argumentName,
new Literal(boostValue.toString(), DataType.STRING)
);
updatedFuncArgs.add(newBoostArg);
} else {
updatedFuncArgs.add(expr);
}
}

// since nothing was found, add an argument
if (!doesFunctionContainBoostArgument) {
UnresolvedArgument newBoostArg = new UnresolvedArgument(
"boost", new Literal(Double.toString(thisBoostValue), DataType.STRING));
updatedFuncArgs.add(newBoostArg);
}

// create a new function expression with boost argument and resolve it
Function updatedRelevanceQueryUnresolvedExpr = new Function(
relevanceQueryUnresolvedExpr.getFuncName(),
updatedFuncArgs);
OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr =
(OpenSearchFunctions.OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr
.accept(this, context);
relevanceQueryExpr.setScoreTracked(true);
return relevanceQueryExpr;
Expression relevanceQueryExpr = node.getRelevanceQuery().accept(this, context);
return new ScoreExpression(relevanceQueryExpr);
}

@Override
Expand Down Expand Up @@ -392,17 +324,24 @@ public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisConte
return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context));
}

/**
* If QualifiedName is actually a reserved metadata field, return the expr type associated
* with the metadata field.
* @param ident metadata field name
* @param context analysis context
* @return DSL reference
*/
private Expression visitMetadata(String ident, AnalysisContext context) {
ExprCoreType exprCoreType = Optional.ofNullable(METADATAFIELD_TYPE_MAP.get(ident))
.orElseThrow(() -> new SemanticCheckException("invalid metadata field"));
return DSL.ref(ident, exprCoreType);
ReferenceExpression ref;
switch (ident.toLowerCase()) {
case "_index":
case "_id":
ref = DSL.ref(ident, ExprCoreType.STRING);
break;
case "_score":
case "_maxscore":
ref = DSL.ref(ident, ExprCoreType.FLOAT);
break;
case "_sort":
ref = DSL.ref(ident, ExprCoreType.LONG);
break;
default:
throw new SemanticCheckException("invalid metadata field");
}
return ref;
}

private Expression visitIdentifier(String ident, AnalysisContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@

package org.opensearch.sql.ast.expression;

import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

/**
* Expression node of Score function.
* Score takes a relevance-search expression as an argument and returns it
* Expression node of Highlight function.
*/
@AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
Expand Down
8 changes: 0 additions & 8 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -794,14 +794,6 @@ public static FunctionExpression score(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args);
}

public static FunctionExpression scorequery(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args);
}

public static FunctionExpression score_query(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args);
}

public static FunctionExpression now(FunctionProperties functionProperties,
Expression... args) {
return compile(functionProperties, BuiltinFunctionName.NOW, args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public T visitHighlight(HighlightExpression node, C context) {
return visitNode(node, context);
}

public T visitScore(ScoreExpression node, C context) {
return visitNode(node, context);
}

public T visitReference(ReferenceExpression node, C context) {
return visitNode(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression;

import lombok.Getter;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
* Score Expression.
*/
@Getter
public class ScoreExpression extends FunctionExpression {

private final Expression relevanceQueryExpr;

/**
* ScoreExpression Constructor.
* @param relevanceQueryExpr : relevanceQueryExpr for expression.
*/
public ScoreExpression(Expression relevanceQueryExpr) {
super(BuiltinFunctionName.SCORE.getName(), List.of(relevanceQueryExpr));
this.relevanceQueryExpr = relevanceQueryExpr;
}

/**
* Return collection value matching relevance query expression.
* @param valueEnv : Dataset to parse value from.
* @return : collection value of relevance query expression.
*/
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
// String refName = "_highlight";
// // Not a wilcard expression
// if (this.type == ExprCoreType.ARRAY) {
// refName += "." + StringUtils.unquoteText(getHighlightField().toString());
// }
// ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING));
//
// // In the event of multiple returned highlights and wildcard being
// // used in conjunction with other highlight calls, we need to ensure
// // only wildcard regex matching is mapped to wildcard call.
// if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) {
// value = new ExprTupleValue(
// new LinkedHashMap<String, ExprValue>(value.tupleValue()
// .entrySet()
// .stream()
// .filter(s -> matchesHighlightRegex(s.getKey(),
// StringUtils.unquoteText(highlightField.toString())))
// .collect(Collectors.toMap(
// e -> e.getKey(),
// e -> e.getValue()))));
// if (value.tupleValue().isEmpty()) {
// value = ExprValueUtils.missingValue();
// }
// }

// TODO: this is where we visit relevance function nodes and update BOOST values as necessary
// Otherwise, this is a no-op

return ExprNullValue.of();
}

@Override
public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitScore(this, context);
}

@Override
public ExprType type() {
return ExprCoreType.UNDEFINED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE;

import java.util.List;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.data.model.ExprDoubleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
Expand Down Expand Up @@ -93,6 +94,27 @@ private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery
return new RelevanceFunctionResolver(funcName);
}

/**
* Definition of score() function.
* Enables score calculation for the match call
*/
// private static DefaultFunctionResolver score(BuiltinFunctionName score) {
// FunctionName funcName = score.getName();
// return FunctionDSL.define(funcName,
// FunctionDSL.impl(
// FunctionDSL.nullMissingHandling(
// (relevanceFunc) -> new ExprDoubleValue(
// Math.pow(relevanceFunc.shortValue(), 1))
// ),
// BOOLEAN, BOOLEAN),
// FunctionDSL.impl(
// FunctionDSL.nullMissingHandling(
// (relevanceFunc, boost) -> new ExprDoubleValue(
// Math.pow(relevanceFunc.shortValue(), boost.shortValue()))
// ),
// BOOLEAN, BOOLEAN, DOUBLE));
// }

private static FunctionResolver score(BuiltinFunctionName score) {
FunctionName funcName = score.getName();
return new RelevanceFunctionResolver(funcName);
Expand All @@ -102,10 +124,6 @@ public static class OpenSearchFunction extends FunctionExpression {
private final FunctionName functionName;
private final List<Expression> arguments;

@Getter
@Setter
private boolean isScoreTracked;

/**
* Required argument constructor.
* @param functionName name of the function
Expand All @@ -115,7 +133,6 @@ public OpenSearchFunction(FunctionName functionName, List<Expression> arguments)
super(functionName, arguments);
this.functionName = functionName;
this.arguments = arguments;
this.isScoreTracked = false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ public class TableScanPushDown<T extends LogicalPlan> implements Rule<T> {
.apply((highlight, scanBuilder) -> scanBuilder.pushDownHighlight(highlight));


public static final Rule<?> PUSH_DOWN_SCORE =
match(highlight(scanBuilder())).apply(
(highlight, scanBuilder) -> scanBuilder.pushDownHighlight(highlight)
);


/** Pattern that matches a plan node. */
private final WithPattern<T> pattern;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public class OpenSearchRequestBuilder {
*/
private Integer querySize;

private boolean trackScores;

public OpenSearchRequestBuilder(String indexName,
Integer maxResultWindow,
Settings settings,
Expand All @@ -97,10 +99,11 @@ public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName,
this.sourceBuilder = new SearchSourceBuilder();
this.exprValueFactory = exprValueFactory;
this.querySize = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT);
this.trackScores = true;
sourceBuilder.from(0);
sourceBuilder.size(querySize);
sourceBuilder.timeout(DEFAULT_QUERY_TIMEOUT);
sourceBuilder.trackScores(false);
sourceBuilder.trackScores(this.trackScores);
}

/**
Expand Down Expand Up @@ -181,10 +184,6 @@ public void pushDownLimit(Integer limit, Integer offset) {
sourceBuilder.from(offset).size(limit);
}

public void pushDownTrackedScore(boolean trackScores) {
sourceBuilder.trackScores(trackScores);
}

/**
* Add highlight to DSL requests.
* @param field name of the field to highlight
Expand Down
Loading

0 comments on commit 6c480be

Please sign in to comment.