diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/DecayFunctionScoreIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/DecayFunctionScoreIT.java index e31515898418f..712026eaf5c43 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/DecayFunctionScoreIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/DecayFunctionScoreIT.java @@ -51,6 +51,7 @@ import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; import org.opensearch.index.query.functionscore.ScoreFunctionBuilders; import org.opensearch.search.MultiValueMode; +import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.VersionUtils; @@ -77,7 +78,9 @@ import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; @@ -616,6 +619,76 @@ public void testCombineModes() throws Exception { } + public void testCombineModesExplain() throws Exception { + assertAcked( + prepareCreate("test").addMapping( + "type1", + jsonBuilder().startObject() + .startObject("type1") + .startObject("properties") + .startObject("test") + .field("type", "text") + .endObject() + .startObject("num") + .field("type", "double") + .endObject() + .endObject() + .endObject() + .endObject() + ) + ); + + client().prepareIndex() + .setId("1") + .setIndex("test") + .setRefreshPolicy(IMMEDIATE) + .setSource(jsonBuilder().startObject().field("test", "value value").field("num", 1.0).endObject()) + .get(); + + FunctionScoreQueryBuilder baseQuery = functionScoreQuery( + constantScoreQuery(termQuery("test", "value")).queryName("query1"), + ScoreFunctionBuilders.weightFactorFunction(2, "weight1") + ); + // decay score should return 0.5 for this function and baseQuery should return 2.0f as it's score + ActionFuture response = client().search( + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) + .source( + searchSource().explain(true) + .query( + functionScoreQuery(baseQuery, gaussDecayFunction("num", 0.0, 1.0, null, 0.5, "func2")).boostMode( + CombineFunction.MULTIPLY + ) + ) + ) + ); + SearchResponse sr = response.actionGet(); + SearchHits sh = sr.getHits(); + assertThat(sh.getTotalHits().value, equalTo((long) (1))); + assertThat(sh.getAt(0).getId(), equalTo("1")); + assertThat(sh.getAt(0).getExplanation().getDetails(), arrayWithSize(2)); + assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails(), arrayWithSize(2)); + // "description": "ConstantScore(test:value) (_name: query1)" + assertThat( + sh.getAt(0).getExplanation().getDetails()[0].getDetails()[0].getDescription(), + equalTo("ConstantScore(test:value) (_name: query1)") + ); + assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails(), arrayWithSize(2)); + assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(2)); + // "description": "constant score 1.0(_name: func1) - no function provided" + assertThat( + sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails()[0].getDescription(), + equalTo("constant score 1.0(_name: weight1) - no function provided") + ); + // "description": "exp(-0.5*pow(MIN[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.7213475204444817, + // _name: func2)" + assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails(), arrayWithSize(2)); + assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1)); + assertThat( + sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription(), + containsString("_name: func2") + ); + } + public void testExceptionThrownIfScaleLE0() throws Exception { assertAcked( prepareCreate("test").addMapping( @@ -1195,4 +1268,132 @@ public void testMultiFieldOptions() throws Exception { sh = sr.getHits(); assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d)); } + + public void testDistanceScoreGeoLinGaussExplain() throws Exception { + assertAcked( + prepareCreate("test").addMapping( + "type1", + jsonBuilder().startObject() + .startObject("type1") + .startObject("properties") + .startObject("test") + .field("type", "text") + .endObject() + .startObject("loc") + .field("type", "geo_point") + .endObject() + .endObject() + .endObject() + .endObject() + ) + ); + + List indexBuilders = new ArrayList<>(); + indexBuilders.add( + client().prepareIndex() + .setId("1") + .setIndex("test") + .setSource( + jsonBuilder().startObject() + .field("test", "value") + .startObject("loc") + .field("lat", 10) + .field("lon", 20) + .endObject() + .endObject() + ) + ); + indexBuilders.add( + client().prepareIndex() + .setId("2") + .setIndex("test") + .setSource( + jsonBuilder().startObject() + .field("test", "value") + .startObject("loc") + .field("lat", 11) + .field("lon", 22) + .endObject() + .endObject() + ) + ); + + indexRandom(true, indexBuilders); + + // Test Gauss + List lonlat = new ArrayList<>(); + lonlat.add(20f); + lonlat.add(11f); + + final String queryName = "query1"; + final String functionName = "func1"; + ActionFuture response = client().search( + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) + .source( + searchSource().explain(true) + .query( + functionScoreQuery(baseQuery.queryName(queryName), gaussDecayFunction("loc", lonlat, "1000km", functionName)) + ) + ) + ); + SearchResponse sr = response.actionGet(); + SearchHits sh = sr.getHits(); + assertThat(sh.getTotalHits().value, equalTo(2L)); + assertThat(sh.getAt(0).getId(), equalTo("1")); + assertThat(sh.getAt(1).getId(), equalTo("2")); + assertExplain(queryName, functionName, sr); + + response = client().search( + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) + .source( + searchSource().explain(true) + .query( + functionScoreQuery(baseQuery.queryName(queryName), linearDecayFunction("loc", lonlat, "1000km", functionName)) + ) + ) + ); + + sr = response.actionGet(); + sh = sr.getHits(); + assertThat(sh.getTotalHits().value, equalTo(2L)); + assertThat(sh.getAt(0).getId(), equalTo("1")); + assertThat(sh.getAt(1).getId(), equalTo("2")); + assertExplain(queryName, functionName, sr); + + response = client().search( + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) + .source( + searchSource().explain(true) + .query( + functionScoreQuery( + baseQuery.queryName(queryName), + exponentialDecayFunction("loc", lonlat, "1000km", functionName) + ) + ) + ) + ); + + sr = response.actionGet(); + sh = sr.getHits(); + assertThat(sh.getTotalHits().value, equalTo(2L)); + assertThat(sh.getAt(0).getId(), equalTo("1")); + assertThat(sh.getAt(1).getId(), equalTo("2")); + assertExplain(queryName, functionName, sr); + } + + private void assertExplain(final String queryName, final String functionName, SearchResponse sr) { + SearchHit firstHit = sr.getHits().getAt(0); + assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2)); + // "description": "*:* (_name: query1)" + assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName)); + assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2)); + // "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)" + assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1)); + // "description": "exp(-0.5*pow(MIN of: [Math.max(arcDistance(10.999999972991645, 21.99999994598329(=doc value),11.0, 20.0(=origin)) + // - 0.0(=offset), 0)],2.0)/7.213475204444817E11, _name: func1)" + assertThat( + firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription().toString(), + containsString("_name: " + functionName) + ); + } } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java index c58f99a67a3fd..f67b913a75871 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java @@ -38,6 +38,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchType; import org.opensearch.common.lucene.search.function.CombineFunction; +import org.opensearch.common.lucene.search.function.Functions; import org.opensearch.common.settings.Settings; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.plugins.Plugin; @@ -72,6 +73,7 @@ import static org.opensearch.index.query.QueryBuilders.termQuery; import static org.opensearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction; import static org.opensearch.search.builder.SearchSourceBuilder.searchSource; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -121,8 +123,17 @@ static class MyScript extends ScoreScript implements ExplainableScoreScript { @Override public Explanation explain(Explanation subQueryScore) throws IOException { + return explain(subQueryScore, null); + } + + @Override + public Explanation explain(Explanation subQueryScore, String functionName) throws IOException { Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore); - return Explanation.match((float) (execute(null)), "This script returned " + execute(null), scoreExp); + return Explanation.match( + (float) (execute(null)), + "This script" + Functions.nameOrEmptyFunc(functionName) + " returned " + execute(null), + scoreExp + ); } @Override @@ -174,4 +185,36 @@ public void testExplainScript() throws InterruptedException, IOException, Execut idCounter--; } } + + public void testExplainScriptWithName() throws InterruptedException, IOException, ExecutionException { + List indexRequests = new ArrayList<>(); + indexRequests.add( + client().prepareIndex("test") + .setId(Integer.toString(1)) + .setSource(jsonBuilder().startObject().field("number_field", 1).field("text", "text").endObject()) + ); + indexRandom(true, true, indexRequests); + client().admin().indices().prepareRefresh().get(); + ensureYellow(); + SearchResponse response = client().search( + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) + .source( + searchSource().explain(true) + .query( + functionScoreQuery( + termQuery("text", "text"), + scriptFunction(new Script(ScriptType.INLINE, "test", "explainable_script", Collections.emptyMap()), "func1") + ).boostMode(CombineFunction.REPLACE) + ) + ) + ).actionGet(); + + OpenSearchAssertions.assertNoFailures(response); + SearchHits hits = response.getHits(); + assertThat(hits.getTotalHits().value, equalTo(1L)); + assertThat(hits.getHits()[0].getId(), equalTo("1")); + assertThat(hits.getHits()[0].getExplanation().getDetails(), arrayWithSize(2)); + assertThat(hits.getHits()[0].getExplanation().getDetails()[0].getDescription(), containsString("_name: func1")); + } + } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreFieldValueIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreFieldValueIT.java index 8714e13471c21..8e0a14b7062a7 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreFieldValueIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreFieldValueIT.java @@ -35,10 +35,13 @@ import org.opensearch.action.search.SearchPhaseExecutionException; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.lucene.search.function.FieldValueFactorFunction; +import org.opensearch.search.SearchHit; import org.opensearch.test.OpenSearchIntegTestCase; import java.io.IOException; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.arrayWithSize; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.index.query.QueryBuilders.functionScoreQuery; import static org.opensearch.index.query.QueryBuilders.matchAllQuery; @@ -163,4 +166,47 @@ public void testFieldValueFactor() throws IOException { // locally, instead of just having failures } } + + public void testFieldValueFactorExplain() throws IOException { + assertAcked( + prepareCreate("test").addMapping( + "type1", + jsonBuilder().startObject() + .startObject("type1") + .startObject("properties") + .startObject("test") + .field("type", randomFrom(new String[] { "short", "float", "long", "integer", "double" })) + .endObject() + .startObject("body") + .field("type", "text") + .endObject() + .endObject() + .endObject() + .endObject() + ).get() + ); + + client().prepareIndex("test").setId("1").setSource("test", 5, "body", "foo").get(); + client().prepareIndex("test").setId("2").setSource("test", 17, "body", "foo").get(); + client().prepareIndex("test").setId("3").setSource("body", "bar").get(); + + refresh(); + + // document 2 scores higher because 17 > 5 + final String functionName = "func1"; + final String queryName = "query"; + SearchResponse response = client().prepareSearch("test") + .setExplain(true) + .setQuery( + functionScoreQuery(simpleQueryStringQuery("foo").queryName(queryName), fieldValueFactorFunction("test", functionName)) + ) + .get(); + assertOrderedSearchHits(response, "2", "1"); + SearchHit firstHit = response.getHits().getAt(0); + assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2)); + // "description": "sum of: (_name: query)" + assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: " + queryName)); + // "description": "field value function(_name: func1): none(doc['test'].value * factor=1.0)" + assertThat(firstHit.getExplanation().getDetails()[1].toString(), containsString("_name: " + functionName)); + } } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreIT.java index 0537ab8f0da7a..3d24933f66d17 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreIT.java @@ -43,6 +43,7 @@ import org.opensearch.script.MockScriptPlugin; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; +import org.opensearch.search.SearchHit; import org.opensearch.search.aggregations.bucket.terms.Terms; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.OpenSearchTestCase; @@ -66,6 +67,8 @@ import static org.opensearch.search.builder.SearchSourceBuilder.searchSource; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -140,6 +143,35 @@ public void testScriptScoresWithAgg() throws IOException { assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L)); } + public void testScriptScoresWithAggWithExplain() throws IOException { + createIndex(INDEX); + index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject()); + refresh(); + + Script script = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "get score value", Collections.emptyMap()); + + SearchResponse response = client().search( + searchRequest().source( + searchSource().explain(true) + .query(functionScoreQuery(scriptFunction(script, "func1"), "query1")) + .aggregation(terms("score_agg").script(script)) + ) + ).actionGet(); + assertSearchResponse(response); + + final SearchHit firstHit = response.getHits().getAt(0); + assertThat(firstHit.getScore(), equalTo(1.0f)); + assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2)); + // "description": "*:* (_name: query1)" + assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: query1")); + assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2)); + // "description": "script score function(_name: func1), computed with script:\"Script{ ... }\"" + assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDescription(), containsString("_name: func1")); + + assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsString(), equalTo("1.0")); + assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L)); + } + public void testMinScoreFunctionScoreBasic() throws IOException { float score = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat); float minScore = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat); diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScorePluginIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScorePluginIT.java index af7633628dab1..885f1aa7ff7a0 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScorePluginIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScorePluginIT.java @@ -171,7 +171,7 @@ public double evaluate(double value, double scale) { } @Override - public Explanation explainFunction(String distanceString, double distanceVal, double scale) { + public Explanation explainFunction(String distanceString, double distanceVal, double scale, String functionName) { return Explanation.match((float) distanceVal, "" + distanceVal); } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/RandomScoreFunctionIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/RandomScoreFunctionIT.java index 669477f670f98..670f5e65eb575 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/RandomScoreFunctionIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/RandomScoreFunctionIT.java @@ -63,6 +63,7 @@ import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -289,6 +290,37 @@ public void testSeedReportedInExplain() throws Exception { assertThat(firstHit.getExplanation().toString(), containsString("" + seed)); } + public void testSeedAndNameReportedInExplain() throws Exception { + createIndex("test"); + ensureGreen(); + index("test", "type", "1", jsonBuilder().startObject().endObject()); + flush(); + refresh(); + + int seed = 12345678; + + final String queryName = "query1"; + final String functionName = "func1"; + SearchResponse resp = client().prepareSearch("test") + .setQuery( + functionScoreQuery( + matchAllQuery().queryName(queryName), + randomFunction(functionName).seed(seed).setField(SeqNoFieldMapper.NAME) + ) + ) + .setExplain(true) + .get(); + assertNoFailures(resp); + assertEquals(1, resp.getHits().getTotalHits().value); + SearchHit firstHit = resp.getHits().getAt(0); + assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2)); + // "description": "*:* (_name: query1)" + assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName)); + assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2)); + // "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)" + assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDescription().toString(), containsString("seed: " + seed)); + } + public void testNoDocs() throws Exception { createIndex("test"); ensureGreen(); diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/FieldValueFactorFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/FieldValueFactorFunction.java index a015b24d73e5a..3233fc9f8cecc 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/FieldValueFactorFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/FieldValueFactorFunction.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.opensearch.OpenSearchException; +import org.opensearch.common.Nullable; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; @@ -55,6 +56,8 @@ public class FieldValueFactorFunction extends ScoreFunction { private final String field; private final float boostFactor; private final Modifier modifier; + private final String functionName; + /** * Value used if the document is missing the field. */ @@ -67,6 +70,17 @@ public FieldValueFactorFunction( Modifier modifierType, Double missing, IndexNumericFieldData indexFieldData + ) { + this(field, boostFactor, modifierType, missing, indexFieldData, null); + } + + public FieldValueFactorFunction( + String field, + float boostFactor, + Modifier modifierType, + Double missing, + IndexNumericFieldData indexFieldData, + @Nullable String functionName ) { super(CombineFunction.MULTIPLY); this.field = field; @@ -74,6 +88,7 @@ public FieldValueFactorFunction( this.modifier = modifierType; this.indexFieldData = indexFieldData; this.missing = missing; + this.functionName = functionName; } @Override @@ -127,7 +142,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE (float) score, String.format( Locale.ROOT, - "field value function: %s(doc['%s'].value%s * factor=%s)", + "field value function" + Functions.nameOrEmptyFunc(functionName) + ": %s(doc['%s'].value%s * factor=%s)", modifierStr, field, defaultStr, diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java index 36ecf690862cc..f7b91db2e712f 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java @@ -46,6 +46,7 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; import org.opensearch.OpenSearchException; +import org.opensearch.common.Nullable; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; @@ -70,11 +71,28 @@ public class FunctionScoreQuery extends Query { public static class FilterScoreFunction extends ScoreFunction { public final Query filter; public final ScoreFunction function; + public final String queryName; + /** + * Creates a FilterScoreFunction with query and function. + * @param filter filter query + * @param function score function + */ public FilterScoreFunction(Query filter, ScoreFunction function) { + this(filter, function, null); + } + + /** + * Creates a FilterScoreFunction with query and function. + * @param filter filter query + * @param function score function + * @param queryName filter query name + */ + public FilterScoreFunction(Query filter, ScoreFunction function, @Nullable String queryName) { super(function.getDefaultScoreCombiner()); this.filter = filter; this.function = function; + this.queryName = queryName; } @Override @@ -93,12 +111,14 @@ protected boolean doEquals(ScoreFunction other) { return false; } FilterScoreFunction that = (FilterScoreFunction) other; - return Objects.equals(this.filter, that.filter) && Objects.equals(this.function, that.function); + return Objects.equals(this.filter, that.filter) + && Objects.equals(this.function, that.function) + && Objects.equals(this.queryName, that.queryName); } @Override protected int doHashCode() { - return Objects.hash(filter, function); + return Objects.hash(filter, function, queryName); } @Override @@ -107,7 +127,7 @@ protected ScoreFunction rewrite(IndexReader reader) throws IOException { if (newFilter == filter) { return this; } - return new FilterScoreFunction(newFilter, function); + return new FilterScoreFunction(newFilter, function, queryName); } @Override @@ -144,6 +164,7 @@ public static ScoreMode fromString(String scoreMode) { final float maxBoost; private final Float minScore; private final CombineFunction combineFunction; + private final String queryName; /** * Creates a FunctionScoreQuery without function. @@ -152,7 +173,18 @@ public static ScoreMode fromString(String scoreMode) { * @param maxBoost The maximum applicable boost. */ public FunctionScoreQuery(Query subQuery, Float minScore, float maxBoost) { - this(subQuery, ScoreMode.FIRST, new ScoreFunction[0], CombineFunction.MULTIPLY, minScore, maxBoost); + this(subQuery, null, minScore, maxBoost); + } + + /** + * Creates a FunctionScoreQuery without function. + * @param subQuery The query to match. + * @param queryName filter query name + * @param minScore The minimum score to consider a document. + * @param maxBoost The maximum applicable boost. + */ + public FunctionScoreQuery(Query subQuery, @Nullable String queryName, Float minScore, float maxBoost) { + this(subQuery, queryName, ScoreMode.FIRST, new ScoreFunction[0], CombineFunction.MULTIPLY, minScore, maxBoost); } /** @@ -161,7 +193,17 @@ public FunctionScoreQuery(Query subQuery, Float minScore, float maxBoost) { * @param function The {@link ScoreFunction} to apply. */ public FunctionScoreQuery(Query subQuery, ScoreFunction function) { - this(subQuery, function, CombineFunction.MULTIPLY, null, DEFAULT_MAX_BOOST); + this(subQuery, null, function); + } + + /** + * Creates a FunctionScoreQuery with a single {@link ScoreFunction} + * @param subQuery The query to match. + * @param queryName filter query name + * @param function The {@link ScoreFunction} to apply. + */ + public FunctionScoreQuery(Query subQuery, @Nullable String queryName, ScoreFunction function) { + this(subQuery, queryName, function, CombineFunction.MULTIPLY, null, DEFAULT_MAX_BOOST); } /** @@ -173,12 +215,53 @@ public FunctionScoreQuery(Query subQuery, ScoreFunction function) { * @param maxBoost The maximum applicable boost. */ public FunctionScoreQuery(Query subQuery, ScoreFunction function, CombineFunction combineFunction, Float minScore, float maxBoost) { - this(subQuery, ScoreMode.FIRST, new ScoreFunction[] { function }, combineFunction, minScore, maxBoost); + this(subQuery, null, function, combineFunction, minScore, maxBoost); + } + + /** + * Creates a FunctionScoreQuery with a single function + * @param subQuery The query to match. + * @param queryName filter query name + * @param function The {@link ScoreFunction} to apply. + * @param combineFunction Defines how the query and function score should be applied. + * @param minScore The minimum score to consider a document. + * @param maxBoost The maximum applicable boost. + */ + public FunctionScoreQuery( + Query subQuery, + @Nullable String queryName, + ScoreFunction function, + CombineFunction combineFunction, + Float minScore, + float maxBoost + ) { + this(subQuery, queryName, ScoreMode.FIRST, new ScoreFunction[] { function }, combineFunction, minScore, maxBoost); + } + + /** + * Creates a FunctionScoreQuery with multiple score functions + * @param subQuery The query to match. + * @param scoreMode Defines how the different score functions should be combined. + * @param functions The {@link ScoreFunction}s to apply. + * @param combineFunction Defines how the query and function score should be applied. + * @param minScore The minimum score to consider a document. + * @param maxBoost The maximum applicable boost. + */ + public FunctionScoreQuery( + Query subQuery, + ScoreMode scoreMode, + ScoreFunction[] functions, + CombineFunction combineFunction, + Float minScore, + float maxBoost + ) { + this(subQuery, null, scoreMode, functions, combineFunction, minScore, maxBoost); } /** * Creates a FunctionScoreQuery with multiple score functions * @param subQuery The query to match. + * @param queryName filter query name * @param scoreMode Defines how the different score functions should be combined. * @param functions The {@link ScoreFunction}s to apply. * @param combineFunction Defines how the query and function score should be applied. @@ -187,6 +270,7 @@ public FunctionScoreQuery(Query subQuery, ScoreFunction function, CombineFunctio */ public FunctionScoreQuery( Query subQuery, + @Nullable String queryName, ScoreMode scoreMode, ScoreFunction[] functions, CombineFunction combineFunction, @@ -197,6 +281,7 @@ public FunctionScoreQuery( throw new IllegalArgumentException("Score function should not be null"); } this.subQuery = subQuery; + this.queryName = queryName; this.scoreMode = scoreMode; this.functions = functions; this.maxBoost = maxBoost; @@ -240,7 +325,7 @@ public Query rewrite(IndexReader reader) throws IOException { needsRewrite |= (newFunctions[i] != functions[i]); } if (needsRewrite) { - return new FunctionScoreQuery(newQ, scoreMode, newFunctions, combineFunction, minScore, maxBoost); + return new FunctionScoreQuery(newQ, queryName, scoreMode, newFunctions, combineFunction, minScore, maxBoost); } return this; } @@ -332,8 +417,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - - Explanation expl = subQueryWeight.explain(context, doc); + Explanation expl = Functions.explainWithName(subQueryWeight.explain(context, doc), queryName); if (!expl.isMatch()) { return expl; } @@ -355,11 +439,15 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio Explanation functionExplanation = function.getLeafScoreFunction(context).explainScore(doc, expl); if (function instanceof FilterScoreFunction) { float factor = functionExplanation.getValue().floatValue(); - Query filterQuery = ((FilterScoreFunction) function).filter; + final FilterScoreFunction filterScoreFunction = (FilterScoreFunction) function; + Query filterQuery = filterScoreFunction.filter; Explanation filterExplanation = Explanation.match( factor, "function score, product of:", - Explanation.match(1.0f, "match filter: " + filterQuery.toString()), + Explanation.match( + 1.0f, + "match filter" + Functions.nameOrEmptyFunc(filterScoreFunction.queryName) + ": " + filterQuery.toString() + ), functionExplanation ); functionsExplanations.add(filterExplanation); @@ -543,11 +631,12 @@ public boolean equals(Object o) { && Objects.equals(this.combineFunction, other.combineFunction) && Objects.equals(this.minScore, other.minScore) && Objects.equals(this.scoreMode, other.scoreMode) - && Arrays.equals(this.functions, other.functions); + && Arrays.equals(this.functions, other.functions) + && Objects.equals(this.queryName, other.queryName); } @Override public int hashCode() { - return Objects.hash(classHash(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(functions)); + return Objects.hash(classHash(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(functions), queryName); } } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/Functions.java b/server/src/main/java/org/opensearch/common/lucene/search/function/Functions.java new file mode 100644 index 0000000000000..a9de8ead31e2a --- /dev/null +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/Functions.java @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.lucene.search.function; + +import org.apache.lucene.search.Explanation; +import org.opensearch.common.Strings; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder; + +/** + * Helper utility class for functions + */ +public final class Functions { + private Functions() {} + + /** + * Return function name wrapped into brackets or empty string, for example: '(_name: func1)' + * @param functionName function name + * @return function name wrapped into brackets or empty string + */ + public static String nameOrEmptyFunc(final String functionName) { + if (!Strings.isNullOrEmpty(functionName)) { + return "(" + AbstractQueryBuilder.NAME_FIELD.getPreferredName() + ": " + functionName + ")"; + } else { + return ""; + } + } + + /** + * Return function name as an argument or empty string, for example: ', _name: func1' + * @param functionName function name + * @return function name as an argument or empty string + */ + public static String nameOrEmptyArg(final String functionName) { + if (!Strings.isNullOrEmpty(functionName)) { + return ", " + FunctionScoreQueryBuilder.NAME_FIELD.getPreferredName() + ": " + functionName; + } else { + return ""; + } + } + + /** + * Enrich explanation with query name + * @param explanation explanation + * @param queryName query name + * @return explanation enriched with query name + */ + public static Explanation explainWithName(Explanation explanation, String queryName) { + if (Strings.isNullOrEmpty(queryName)) { + return explanation; + } else { + final String description = explanation.getDescription() + " " + nameOrEmptyFunc(queryName); + if (explanation.isMatch()) { + return Explanation.match(explanation.getValue(), description, explanation.getDetails()); + } else { + return Explanation.noMatch(description, explanation.getDetails()); + } + } + } +} diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/RandomScoreFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/RandomScoreFunction.java index 78df111393394..f4fcda47b0078 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/RandomScoreFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/RandomScoreFunction.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.util.StringHelper; +import org.opensearch.common.Nullable; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -50,6 +51,7 @@ public class RandomScoreFunction extends ScoreFunction { private final int originalSeed; private final int saltedSeed; private final IndexFieldData fieldData; + private final String functionName; /** * Creates a RandomScoreFunction. @@ -59,10 +61,23 @@ public class RandomScoreFunction extends ScoreFunction { * @param uidFieldData The field data for _uid to use for generating consistent random values for the same id */ public RandomScoreFunction(int seed, int salt, IndexFieldData uidFieldData) { + this(seed, salt, uidFieldData, null); + } + + /** + * Creates a RandomScoreFunction. + * + * @param seed A seed for randomness + * @param salt A value to salt the seed with, ideally unique to the running node/index + * @param uidFieldData The field data for _uid to use for generating consistent random values for the same id + * @param functionName The function name + */ + public RandomScoreFunction(int seed, int salt, IndexFieldData uidFieldData, @Nullable String functionName) { super(CombineFunction.MULTIPLY); this.originalSeed = seed; this.saltedSeed = BitMixer.mix(seed, salt); this.fieldData = uidFieldData; + this.functionName = functionName; } @Override @@ -97,7 +112,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE String field = fieldData == null ? null : fieldData.getFieldName(); return Explanation.match( (float) score(docId, subQueryScore.getValue().floatValue()), - "random score function (seed: " + originalSeed + ", field: " + field + ")" + "random score function (seed: " + originalSeed + ", field: " + field + Functions.nameOrEmptyArg(functionName) + ")" ); } }; diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java index 5ce50844b3dcc..3a7cc970908a5 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java @@ -39,6 +39,7 @@ import org.opensearch.script.ScoreScript; import org.opensearch.script.Script; import org.opensearch.Version; +import org.opensearch.common.Nullable; import java.io.IOException; import java.util.Objects; @@ -67,14 +68,23 @@ public float score() { private final int shardId; private final String indexName; private final Version indexVersion; - - public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) { + private final String functionName; + + public ScriptScoreFunction( + Script sScript, + ScoreScript.LeafFactory script, + String indexName, + int shardId, + Version indexVersion, + @Nullable String functionName + ) { super(CombineFunction.REPLACE); this.sScript = sScript; this.script = script; this.indexName = indexName; this.shardId = shardId; this.indexVersion = indexVersion; + this.functionName = functionName; } @Override @@ -105,11 +115,15 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE leafScript.setDocument(docId); scorer.docid = docId; scorer.score = subQueryScore.getValue().floatValue(); - exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore); + exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName); } else { double score = score(docId, subQueryScore.getValue().floatValue()); // info about params already included in sScript - String explanation = "script score function, computed with script:\"" + sScript + "\""; + String explanation = "script score function" + + Functions.nameOrEmptyFunc(functionName) + + ", computed with script:\"" + + sScript + + "\""; Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore); return Explanation.match((float) score, explanation, scoreExp); } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java index 7d9f293b0c17b..44c76e74d5a41 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreQuery.java @@ -50,6 +50,7 @@ import org.apache.lucene.search.BulkScorer; import org.apache.lucene.util.Bits; import org.opensearch.Version; +import org.opensearch.common.Nullable; import org.opensearch.script.ScoreScript; import org.opensearch.script.ScoreScript.ExplanationHolder; import org.opensearch.script.Script; @@ -69,6 +70,7 @@ public class ScriptScoreQuery extends Query { private final String indexName; private final int shardId; private final Version indexVersion; + private final String queryName; public ScriptScoreQuery( Query subQuery, @@ -78,8 +80,22 @@ public ScriptScoreQuery( String indexName, int shardId, Version indexVersion + ) { + this(subQuery, null, script, scriptBuilder, minScore, indexName, shardId, indexVersion); + } + + public ScriptScoreQuery( + Query subQuery, + @Nullable String queryName, + Script script, + ScoreScript.LeafFactory scriptBuilder, + Float minScore, + String indexName, + int shardId, + Version indexVersion ) { this.subQuery = subQuery; + this.queryName = queryName; this.script = script; this.scriptBuilder = scriptBuilder; this.minScore = minScore; @@ -92,7 +108,7 @@ public ScriptScoreQuery( public Query rewrite(IndexReader reader) throws IOException { Query newQ = subQuery.rewrite(reader); if (newQ != subQuery) { - return new ScriptScoreQuery(newQ, script, scriptBuilder, minScore, indexName, shardId, indexVersion); + return new ScriptScoreQuery(newQ, queryName, script, scriptBuilder, minScore, indexName, shardId, indexVersion); } return super.rewrite(reader); } @@ -140,7 +156,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - Explanation subQueryExplanation = subQueryWeight.explain(context, doc); + Explanation subQueryExplanation = Functions.explainWithName(subQueryWeight.explain(context, doc), queryName); if (subQueryExplanation.isMatch() == false) { return subQueryExplanation; } @@ -210,7 +226,8 @@ public void visit(QueryVisitor visitor) { @Override public String toString(String field) { StringBuilder sb = new StringBuilder(); - sb.append("script_score (").append(subQuery.toString(field)).append(", script: "); + sb.append("script_score (").append(subQuery.toString(field)); + sb.append(Functions.nameOrEmptyArg(queryName)).append(", script: "); sb.append("{" + script.toString() + "}"); return sb.toString(); } @@ -225,12 +242,13 @@ public boolean equals(Object o) { && script.equals(that.script) && Objects.equals(minScore, that.minScore) && indexName.equals(that.indexName) - && indexVersion.equals(that.indexVersion); + && indexVersion.equals(that.indexVersion) + && Objects.equals(queryName, that.queryName); } @Override public int hashCode() { - return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion); + return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion, queryName); } private static class ScriptScorer extends Scorer { diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/WeightFactorFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/WeightFactorFunction.java index 9ef33efdfd9f5..71968a0545cff 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/WeightFactorFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/WeightFactorFunction.java @@ -34,6 +34,8 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; +import org.opensearch.common.Nullable; +import org.opensearch.common.Strings; import java.io.IOException; import java.util.Objects; @@ -45,9 +47,17 @@ public class WeightFactorFunction extends ScoreFunction { private float weight = 1.0f; public WeightFactorFunction(float weight, ScoreFunction scoreFunction) { + this(weight, scoreFunction, null); + } + + public WeightFactorFunction(float weight, ScoreFunction scoreFunction, @Nullable String functionName) { super(CombineFunction.MULTIPLY); if (scoreFunction == null) { - this.scoreFunction = SCORE_ONE; + if (Strings.isNullOrEmpty(functionName)) { + this.scoreFunction = SCORE_ONE; + } else { + this.scoreFunction = new ScoreOne(CombineFunction.MULTIPLY, functionName); + } } else { this.scoreFunction = scoreFunction; } @@ -55,9 +65,11 @@ public WeightFactorFunction(float weight, ScoreFunction scoreFunction) { } public WeightFactorFunction(float weight) { - super(CombineFunction.MULTIPLY); - this.scoreFunction = SCORE_ONE; - this.weight = weight; + this(weight, null, null); + } + + public WeightFactorFunction(float weight, @Nullable String functionName) { + this(weight, null, functionName); } @Override @@ -112,9 +124,15 @@ protected int doHashCode() { } private static class ScoreOne extends ScoreFunction { + private final String functionName; protected ScoreOne(CombineFunction scoreCombiner) { + this(scoreCombiner, null); + } + + protected ScoreOne(CombineFunction scoreCombiner, @Nullable String functionName) { super(scoreCombiner); + this.functionName = functionName; } @Override @@ -127,7 +145,10 @@ public double score(int docId, float subQueryScore) { @Override public Explanation explainScore(int docId, Explanation subQueryScore) { - return Explanation.match(1.0f, "constant score 1.0 - no function provided"); + return Explanation.match( + 1.0f, + "constant score 1.0" + Functions.nameOrEmptyFunc(functionName) + " - no function provided" + ); } }; } diff --git a/server/src/main/java/org/opensearch/index/query/QueryBuilders.java b/server/src/main/java/org/opensearch/index/query/QueryBuilders.java index 5b386564df1e8..a80e4c1c9f745 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryBuilders.java +++ b/server/src/main/java/org/opensearch/index/query/QueryBuilders.java @@ -33,6 +33,7 @@ package org.opensearch.index.query; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.common.Nullable; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.geo.GeoPoint; import org.opensearch.common.geo.ShapeRelation; @@ -452,7 +453,17 @@ public static FunctionScoreQueryBuilder functionScoreQuery(FunctionScoreQueryBui * @param function The function builder used to custom score */ public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function) { - return new FunctionScoreQueryBuilder(function); + return functionScoreQuery(function, null); + } + + /** + * A query that allows to define a custom scoring function. + * + * @param function The function builder used to custom score + * @param queryName The query name + */ + public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function, @Nullable String queryName) { + return new FunctionScoreQueryBuilder(function, queryName); } /** diff --git a/server/src/main/java/org/opensearch/index/query/ScriptQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/ScriptQueryBuilder.java index 881323b05e536..8739e48eb411b 100644 --- a/server/src/main/java/org/opensearch/index/query/ScriptQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/ScriptQueryBuilder.java @@ -43,9 +43,11 @@ import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.opensearch.OpenSearchException; +import org.opensearch.common.Nullable; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.lucene.search.function.Functions; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.script.FilterScript; @@ -153,17 +155,19 @@ protected Query doToQuery(QueryShardContext context) throws IOException { } FilterScript.Factory factory = context.compile(script, FilterScript.CONTEXT); FilterScript.LeafFactory filterScript = factory.newFactory(script.getParams(), context.lookup()); - return new ScriptQuery(script, filterScript); + return new ScriptQuery(script, filterScript, queryName); } static class ScriptQuery extends Query { final Script script; final FilterScript.LeafFactory filterScript; + final String queryName; - ScriptQuery(Script script, FilterScript.LeafFactory filterScript) { + ScriptQuery(Script script, FilterScript.LeafFactory filterScript, @Nullable String queryName) { this.script = script; this.filterScript = filterScript; + this.queryName = queryName; } @Override @@ -171,6 +175,7 @@ public String toString(String field) { StringBuilder buffer = new StringBuilder(); buffer.append("ScriptQuery("); buffer.append(script); + buffer.append(Functions.nameOrEmptyArg(queryName)); buffer.append(")"); return buffer.toString(); } diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunction.java b/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunction.java index 8a595dda07979..02d01ef470b61 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunction.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunction.java @@ -33,6 +33,7 @@ package org.opensearch.index.query.functionscore; import org.apache.lucene.search.Explanation; +import org.opensearch.common.Nullable; /** * Implement this interface to provide a decay function that is executed on a @@ -45,7 +46,7 @@ public interface DecayFunction { double evaluate(double value, double scale); - Explanation explainFunction(String valueString, double value, double scale); + Explanation explainFunction(String valueString, double value, double scale, @Nullable String functionName); /** * The final scale parameter is computed from the scale parameter given by diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunctionBuilder.java index 3ddacb1305536..0ee61b34cd279 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/DecayFunctionBuilder.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.opensearch.OpenSearchParseException; +import org.opensearch.common.Nullable; import org.opensearch.common.ParsingException; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.geo.GeoDistance; @@ -93,10 +94,31 @@ protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Ob this(fieldName, origin, scale, offset, DEFAULT_DECAY); } + /** + * Convenience constructor that converts its parameters into json to parse on the data nodes. + */ + protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) { + this(fieldName, origin, scale, offset, DEFAULT_DECAY, functionName); + } + /** * Convenience constructor that converts its parameters into json to parse on the data nodes. */ protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { + this(fieldName, origin, scale, offset, decay, null); + } + + /** + * Convenience constructor that converts its parameters into json to parse on the data nodes. + */ + protected DecayFunctionBuilder( + String fieldName, + Object origin, + Object scale, + Object offset, + double decay, + @Nullable String functionName + ) { if (fieldName == null) { throw new IllegalArgumentException("decay function: field name must not be null"); } @@ -123,6 +145,7 @@ protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Ob } catch (IOException e) { throw new IllegalArgumentException("unable to build inner function object", e); } + setFunctionName(functionName); } protected DecayFunctionBuilder(String fieldName, BytesReference functionBytes) { @@ -285,7 +308,16 @@ private AbstractDistanceScoreFunction parseNumberVariable( ); } IndexNumericFieldData numericFieldData = context.getForField(fieldType); - return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode); + return new NumericFieldDataScoreFunction( + origin, + scale, + decay, + offset, + getDecayFunction(), + numericFieldData, + mode, + getFunctionName() + ); } private AbstractDistanceScoreFunction parseGeoVariable( @@ -325,7 +357,7 @@ private AbstractDistanceScoreFunction parseGeoVariable( double scale = DistanceUnit.DEFAULT.parse(scaleString, DistanceUnit.DEFAULT); double offset = DistanceUnit.DEFAULT.parse(offsetString, DistanceUnit.DEFAULT); IndexGeoPointFieldData indexFieldData = context.getForField(fieldType); - return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode); + return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode, getFunctionName()); } @@ -375,7 +407,16 @@ private AbstractDistanceScoreFunction parseDateVariable( val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24), DecayFunctionParser.class.getSimpleName() + ".offset"); double offset = val.getMillis(); IndexNumericFieldData numericFieldData = context.getForField(dateFieldType); - return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode); + return new NumericFieldDataScoreFunction( + origin, + scale, + decay, + offset, + getDecayFunction(), + numericFieldData, + mode, + getFunctionName() + ); } static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction { @@ -392,9 +433,10 @@ static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction { double offset, DecayFunction func, IndexGeoPointFieldData fieldData, - MultiValueMode mode + MultiValueMode mode, + @Nullable String functionName ) { - super(scale, decay, offset, func, mode); + super(scale, decay, offset, func, mode, functionName); this.origin = origin; this.fieldData = fieldData; } @@ -485,9 +527,10 @@ static class NumericFieldDataScoreFunction extends AbstractDistanceScoreFunction double offset, DecayFunction func, IndexNumericFieldData fieldData, - MultiValueMode mode + MultiValueMode mode, + @Nullable String functionName ) { - super(scale, decay, offset, func, mode); + super(scale, decay, offset, func, mode, functionName); this.fieldData = fieldData; this.origin = origin; } @@ -569,13 +612,15 @@ public abstract static class AbstractDistanceScoreFunction extends ScoreFunction protected final double offset; private final DecayFunction func; protected final MultiValueMode mode; + protected final String functionName; public AbstractDistanceScoreFunction( double userSuppiedScale, double decay, double offset, DecayFunction func, - MultiValueMode mode + MultiValueMode mode, + @Nullable String functionName ) { super(CombineFunction.MULTIPLY); this.mode = mode; @@ -591,6 +636,7 @@ public AbstractDistanceScoreFunction( throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : offset must be > 0.0"); } this.offset = offset; + this.functionName = functionName; } /** @@ -624,7 +670,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE return Explanation.match( (float) score(docId, subQueryScore.getValue().floatValue()), "Function for field " + getFieldName() + ":", - func.explainFunction(getDistanceString(ctx, docId), value, scale) + func.explainFunction(getDistanceString(ctx, docId), value, scale, functionName) ); } }; diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java index 7f0a9c3a58d59..b78e75762fe11 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java @@ -33,8 +33,10 @@ package org.opensearch.index.query.functionscore; import org.apache.lucene.search.Explanation; +import org.opensearch.common.Nullable; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.lucene.search.function.Functions; import java.io.IOException; @@ -45,6 +47,10 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder scoreFunctionBuilder) { - this(new MatchAllQueryBuilder(), new FilterFunctionBuilder[] { new FilterFunctionBuilder(scoreFunctionBuilder) }); + this(scoreFunctionBuilder, null); + } + + /** + * Creates a function_score query that will execute the function provided on all documents + * + * @param scoreFunctionBuilder score function that is executed + * @param queryName the query name + */ + public FunctionScoreQueryBuilder(ScoreFunctionBuilder scoreFunctionBuilder, @Nullable String queryName) { + this( + new MatchAllQueryBuilder().queryName(queryName), + new FilterFunctionBuilder[] { new FilterFunctionBuilder(scoreFunctionBuilder) } + ); } /** @@ -316,15 +340,17 @@ protected Query doToQuery(QueryShardContext context) throws IOException { int i = 0; for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) { ScoreFunction scoreFunction = filterFunctionBuilder.getScoreFunction().toFunction(context); - if (filterFunctionBuilder.getFilter().getName().equals(MatchAllQueryBuilder.NAME)) { + final QueryBuilder builder = filterFunctionBuilder.getFilter(); + if (builder.getName().equals(MatchAllQueryBuilder.NAME)) { filterFunctions[i++] = scoreFunction; } else { - Query filter = filterFunctionBuilder.getFilter().toQuery(context); - filterFunctions[i++] = new FunctionScoreQuery.FilterScoreFunction(filter, scoreFunction); + Query filter = builder.toQuery(context); + filterFunctions[i++] = new FunctionScoreQuery.FilterScoreFunction(filter, scoreFunction, builder.queryName()); } } - Query query = this.query.toQuery(context); + final QueryBuilder builder = this.query; + Query query = builder.toQuery(context); if (query == null) { query = new MatchAllDocsQuery(); } @@ -332,12 +358,12 @@ protected Query doToQuery(QueryShardContext context) throws IOException { CombineFunction boostMode = this.boostMode == null ? DEFAULT_BOOST_MODE : this.boostMode; // handle cases where only one score function and no filter was provided. In this case we create a FunctionScoreQuery. if (filterFunctions.length == 0) { - return new FunctionScoreQuery(query, minScore, maxBoost); + return new FunctionScoreQuery(query, builder.queryName(), minScore, maxBoost); } else if (filterFunctions.length == 1 && filterFunctions[0] instanceof FunctionScoreQuery.FilterScoreFunction == false) { - return new FunctionScoreQuery(query, filterFunctions[0], boostMode, minScore, maxBoost); + return new FunctionScoreQuery(query, builder.queryName(), filterFunctions[0], boostMode, minScore, maxBoost); } // in all other cases we create a FunctionScoreQuery with filters - return new FunctionScoreQuery(query, scoreMode, filterFunctions, boostMode, minScore, maxBoost); + return new FunctionScoreQuery(query, builder.queryName(), scoreMode, filterFunctions, boostMode, minScore, maxBoost); } /** @@ -606,6 +632,7 @@ private static String parseFiltersAndFunctions( QueryBuilder filter = null; ScoreFunctionBuilder scoreFunction = null; Float functionWeight = null; + String functionName = null; if (token != XContentParser.Token.START_OBJECT) { throw new ParsingException( parser.getTokenLocation(), @@ -635,6 +662,8 @@ private static String parseFiltersAndFunctions( } else if (token.isValue()) { if (WEIGHT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { functionWeight = parser.floatValue(); + } else if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + functionName = parser.text(); } else { throw new ParsingException( parser.getTokenLocation(), @@ -652,6 +681,10 @@ private static String parseFiltersAndFunctions( scoreFunction.setWeight(functionWeight); } } + + if (functionName != null && scoreFunction != null) { + scoreFunction.setFunctionName(functionName); + } } if (filter == null) { filter = new MatchAllQueryBuilder(); diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/GaussDecayFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/GaussDecayFunctionBuilder.java index c208083da08f5..ac6ae33cb4ed0 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/GaussDecayFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/GaussDecayFunctionBuilder.java @@ -33,9 +33,11 @@ package org.opensearch.index.query.functionscore; import org.apache.lucene.search.Explanation; +import org.opensearch.common.Nullable; import org.opensearch.common.ParseField; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.lucene.search.function.Functions; import java.io.IOException; @@ -49,10 +51,25 @@ public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, super(fieldName, origin, scale, offset); } + public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) { + super(fieldName, origin, scale, offset, functionName); + } + public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { super(fieldName, origin, scale, offset, decay); } + public GaussDecayFunctionBuilder( + String fieldName, + Object origin, + Object scale, + Object offset, + double decay, + @Nullable String functionName + ) { + super(fieldName, origin, scale, offset, decay, functionName); + } + GaussDecayFunctionBuilder(String fieldName, BytesReference functionBytes) { super(fieldName, functionBytes); } @@ -75,7 +92,6 @@ public DecayFunction getDecayFunction() { } private static final class GaussScoreFunction implements DecayFunction { - @Override public double evaluate(double value, double scale) { // note that we already computed scale^2 in processScale() so we do @@ -84,8 +100,11 @@ public double evaluate(double value, double scale) { } @Override - public Explanation explainFunction(String valueExpl, double value, double scale) { - return Explanation.match((float) evaluate(value, scale), "exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + ")"); + public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) { + return Explanation.match( + (float) evaluate(value, scale), + "exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + Functions.nameOrEmptyArg(functionName) + ")" + ); } @Override diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/LinearDecayFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/LinearDecayFunctionBuilder.java index 762757eb156e4..03102e45a41ba 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/LinearDecayFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/LinearDecayFunctionBuilder.java @@ -33,8 +33,10 @@ package org.opensearch.index.query.functionscore; import org.apache.lucene.search.Explanation; +import org.opensearch.common.Nullable; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.lucene.search.function.Functions; import java.io.IOException; @@ -47,10 +49,25 @@ public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, super(fieldName, origin, scale, offset); } + public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) { + super(fieldName, origin, scale, offset, functionName); + } + public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { super(fieldName, origin, scale, offset, decay); } + public LinearDecayFunctionBuilder( + String fieldName, + Object origin, + Object scale, + Object offset, + double decay, + @Nullable String functionName + ) { + super(fieldName, origin, scale, offset, decay, functionName); + } + LinearDecayFunctionBuilder(String fieldName, BytesReference functionBytes) { super(fieldName, functionBytes); } @@ -80,8 +97,11 @@ public double evaluate(double value, double scale) { } @Override - public Explanation explainFunction(String valueExpl, double value, double scale) { - return Explanation.match((float) evaluate(value, scale), "max(0.0, ((" + scale + " - " + valueExpl + ")/" + scale + ")"); + public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) { + return Explanation.match( + (float) evaluate(value, scale), + "max(0.0, ((" + scale + " - " + valueExpl + ")/" + scale + Functions.nameOrEmptyArg(functionName) + ")" + ); } @Override diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/RandomScoreFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/RandomScoreFunctionBuilder.java index 730be404feb14..26495c93082ae 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/RandomScoreFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/RandomScoreFunctionBuilder.java @@ -31,6 +31,7 @@ package org.opensearch.index.query.functionscore; +import org.opensearch.common.Nullable; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -58,6 +59,10 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder> implements ToXContentFragment, NamedWriteable { private Float weight; + private String functionName; /** * Standard empty constructor. @@ -58,11 +60,17 @@ public ScoreFunctionBuilder() {} */ public ScoreFunctionBuilder(StreamInput in) throws IOException { weight = checkWeight(in.readOptionalFloat()); + if (in.getVersion().onOrAfter(Version.V_2_0_0)) { + functionName = in.readOptionalString(); + } } @Override public final void writeTo(StreamOutput out) throws IOException { out.writeOptionalFloat(weight); + if (out.getVersion().onOrAfter(Version.V_2_0_0)) { + out.writeOptionalString(functionName); + } doWriteTo(out); } @@ -99,11 +107,30 @@ public final Float getWeight() { return weight; } + /** + * The name of this function + */ + public String getFunctionName() { + return functionName; + } + + /** + * Set the name of this function + */ + public void setFunctionName(String functionName) { + this.functionName = functionName; + } + @Override public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { if (weight != null) { builder.field(FunctionScoreQueryBuilder.WEIGHT_FIELD.getPreferredName(), weight); } + + if (functionName != null) { + builder.field(FunctionScoreQueryBuilder.NAME_FIELD.getPreferredName(), functionName); + } + doXContent(builder, params); return builder; } @@ -128,7 +155,7 @@ public final boolean equals(Object obj) { } @SuppressWarnings("unchecked") FB other = (FB) obj; - return Objects.equals(weight, other.getWeight()) && doEquals(other); + return Objects.equals(weight, other.getWeight()) && Objects.equals(functionName, other.getFunctionName()) && doEquals(other); } /** @@ -139,7 +166,7 @@ public final boolean equals(Object obj) { @Override public final int hashCode() { - return Objects.hash(getClass(), weight, doHashCode()); + return Objects.hash(getClass(), weight, functionName, doHashCode()); } /** @@ -156,7 +183,7 @@ public final ScoreFunction toFunction(QueryShardContext context) throws IOExcept if (weight == null) { return scoreFunction; } - return new WeightFactorFunction(weight, scoreFunction); + return new WeightFactorFunction(weight, scoreFunction, getFunctionName()); } /** diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ScoreFunctionBuilders.java b/server/src/main/java/org/opensearch/index/query/functionscore/ScoreFunctionBuilders.java index 54dca40208c00..59d02e9381d7e 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ScoreFunctionBuilders.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ScoreFunctionBuilders.java @@ -32,6 +32,7 @@ package org.opensearch.index.query.functionscore; +import org.opensearch.common.Nullable; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; @@ -46,10 +47,29 @@ public static ExponentialDecayFunctionBuilder exponentialDecayFunction(String fi return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null); } + public static ExponentialDecayFunctionBuilder exponentialDecayFunction( + String fieldName, + Object origin, + Object scale, + @Nullable String functionName + ) { + return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null, functionName); + } + public static ExponentialDecayFunctionBuilder exponentialDecayFunction(String fieldName, Object origin, Object scale, Object offset) { return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset); } + public static ExponentialDecayFunctionBuilder exponentialDecayFunction( + String fieldName, + Object origin, + Object scale, + Object offset, + @Nullable String functionName + ) { + return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, functionName); + } + public static ExponentialDecayFunctionBuilder exponentialDecayFunction( String fieldName, Object origin, @@ -60,10 +80,30 @@ public static ExponentialDecayFunctionBuilder exponentialDecayFunction( return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay); } + public static ExponentialDecayFunctionBuilder exponentialDecayFunction( + String fieldName, + Object origin, + Object scale, + Object offset, + double decay, + @Nullable String functionName + ) { + return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName); + } + public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale) { return new GaussDecayFunctionBuilder(fieldName, origin, scale, null); } + public static GaussDecayFunctionBuilder gaussDecayFunction( + String fieldName, + Object origin, + Object scale, + @Nullable String functionName + ) { + return new GaussDecayFunctionBuilder(fieldName, origin, scale, null, functionName); + } + public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale, Object offset) { return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset); } @@ -72,6 +112,26 @@ public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Obj return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay); } + public static GaussDecayFunctionBuilder gaussDecayFunction( + String fieldName, + Object origin, + Object scale, + Object offset, + double decay, + @Nullable String functionName + ) { + return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName); + } + + public static LinearDecayFunctionBuilder linearDecayFunction( + String fieldName, + Object origin, + Object scale, + @Nullable String functionName + ) { + return new LinearDecayFunctionBuilder(fieldName, origin, scale, null, functionName); + } + public static LinearDecayFunctionBuilder linearDecayFunction(String fieldName, Object origin, Object scale) { return new LinearDecayFunctionBuilder(fieldName, origin, scale, null); } @@ -80,6 +140,16 @@ public static LinearDecayFunctionBuilder linearDecayFunction(String fieldName, O return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset); } + public static LinearDecayFunctionBuilder linearDecayFunction( + String fieldName, + Object origin, + Object scale, + Object offset, + @Nullable String functionName + ) { + return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, functionName); + } + public static LinearDecayFunctionBuilder linearDecayFunction( String fieldName, Object origin, @@ -90,23 +160,54 @@ public static LinearDecayFunctionBuilder linearDecayFunction( return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay); } + public static LinearDecayFunctionBuilder linearDecayFunction( + String fieldName, + Object origin, + Object scale, + Object offset, + double decay, + @Nullable String functionName + ) { + return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName); + } + public static ScriptScoreFunctionBuilder scriptFunction(Script script) { - return (new ScriptScoreFunctionBuilder(script)); + return scriptFunction(script, null); } public static ScriptScoreFunctionBuilder scriptFunction(String script) { - return (new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap()))); + return scriptFunction(script, null); } public static RandomScoreFunctionBuilder randomFunction() { - return new RandomScoreFunctionBuilder(); + return randomFunction(null); } public static WeightBuilder weightFactorFunction(float weight) { - return (WeightBuilder) (new WeightBuilder().setWeight(weight)); + return weightFactorFunction(weight, null); } public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName) { - return new FieldValueFactorFunctionBuilder(fieldName); + return fieldValueFactorFunction(fieldName, null); + } + + public static ScriptScoreFunctionBuilder scriptFunction(Script script, @Nullable String functionName) { + return new ScriptScoreFunctionBuilder(script, functionName); + } + + public static ScriptScoreFunctionBuilder scriptFunction(String script, @Nullable String functionName) { + return new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap()), functionName); + } + + public static RandomScoreFunctionBuilder randomFunction(@Nullable String functionName) { + return new RandomScoreFunctionBuilder(functionName); + } + + public static WeightBuilder weightFactorFunction(float weight, @Nullable String functionName) { + return (WeightBuilder) (new WeightBuilder(functionName).setWeight(weight)); + } + + public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName, @Nullable String functionName) { + return new FieldValueFactorFunctionBuilder(fieldName, functionName); } } diff --git a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index 8b6cbe3a1bafd..2701e5867edde 100644 --- a/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -32,6 +32,7 @@ package org.opensearch.index.query.functionscore; +import org.opensearch.common.Nullable; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -57,10 +58,15 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder { */ public WeightBuilder() {} + /** + * Standard constructor. + */ + public WeightBuilder(@Nullable String functionName) { + setFunctionName(functionName); + } + /** * Read from a stream. */ diff --git a/server/src/main/java/org/opensearch/script/ExplainableScoreScript.java b/server/src/main/java/org/opensearch/script/ExplainableScoreScript.java index fb7dd7ded501b..6ea3a322449e5 100644 --- a/server/src/main/java/org/opensearch/script/ExplainableScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ExplainableScoreScript.java @@ -33,6 +33,7 @@ package org.opensearch.script; import org.apache.lucene.search.Explanation; +import org.opensearch.common.Nullable; import java.io.IOException; @@ -49,7 +50,21 @@ public interface ExplainableScoreScript { * want to explain how that was computed. * * @param subQueryScore the Explanation for _score + * @deprecated please use {@code explain(Explanation subQueryScore, @Nullable String scriptName)} */ + @Deprecated Explanation explain(Explanation subQueryScore) throws IOException; + /** + * Build the explanation of the current document being scored + * The script score needs the Explanation of the sub query score because it might use _score and + * want to explain how that was computed. + * + * @param subQueryScore the Explanation for _score + * @param scriptName the script name + */ + default Explanation explain(Explanation subQueryScore, @Nullable String scriptName) throws IOException { + return explain(subQueryScore); + } + } diff --git a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java index 62557da2adb62..2bfcec1bf786c 100644 --- a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java +++ b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java @@ -88,6 +88,7 @@ import java.util.concurrent.ExecutionException; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.core.Is.is; import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsNot.not; @@ -283,7 +284,8 @@ protected boolean sortRequiresCustomComparator() { 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), - MultiValueMode.MAX + MultiValueMode.MAX, + null ); private static final ScoreFunction EXP_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction( 0, @@ -292,7 +294,8 @@ protected boolean sortRequiresCustomComparator() { 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), - MultiValueMode.MAX + MultiValueMode.MAX, + null ); private static final ScoreFunction LIN_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction( 0, @@ -301,7 +304,48 @@ protected boolean sortRequiresCustomComparator() { 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), - MultiValueMode.MAX + MultiValueMode.MAX, + null + ); + + private static final ScoreFunction RANDOM_SCORE_FUNCTION_NAMED = new RandomScoreFunction(0, 0, new IndexFieldDataStub(), "func1"); + private static final ScoreFunction FIELD_VALUE_FACTOR_FUNCTION_NAMED = new FieldValueFactorFunction( + "test", + 1, + FieldValueFactorFunction.Modifier.LN, + 1.0, + null, + "func1" + ); + private static final ScoreFunction GAUSS_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction( + 0, + 1, + 0.1, + 0, + GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, + new IndexNumericFieldDataStub(), + MultiValueMode.MAX, + "func1" + ); + private static final ScoreFunction EXP_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction( + 0, + 1, + 0.1, + 0, + ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, + new IndexNumericFieldDataStub(), + MultiValueMode.MAX, + "func1" + ); + private static final ScoreFunction LIN_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction( + 0, + 1, + 0.1, + 0, + LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, + new IndexNumericFieldDataStub(), + MultiValueMode.MAX, + "func1" ); private static final ScoreFunction WEIGHT_FACTOR_FUNCTION = new WeightFactorFunction(4); private static final String TEXT = "The way out is through."; @@ -383,6 +427,58 @@ public void testExplainFunctionScoreQuery() throws IOException { assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0)); } + public void testExplainFunctionScoreQueryWithName() throws IOException { + Explanation functionExplanation = getFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION_NAMED); + checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0, field: test, _name: func1)"); + assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION_NAMED); + checkFunctionScoreExplanation(functionExplanation, "field value function(_name: func1): ln(doc['test'].value?:1.0 * factor=1.0)"); + assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION_NAMED); + checkFunctionScoreExplanation(functionExplanation, "Function for field test:"); + assertThat( + functionExplanation.getDetails()[0].getDetails()[0].toString(), + equalTo( + "0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs" + + "(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594, _name: func1)\n" + ) + ); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION_NAMED); + checkFunctionScoreExplanation(functionExplanation, "Function for field test:"); + assertThat( + functionExplanation.getDetails()[0].getDetails()[0].toString(), + equalTo( + "0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455, _name: func1)\n" + ) + ); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION_NAMED); + checkFunctionScoreExplanation(functionExplanation, "Function for field test:"); + assertThat( + functionExplanation.getDetails()[0].getDetails()[0].toString(), + equalTo( + "0.1 = max(0.0, ((1.1111111111111112" + + " - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112, _name: func1)\n" + ) + ); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFunctionScoreExplanation(searcher, new WeightFactorFunction(4, RANDOM_SCORE_FUNCTION_NAMED)); + checkFunctionScoreExplanation(functionExplanation, "product of:"); + assertThat( + functionExplanation.getDetails()[0].getDetails()[0].toString(), + endsWith("random score function (seed: 0, field: test, _name: func1)\n") + ); + assertThat(functionExplanation.getDetails()[0].getDetails()[1].toString(), equalTo("4.0 = weight\n")); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); + assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0)); + } + public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException { FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new TermQuery(TERM), scoreFunction, CombineFunction.AVG, 0.0f, 100); Weight weight = searcher.createWeight(searcher.rewrite(functionScoreQuery), org.apache.lucene.search.ScoreMode.COMPLETE, 1f); diff --git a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java index c80ce807bf736..e1002e114822e 100644 --- a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java @@ -110,6 +110,34 @@ public void testExplain() throws IOException { assertThat(explanation.getValue(), equalTo(1.0)); } + public void testExplainWithName() throws IOException { + Script script = new Script("script using explain"); + ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> { + assertNotNull(explanation); + explanation.set("this explains the score"); + return 1.0; + }); + + ScriptScoreQuery query = new ScriptScoreQuery( + Queries.newMatchAllQuery(), + "query1", + script, + factory, + null, + "index", + 0, + Version.CURRENT + ); + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + Explanation explanation = weight.explain(leafReaderContext, 0); + assertNotNull(explanation); + assertThat(explanation.getDescription(), equalTo("this explains the score")); + assertThat(explanation.getValue(), equalTo(1.0)); + + assertThat(explanation.getDetails(), arrayWithSize(1)); + assertThat(explanation.getDetails()[0].getDescription(), equalTo("*:* (_name: query1)")); + } + public void testExplainDefault() throws IOException { Script script = new Script("script without setting explanation"); ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> 1.5);