Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fixed Hybrid query for cases when it's wrapped into other compound queries (#498) #501

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)
### Infrastructure
### Documentation
### Maintenance
Expand All @@ -17,6 +16,8 @@ Fixed exception for case when Hybrid query being wrapped into bool query ([#490]
### Features
### Enhancements
### Bug Fixes
- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490))
- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.SeqNoFieldMapper;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
Expand Down Expand Up @@ -60,12 +67,120 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
if (isHybridQuery(query, searchContext)) {
Query hybridQuery = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

private boolean isHybridQuery(final Query query, final SearchContext searchContext) {
if (query instanceof HybridQuery) {
return true;
} else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) {
/* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code
https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370.
main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks
hybrid query for indexes with nested field types.
in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for
this search request.
below is sample structure of such query:

Boolean {
should: {
hybrid: {
sub_query1 {}
sub_query2 {}
}
}
filter: {
exists: {
field: "_primary_term"
}
}
}
TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */
// we have already checked if query in instance of Boolean in higher level else if condition
return ((BooleanQuery) query).clauses()
.stream()
.filter(clause -> clause.getQuery() instanceof HybridQuery == false)
.allMatch(clause -> {
return clause.getOccur() == BooleanClause.Occur.FILTER
&& clause.getQuery() instanceof FieldExistsQuery
&& SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField());
});
}
return false;
}

private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) {
return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query);
}

private boolean isWrappedHybridQuery(final Query query) {
return query instanceof BooleanQuery
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery);
}

private Query extractHybridQuery(final SearchContext searchContext, final Query query) {
if (hasNestedFieldOrNestedDocs(query, searchContext)
&& isWrappedHybridQuery(query)
&& ((BooleanQuery) query).clauses().size() > 0) {
// extract hybrid query and replace bool with hybrid query
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) {
throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query");
}
return booleanClauses.get(0).getQuery();
}
return query;
}

/**
* Validate the query from neural-search plugin point of view. Current main goal for validation is to block cases
* when hybrid query is wrapped into other compound queries.
* For example, if we have Bool query like below we need to throw an error
* bool: {
* should: [
* match: {},
* hybrid: {
* sub_query1 {}
* sub_query2 {}
* }
* ]
* }
* TODO add similar validation for other compound type queries like dis_max, constant_score etc.
*
* @param query query to validate
*/
private void validateQuery(final SearchContext searchContext, final Query query) {
if (query instanceof BooleanQuery) {
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), getMaxDepthLimit(searchContext));
}
}
}

private void validateNestedBooleanQuery(final Query query, final int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level <= 0) {
// ideally we should throw an error here but this code is on the main search workflow path and that might block
// execution of some queries. Instead, we're silently exit and allow such query to execute and potentially produce incorrect
// results in case hybrid query is wrapped into such bool query
log.error("reached max nested query limit, cannot process bool query with that many nested clauses");
return;
}
if (query instanceof BooleanQuery) {
for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) {
validateNestedBooleanQuery(booleanClause.getQuery(), level - 1);
}
}
}

@VisibleForTesting
protected boolean searchWithCollector(
final SearchContext searchContext,
Expand Down Expand Up @@ -209,4 +324,9 @@ private float getMaxScore(final List<TopDocs> topDocs) {
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}

private int getMaxDepthLimit(final SearchContext searchContext) {
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,18 @@ protected void addKnnDoc(String index, String docId, List<String> vectorFieldNam
addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList());
}

@SneakyThrows
protected void addKnnDoc(
String index,
String docId,
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
) {
addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList());
}

/**
* Add a set of knn vectors and text to an index
*
Expand All @@ -422,6 +434,8 @@ protected void addKnnDoc(String index, String docId, List<String> vectorFieldNam
* @param vectors List of vectors corresponding to those fields
* @param textFieldNames List of text fields to be added
* @param texts List of text corresponding to those fields
* @param nestedFieldNames List of nested fields to be added
* @param nestedFields List of fields and values corresponding to those fields
*/
@SneakyThrows
protected void addKnnDoc(
Expand All @@ -430,7 +444,9 @@ protected void addKnnDoc(
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
List<String> texts,
List<String> nestedFieldNames,
List<Map<String, String>> nestedFields
) {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
Expand All @@ -441,6 +457,16 @@ protected void addKnnDoc(
for (int i = 0; i < textFieldNames.size(); i++) {
builder.field(textFieldNames.get(i), texts.get(i));
}

for (int i = 0; i < nestedFieldNames.size(); i++) {
builder.field(nestedFieldNames.get(i));
builder.startObject();
Map<String, String> nestedValues = nestedFields.get(i);
for (Map.Entry<String, String> entry : nestedValues.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();

request.setJsonEntity(builder.toString());
Expand Down Expand Up @@ -523,7 +549,16 @@ protected boolean checkComplete(Map<String, Object> node) {
}

@SneakyThrows
private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs, int numberOfShards) {
protected String buildIndexConfiguration(final List<KNNFieldConfig> knnFieldConfigs, final int numberOfShards) {
return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards);
}

@SneakyThrows
protected String buildIndexConfiguration(
final List<KNNFieldConfig> knnFieldConfigs,
final List<String> nestedFields,
final int numberOfShards
) {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("settings")
Expand All @@ -544,6 +579,11 @@ private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs, int
.endObject()
.endObject();
}

for (String nestedField : nestedFields) {
xContentBuilder.startObject(nestedField).field("type", "nested").endObject();
}

xContentBuilder.endObject().endObject().endObject();
return xContentBuilder.toString();
}
Expand Down
Loading
Loading