Skip to content

Commit

Permalink
Fixed nested field case
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Nov 28, 2023
1 parent b3c73bd commit 31462ee
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public void addSubQuery(final Query query) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
subQueries.add(query);
}

/**
* Create the Weight used to score this query
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
import lombok.extern.log4j.Log4j2;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.index.mapper.SeqNoFieldMapper;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
Expand All @@ -48,24 +53,88 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

final static int MAX_NESTED_SUBQUERY_LIMIT = 20;

public HybridQueryPhaseSearcher() {
super();
}

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
final Query query,
Query query,
final LinkedList<QueryCollectorContext> collectors,
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (query instanceof HybridQuery) {
if (isHybridQuery(query, searchContext)) {
query = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
validateHybridQuery(query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

private boolean isHybridQuery(final Query query, final SearchContext searchContext) {
if (query instanceof HybridQuery) {
return true;
} else if (hasNestedFieldOrNestedDocs(query, searchContext) && mightBeWrappedHybridQuery(query)) {
BooleanQuery booleanQuery = (BooleanQuery) query;
return booleanQuery.clauses()
.stream()
.filter(clause -> clause.getQuery() instanceof HybridQuery == false)
.allMatch(
clause -> clause.getOccur() == BooleanClause.Occur.FILTER
&& clause.getQuery() instanceof FieldExistsQuery
&& SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField())
);
}
return false;
}

private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) {
return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query);
}

private boolean mightBeWrappedHybridQuery(final Query query) {
return query instanceof BooleanQuery
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery);
}

private Query extractHybridQuery(final SearchContext searchContext, final Query query) {
if (hasNestedFieldOrNestedDocs(query, searchContext)
&& mightBeWrappedHybridQuery(query)
&& ((BooleanQuery) query).clauses().size() > 0) {
// extract hybrid query and replace bool with hybrid query
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
return booleanClauses.stream().findFirst().get().getQuery();
}
return query;
}

private void validateHybridQuery(final Query query) {
if (query instanceof BooleanQuery) {
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), 1);
}
}
}

private void validateNestedBooleanQuery(final Query query, int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level >= MAX_NESTED_SUBQUERY_LIMIT) {
throw new IllegalStateException("reached max nested query limit, cannot process query");
}
if (query instanceof BooleanQuery) {
for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) {
validateNestedBooleanQuery(booleanClause.getQuery(), level + 1);
}
}
}

@VisibleForTesting
protected boolean searchWithCollector(
final SearchContext searchContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,16 @@ protected boolean checkComplete(Map<String, Object> node) {
}

@SneakyThrows
private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs, int numberOfShards) {
protected String buildIndexConfiguration(final List<KNNFieldConfig> knnFieldConfigs, final int numberOfShards) {
return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards);
}

@SneakyThrows
protected String buildIndexConfiguration(
final List<KNNFieldConfig> knnFieldConfigs,
final List<String> nestedFields,
final int numberOfShards
) {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("settings")
Expand All @@ -544,6 +553,11 @@ private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs, int
.endObject()
.endObject();
}

for (String nestedField : nestedFields) {
xContentBuilder.startObject(nestedField).field("type", "nested").endObject();
}

xContentBuilder.endObject().endObject().endObject();
return xContentBuilder.toString();
}
Expand Down
80 changes: 74 additions & 6 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.query;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;

Expand All @@ -21,6 +23,7 @@

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand All @@ -35,6 +38,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index";
private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index";
private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index";
private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD =
"test-neural-multi-doc-nested-type--single-shard-index";
private static final String TEST_QUERY_TEXT = "greetings";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final String TEST_QUERY_TEXT3 = "hello";
Expand Down Expand Up @@ -191,7 +196,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult(
}

@SneakyThrows
public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() {
public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenFail() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);

MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand All @@ -202,23 +207,71 @@ public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess()
MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder);

ResponseException exceptionNoNestedTypes = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, boolQueryBuilder, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE))
);

org.hamcrest.MatcherAssert.assertThat(
exceptionNoNestedTypes.getMessage(),
allOf(
containsString("hybrid query must be a top level query and cannot be wrapped into other queries"),
containsString("illegal_argument_exception")
)
);

initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD);

ResponseException exceptionQWithNestedTypes = expectThrows(
ResponseException.class,
() -> search(
TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD,
boolQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
)
);

org.hamcrest.MatcherAssert.assertThat(
exceptionQWithNestedTypes.getMessage(),
allOf(
containsString("hybrid query must be a top level query and cannot be wrapped into other queries"),
containsString("illegal_argument_exception")
)
);
}

@SneakyThrows
public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD);

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT);
TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(termQueryBuilder);
hybridQueryBuilderOnlyTerm.add(termQuery2Builder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
boolQueryBuilder,
TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD,
hybridQueryBuilderOnlyTerm,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertTrue(getHitCount(searchResponseAsMap) > 0);
assertEquals(0, getHitCount(searchResponseAsMap));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f);
assertEquals(0.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION);

Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertTrue((int) total.get("value") > 0);
assertEquals(0, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -284,6 +337,21 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
);
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
}

if (TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD.equals(indexName)
&& !indexExists(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD)) {
createIndexWithConfiguration(
indexName,
buildIndexConfiguration(
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
List.of("user"),
1
),
""
);

addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD);
}
}

private void addDocsToIndex(final String testMultiDocIndexName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -204,6 +205,8 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()

Query query = termSubQuery.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
MapperService mapperService = mock(MapperService.class);
when(searchContext.mapperService()).thenReturn(mapperService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand All @@ -217,7 +220,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
Expand Down Expand Up @@ -265,6 +269,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down

0 comments on commit 31462ee

Please sign in to comment.