Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Nov 28, 2023
1 parent cc3bef8 commit 233a00f
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,6 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public void addSubQuery(final Query query) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
subQueries.add(query);
}

/**
* Create the Weight used to score this query
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public boolean searchWith(
query = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
validateHybridQuery(query);
validateQuery(query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

Expand Down Expand Up @@ -112,7 +112,7 @@ && mightBeWrappedHybridQuery(query)
return query;
}

private void validateHybridQuery(final Query query) {
private void validateQuery(final Query query) {
if (query instanceof BooleanQuery) {
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1";
private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2";
private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1";
private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user";

private static final int TEST_DIMENSION = 768;
private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2;
Expand Down Expand Up @@ -344,7 +345,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
indexName,
buildIndexConfiguration(
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
List.of("user"),
List.of(TEST_NESTED_TYPE_FIELD_NAME_1),
1
),
""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.neuralsearch.search.query;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
Expand All @@ -21,6 +23,7 @@
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import lombok.SneakyThrows;

Expand All @@ -30,6 +33,8 @@
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -38,11 +43,13 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.action.OriginalIndices;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -83,7 +90,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
Expand Down Expand Up @@ -126,6 +134,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand All @@ -152,7 +161,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
Expand Down Expand Up @@ -196,6 +206,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.queryResult()).thenReturn(new QuerySearchResult());
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand All @@ -205,8 +216,6 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()

Query query = termSubQuery.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
MapperService mapperService = mock(MapperService.class);
when(searchContext.mapperService()).thenReturn(mapperService);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down Expand Up @@ -315,7 +324,8 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
Expand Down Expand Up @@ -365,6 +375,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down Expand Up @@ -409,6 +420,223 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
MapperService mapperService = mock(MapperService.class);
when(mapperService.hasNested()).thenReturn(false);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));
TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);

BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3);

Query query = boolQueryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> hybridQueryPhaseSearcher.searchWith(
searchContext,
contextIndexSearcher,
query,
collectors,
hasFilterCollector,
hasTimeout
)
);

org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
containsString("hybrid query must be a top level query and cannot be wrapped into other queries")
);

releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);

MapperService mapperService = createMapperService(mapping(b -> {
b.startObject("field");
b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject();
b.endObject();
b.startObject("field2");
b.field("type", "nested");
b.endObject();
}));

/*MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject()
.field("type", "nested")
)
);*/

TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
// when(mapperService.hasNested()).thenReturn(true);
when(mockQueryShardContext.getMapperService()).thenReturn(mapperService);
when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME));

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
int docId2 = RandomizedTest.randomInt();
int docId3 = RandomizedTest.randomInt();
int docId4 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));

BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD)
.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER);
Query query = builder.build();

when(searchContext.query()).thenReturn(query);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

assertNotNull(querySearchResult.topDocs());
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
TopDocs topDocs = topDocsAndMaxScore.topDocs;
assertTrue(topDocs.totalHits.value > 0);
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
assertNotNull(scoreDocs);
assertTrue(scoreDocs.length > 0);
assertTrue(isHybridQueryStartStopElement(scoreDocs[0]));
assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1]));
List<TopDocs> compoundTopDocs = getSubQueryResultsForSingleShard(topDocs);
assertNotNull(compoundTopDocs);
assertEquals(2, compoundTopDocs.size());

TopDocs subQueryTopDocs1 = compoundTopDocs.get(0);
List<Integer> expectedIds1 = List.of(docId1);
assertQueryResults(subQueryTopDocs1, expectedIds1, reader);

TopDocs subQueryTopDocs2 = compoundTopDocs.get(1);
List<Integer> expectedIds2 = List.of();
assertQueryResults(subQueryTopDocs2, expectedIds2, reader);

releaseResources(directory, w, reader);
}

@SneakyThrows
private void assertQueryResults(TopDocs subQueryTopDocs, List<Integer> expectedDocIds, IndexReader reader) {
assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value);
Expand Down

0 comments on commit 233a00f

Please sign in to comment.