diff --git a/CHANGELOG.md b/CHANGELOG.md index 5290345db..6b2dfbcf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.17...2.x) ### Features ### Enhancements +- Adds rescore parameter support ([#885](https://github.com/opensearch-project/neural-search/pull/885)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 845396dd0..fe69c577e 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -18,6 +18,8 @@ import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; + +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -69,8 +71,10 @@ private void validateNormalizationProcessor(final String fileName, final String modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR); loadModel(modelId); addDocuments(getIndexNameForTest(), false); - validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName); - validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100)); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder); + hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault()); + validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder); } finally { wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName); } @@ -98,15 +102,10 @@ private void createSearchPipeline(final String pipelineName) { ); } - private void validateTestIndex(final String modelId, final String index, final String searchPipeline) { - validateTestIndex(modelId, index, searchPipeline, null); - } - - private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map methodParameters) { + private void validateTestIndex(final String index, final String searchPipeline, HybridQueryBuilder queryBuilder) { int docCount = getDocCount(index); assertEquals(6, docCount); - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters); - Map searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline)); + Map searchResponseAsMap = search(index, queryBuilder, null, 1, Map.of("search_pipeline", searchPipeline)); assertNotNull(searchResponseAsMap); int hits = getHitCount(searchResponseAsMap); assertEquals(1, hits); @@ -116,7 +115,7 @@ private void validateTestIndex(final String modelId, final String index, final S } } - private HybridQueryBuilder getQueryBuilder(final String modelId, Map methodParameters) { + private HybridQueryBuilder getQueryBuilder(final String modelId, Map methodParameters, RescoreContext rescoreContext) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); neuralQueryBuilder.fieldName("passage_embedding"); neuralQueryBuilder.modelId(modelId); @@ -125,6 +124,9 @@ private HybridQueryBuilder getQueryBuilder(final String modelId, Map if (methodParameters != null) { neuralQueryBuilder.methodParameters(methodParameters); } + if (rescoreContext != null) { + neuralQueryBuilder.rescoreContext(rescoreContext); + } MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY); diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index ece2bbb9e..d0994e711 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -61,6 +61,7 @@ private void validateIndexQuery(final String modelId) { 0.01f, null, null, + null, null ); Map responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1); @@ -76,6 +77,7 @@ private void validateIndexQuery(final String modelId) { null, null, null, + null, null ); Map responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index 54d993b35..f35227041 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -63,6 +63,7 @@ private void validateTestIndex(final String modelId) throws Exception { null, null, null, + null, null ); Map response = search(getIndexNameForTest(), neuralQueryBuilder, 1); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index ba2ff7979..eeae7f7dd 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -16,6 +16,8 @@ import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; + +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -59,11 +61,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr int totalDocsCountMixed; if (isFirstMixedRound()) { totalDocsCountMixed = NUM_DOCS_PER_ROUND; - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder); addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); } else { totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder); } break; case UPGRADED: @@ -72,8 +76,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND; loadModel(modelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100)); + HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder); + hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault()); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder); } finally { wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); } @@ -83,16 +89,11 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr } } - private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception { - validateTestIndexOnUpgrade(numberOfDocs, modelId, null); - } - - private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map methodParameters) + private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder) throws Exception { int docCount = getDocCount(getIndexNameForTest()); assertEquals(numberOfDocs, docCount); loadModel(modelId); - HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters); Map searchResponseAsMap = search( getIndexNameForTest(), hybridQueryBuilder, @@ -109,7 +110,11 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod } } - private HybridQueryBuilder getQueryBuilder(final String modelId, final Map methodParameters) { + private HybridQueryBuilder getQueryBuilder( + final String modelId, + final Map methodParameters, + final RescoreContext rescoreContext + ) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); neuralQueryBuilder.fieldName("passage_embedding"); neuralQueryBuilder.modelId(modelId); @@ -118,6 +123,9 @@ private HybridQueryBuilder getQueryBuilder(final String modelId, final Map responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1); @@ -102,6 +103,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo null, null, null, + null, null ); Map responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index 8e0ff7568..72976770d 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -86,6 +86,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod null, null, null, + null, null ); Map responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 8e1b6b36b..915a79117 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -8,6 +8,7 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion; import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport; import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch; @@ -40,6 +41,8 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.query.parser.MethodParametersParser; +import org.opensearch.knn.index.query.parser.RescoreParser; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.common.MinClusterVersionUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -101,6 +104,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private Supplier vectorSupplier; private QueryBuilder filter; private Map methodParameters; + private RescoreContext rescoreContext; /** * Constructor from stream input @@ -131,6 +135,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) { this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion); } + this.rescoreContext = RescoreParser.streamInput(in); } @Override @@ -156,6 +161,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) { MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion); } + RescoreParser.streamOutput(out, rescoreContext); } @Override @@ -181,6 +187,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (Objects.nonNull(methodParameters)) { MethodParametersParser.doXContent(xContentBuilder, methodParameters); } + if (Objects.nonNull(rescoreContext)) { + RescoreParser.doXContent(xContentBuilder, rescoreContext); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -276,6 +285,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n neuralQueryBuilder.filter(parseInnerQueryBuilder(parser)); } else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser)); + } else if (RESCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.rescoreContext(RescoreParser.fromXContent(parser)); } } else { throw new ParsingException( @@ -308,6 +319,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { .maxDistance(maxDistance) .minScore(minScore) .k(k) + .methodParameters(methodParameters) + .rescoreContext(rescoreContext) .build(); } @@ -335,7 +348,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { minScore(), vectorSetOnce::get, filter(), - methodParameters() + methodParameters(), + rescoreContext() ); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 7477fe63b..0fe9a77d6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -97,6 +97,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { null, null, null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -148,6 +149,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu null, null, null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -188,6 +190,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() { null, null, null, + null, null ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index ad2460103..4ddf8d2c3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -249,7 +249,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -299,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); hybridQueryBuilderDefaultNorm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -324,7 +324,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); hybridQueryBuilderL2Norm.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 7700c9f6a..d851131f1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -85,7 +85,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -110,7 +110,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -135,7 +135,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -185,7 +185,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); hybridQueryBuilderArithmeticMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -210,7 +210,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); hybridQueryBuilderHarmonicMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); @@ -235,7 +235,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); hybridQueryBuilderGeometricMean.add( - new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null) + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null) ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 638a34a3c..8fd87f091 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -124,6 +124,7 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws null, null, null, + null, null ); QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 9ecb93b81..e0c51e106 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -14,6 +14,8 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_OVERSAMPLE_FIELD; import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; @@ -52,6 +54,7 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.common.VectorUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; @@ -136,10 +139,52 @@ public void testFromXContent_withMethodParameters_thenBuildSuccessfully() { contentParser.nextToken(); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + assertEquals(Map.of("ef_search", 1000), neuralQueryBuilder.methodParameters()); assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(K, neuralQueryBuilder.k()); + assertNull(neuralQueryBuilder.rescoreContext()); + } + + @SneakyThrows + public void testFromXContent_withRescoreContext_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_image": "string", + "model_id": "string", + "k": int, + "rescore": { + "oversample_factor" : int + } + } + } + */ + setUpClusterService(Version.V_2_10_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .startObject(RESCORE_FIELD.getPreferredName()) + .field(RESCORE_OVERSAMPLE_FIELD.getPreferredName(), 1) + .endObject() + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(K, neuralQueryBuilder.k()); + assertEquals(RescoreContext.getDefault(), neuralQueryBuilder.rescoreContext()); + assertNull(neuralQueryBuilder.methodParameters()); } @SneakyThrows @@ -679,13 +724,20 @@ public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) + .methodParameters(Map.of("ef_search", 100)) + .rescoreContext(RescoreContext.getDefault()) .vectorSupplier(TEST_VECTOR_SUPPLIER); + + KNNQueryBuilder expected = KNNQueryBuilder.builder() + .k(K) + .fieldName(neuralQueryBuilder.fieldName()) + .methodParameters(neuralQueryBuilder.methodParameters()) + .rescoreContext(neuralQueryBuilder.rescoreContext()) + .vector(TEST_VECTOR_SUPPLIER.get()) + .build(); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); - assertTrue(queryBuilder instanceof KNNQueryBuilder); - KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; - assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); - assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK()); - assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); + assertEquals(expected, queryBuilder); } public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index 0e5d86e72..210abd7ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -18,6 +18,7 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.primitives.Floats; @@ -111,6 +112,7 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { null, null, null, + null, null ); @@ -133,7 +135,8 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { null, null, null, - Map.of("ef_search", 10) + Map.of("ef_search", 10), + RescoreContext.getDefault() ); Map searchResponseAsMapMultimodalQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilderMultimodalQuery, 1); Map firstInnerHitMultimodalQuery = getFirstInnerHit(searchResponseAsMapMultimodalQuery); @@ -160,6 +163,7 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { null, null, null, + null, null ); @@ -189,6 +193,7 @@ public void testQueryWithBoostAndImageQueryAndRadialQuery() { 0.01f, null, null, + null, null ); @@ -244,6 +249,7 @@ public void testRescoreQuery() { null, null, null, + null, null ); @@ -322,6 +328,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null, null, null, + null, null ); NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( @@ -334,6 +341,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null, null, null, + null, null ); @@ -362,6 +370,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null, null, null, + null, null ); @@ -418,6 +427,7 @@ public void testNestedQuery() { null, null, null, + null, null ); @@ -469,6 +479,7 @@ public void testFilterQuery() { null, null, new MatchQueryBuilder("_id", "3"), + null, null ); Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, neuralQueryBuilder, 3);