Skip to content

Commit

Permalink
[Backport] manually backport 1013 to 2.x (#1028)
Browse files Browse the repository at this point in the history
* Support of new k-NN query parameter expand_nested.

Signed-off-by: Bo Zhang <bzhangam@amazon.com>
(cherry picked from commit fa149d4)

* Remove mistakenly added code from HybridSearchIT.

Signed-off-by: Bo Zhang <bzhangam@amazon.com>
(cherry picked from commit eaa7779)
  • Loading branch information
bzhangam authored Dec 18, 2024
1 parent 6481b60 commit 0cffa25
Show file tree
Hide file tree
Showing 14 changed files with 230 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import java.util.Map;

import org.opensearch.index.query.MatchQueryBuilder;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand Down Expand Up @@ -71,9 +74,9 @@ private void validateNormalizationProcessor(final String fileName, final String
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
Expand Down Expand Up @@ -115,12 +118,20 @@ private void validateTestIndex(final String index, final String searchPipeline,
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters, RescoreContext rescoreContext) {
private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Boolean expandNestedDocs,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (expandNestedDocs != null) {
neuralQueryBuilder.expandNested(expandNestedDocs);
}
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private void validateIndexQuery(final String modelId) {
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -78,6 +79,7 @@ private void validateIndexQuery(final String modelId) {
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ private void validateTestIndex(final String modelId) throws Exception {
null,
null,
null,
null,
null
);
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.opensearch.index.query.MatchQueryBuilder;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand All @@ -17,6 +19,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;

import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
Expand All @@ -31,6 +35,8 @@ public class HybridSearchIT extends AbstractRollingUpgradeTestCase {
private static final String TEXT_UPGRADED = "Hi earth";
private static final String QUERY = "Hi world";
private static final int NUM_DOCS_PER_ROUND = 1;
private static final String VECTOR_EMBEDDING_FIELD = "passage_embedding";
protected static final String RESCORE_QUERY = "hi";
private static String modelId = "";

// Test rolling-upgrade normalization processor when index with multiple shards
Expand Down Expand Up @@ -61,13 +67,14 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
}
break;
case UPGRADED:
Expand All @@ -76,10 +83,11 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
hybridQueryBuilder = getQueryBuilder(modelId, Boolean.FALSE, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -89,15 +97,19 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder)
throws Exception {
private void validateTestIndexOnUpgrade(
final int numberOfDocs,
final String modelId,
HybridQueryBuilder hybridQueryBuilder,
QueryBuilder rescorer
) throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
null,
rescorer,
1,
Map.of("search_pipeline", SEARCH_PIPELINE_NAME)
);
Expand All @@ -112,19 +124,23 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod

private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Boolean expandNestedDocs,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
final RescoreContext rescoreContextForNeuralQuery
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.fieldName(VECTOR_EMBEDDING_FIELD);
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (expandNestedDocs != null) {
neuralQueryBuilder.expandNested(expandNestedDocs);
}
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
if (Objects.nonNull(rescoreContextForNeuralQuery)) {
neuralQueryBuilder.rescoreContext(rescoreContextForNeuralQuery);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -104,6 +105,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
null,
null,
null,
null,
null
);
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.query;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
Expand Down Expand Up @@ -98,6 +99,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private Integer k = null;
private Float maxDistance = null;
private Float minScore = null;
private Boolean expandNested;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
Expand Down Expand Up @@ -132,6 +134,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.maxDistance = in.readOptionalFloat();
this.minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
this.expandNested = in.readOptionalBoolean();
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
Expand All @@ -158,6 +163,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(this.maxDistance);
out.writeOptionalFloat(this.minScore);
}
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
out.writeOptionalBoolean(this.expandNested);
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
Expand All @@ -184,6 +192,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(minScore)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (Objects.nonNull(expandNested)) {
xContentBuilder.field(EXPAND_NESTED_FIELD.getPreferredName(), expandNested);
}
if (Objects.nonNull(methodParameters)) {
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
}
Expand Down Expand Up @@ -274,6 +285,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.maxDistance(parser.floatValue());
} else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.minScore(parser.floatValue());
} else if (EXPAND_NESTED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.expandNested(parser.booleanValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -318,6 +331,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
Expand Down Expand Up @@ -346,6 +360,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
k(),
maxDistance(),
minScore(),
expandNested(),
vectorSetOnce::get,
filter(),
methodParameters(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -150,6 +151,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -191,6 +193,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,20 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -249,7 +262,20 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down Expand Up @@ -299,7 +325,20 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -324,7 +363,20 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_DOC_TEXT1,
"",
modelId,
5,
null,
null,
null,
null,
null,
null,
null
)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down
Loading

0 comments on commit 0cffa25

Please sign in to comment.