Skip to content

Commit

Permalink
Add some IT, optimized the get tokens logic in NeuralSparseQueryBuilder.
Browse files Browse the repository at this point in the history
Signed-off-by: conggguan <congguan@amazon.com>
  • Loading branch information
conggguan committed Apr 24, 2024
1 parent 31c349e commit feb9606
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 48 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.13...2.x)
### Features
- Support k-NN radial search parameters in neural search([#697](https://github.com/opensearch-project/neural-search/pull/697))
- Enhance neural_sparse query's latency performance with two-phase rescore query([#695](https://github.com/opensearch-project/neural-search/pull/695/files)).
- Enhance neural_sparse query's latency performance with two-phase rescore query([#695](https://github.com/opensearch-project/neural-search/pull/695)).
### Enhancements
- BWC tests for text chunking processor ([#661](https://github.com/opensearch-project/neural-search/pull/661))
- Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,25 +340,23 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
protected Query doToQuery(QueryShardContext context) throws IOException {
final MappedFieldType ft = context.fieldMapper(fieldName);
validateFieldType(ft);
Map<String, Float> allTokens = getAllTokens();
Query allTokenQuery = buildFeatureFieldQueryFromTokens(allTokens, fieldName);
if (!NeuralSparseTwoPhaseParameters.isEnabled(neuralSparseTwoPhaseParameters)) {
return buildFeatureFieldQueryFromTokens(getAllTokens(), fieldName);
return allTokenQuery;
}
// in the last step we make sure neuralSparseTwoPhaseParameters is not null
float ratio = neuralSparseTwoPhaseParameters.pruning_ratio();
Map<String, Float> highScoreTokens = getHighScoreTokens(ratio);
Map<String, Float> lowScoreTokens = getLowScoreTokens(ratio);
Map<String, Float> allTokens = getAllTokens();
Query allTokenQuery = buildFeatureFieldQueryFromTokens(allTokens, fieldName);
Map<String, Float> highScoreTokens = getHighScoreTokens(allTokens, ratio);
Map<String, Float> lowScoreTokens = getLowScoreTokens(allTokens, ratio);
// if all token are valid score that we don't need the two-phase optimize, return allTokenQuery.
if (lowScoreTokens.isEmpty()) {
return allTokenQuery;
}
Query highScoreTokenQuery = buildFeatureFieldQueryFromTokens(highScoreTokens, fieldName);
Query lowScoreTokenQuery = buildFeatureFieldQueryFromTokens(lowScoreTokens, fieldName);
return new NeuralSparseQuery(
allTokenQuery,
highScoreTokenQuery,
lowScoreTokenQuery,
buildFeatureFieldQueryFromTokens(highScoreTokens, fieldName),
buildFeatureFieldQueryFromTokens(lowScoreTokens, fieldName),
neuralSparseTwoPhaseParameters.window_size_expansion()
);
}
Expand Down Expand Up @@ -439,17 +437,15 @@ private Map<String, Float> getAllTokens() {
return queryTokens;
}

private Map<String, Float> getHighScoreTokens(float ratio) {
return getFilteredScoreTokens(true, ratio);
private Map<String, Float> getHighScoreTokens(Map<String, Float> queryTokens, float ratio) {
return getFilteredScoreTokens(queryTokens, true, ratio);
}

private Map<String, Float> getLowScoreTokens(float ratio) {
return getFilteredScoreTokens(false, ratio);
private Map<String, Float> getLowScoreTokens(Map<String, Float> queryTokens, float ratio) {
return getFilteredScoreTokens(queryTokens, false, ratio);
}

private Map<String, Float> getFilteredScoreTokens(boolean aboveThreshold, float ratio) {
Map<String, Float> queryTokens = queryTokensSupplier.get();
validateQueryTokens(queryTokens);
private Map<String, Float> getFilteredScoreTokens(Map<String, Float> queryTokens, boolean aboveThreshold, float ratio) {
float max = queryTokens.values().stream().max(Float::compare).orElse(0f);
float threshold = ratio * max;
if (max == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {

private static final String FIELD_NAME = "testField";
private static final String QUERY_TEXT = "Hello world!";
private static final String QUERY_TEXT_LONG_VERSION =
"The ID of the sparse encoding model or tokenizer model that will be used to generate vector embeddings from the query text. The model must be deployed in OpenSearch before it can be used in sparse neural search. For more information, see Using custom models within OpenSearch and Neural sparse search.";
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
private static final float BOOST = 1.8f;
private static final String QUERY_NAME = "queryName";
Expand Down Expand Up @@ -505,9 +503,9 @@ public void testFromXContent_whenBuiltWithEmptyTwoPhaseParams_thenThrowException
"query_text": "string",
"model_id": "string",
"two_phase_settings":{
"window_size_expansion": 5,
"pruning_ratio": 0.4,
"enabled": false
"window_size_expansion": null,
"pruning_ratio": null,
"enabled": null
}
}
}
Expand Down Expand Up @@ -996,30 +994,26 @@ public void testTokenDividedByScores_whenDefaultSettings() {
@SneakyThrows
public void testDoToQuery_whenTwoPhaseParaDisabled_thenDegradeSuccess() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.maxTokenScore(MAX_TOKEN_SCORE)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER)
.neuralSparseTwoPhaseParameters(
new NeuralSparseTwoPhaseParameters().enabled(false).pruning_ratio(0.4f).window_size_expansion(6.0f)
new NeuralSparseTwoPhaseParameters().enabled(false).pruning_ratio(0.7f).window_size_expansion(6.0f)
);
Query query = sparseEncodingQueryBuilder.doToQuery(mockQueryShardContext);
assertTrue(query instanceof BooleanQuery);
List<BooleanClause> booleanClauseList = ((BooleanQuery) query).clauses();
assertEquals(2, ((BooleanQuery) query).clauses().size());
BooleanClause firstClause = booleanClauseList.get(0);
BooleanClause secondClause = booleanClauseList.get(1);

Query firstFeatureQuery = firstClause.getQuery();
assertEquals(firstFeatureQuery, FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f));
Query secondFeatureQuery = secondClause.getQuery();
assertEquals(secondFeatureQuery, FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f));
List<Query> actualQueries = booleanClauseList.stream().map(BooleanClause::getQuery).collect(Collectors.toList());
Query expectedQuery1 = FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f);
Query expectedQuery2 = FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f);
assertTrue("Expected query for 'world' not found", actualQueries.contains(expectedQuery1));
assertTrue("Expected query for 'hello' not found", actualQueries.contains(expectedQuery2));
}

@SneakyThrows
public void testDoToQuery_whenTwoPhaseParaEmpty_thenDegradeSuccess() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.maxTokenScore(MAX_TOKEN_SCORE)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
Expand All @@ -1036,6 +1030,29 @@ public void testDoToQuery_whenTwoPhaseParaEmpty_thenDegradeSuccess() {
assertEquals(secondFeatureQuery, FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f));
}

@SneakyThrows
public void testDoToQuery_whenTwoPhaseEnabled_thenBuildCorrectQuery() {
Map<String, Float> map = new HashMap<>();
for (int i = 1; i < 3; i++) {
map.put(String.valueOf(i), (float) i);
}
final Supplier<Map<String, Float>> tokenSupplier = () -> map;
// token with score [1.0,2.0] will build degrade to allTokenQuery
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(tokenSupplier)
.neuralSparseTwoPhaseParameters(NeuralSparseTwoPhaseParameters.getDefaultSettings());
Query allTokenQuery = sparseEncodingQueryBuilder.doToQuery(mockQueryShardContext);
assertTrue(allTokenQuery instanceof BooleanQuery);
assertEquals(((BooleanQuery) allTokenQuery).clauses().size(), 2);
map.put("Temp", 9.f);
// token with score [1.0,2.0,9.0] will build a NeuralSparseQuery whose lowTokenQuery including [1.0,2.0]
Query query = sparseEncodingQueryBuilder.doToQuery(mockQueryShardContext);
assertTrue(query instanceof NeuralSparseQuery);
assertEquals(((NeuralSparseQuery) query).getLowScoreTokenQuery(), allTokenQuery);
}

@SneakyThrows
public void testDoToQuery_successfulDoToQuery() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
Expand Down
Loading

0 comments on commit feb9606

Please sign in to comment.