Skip to content

Commit

Permalink
Support for post filter in hybrid query (#633)
Browse files Browse the repository at this point in the history
* Post Filter for hybrid query

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Add changelog

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Addressing martin comments

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Addressing martin comments

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Addressing navneet comments

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Addressing navneet comments

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Addressing navneet comments

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding Coverage

Signed-off-by: Varun Jain <varunudr@amazon.com>

---------

Signed-off-by: Varun Jain <varunudr@amazon.com>
  • Loading branch information
vibrantvarun authored Mar 14, 2024
1 parent 759a971 commit d2d4cc6
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630))
- Support for post filter in hybrid query ([#633](https://github.com/opensearch-project/neural-search/pull/633))
### Bug Fixes
- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
Expand Down Expand Up @@ -46,6 +52,9 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
@Nullable
private final Weight filterWeight;
private static final float boost_factor = 1f;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -60,27 +69,44 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Weight filteringWeight = null;
// Check for post filter to create weight for filter query and later use that weight in the search workflow
if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) {
Query filterQuery = searchContext.parsedPostFilter().query();
ContextIndexSearcher searcher = searchContext.searcher();
// ScoreMode COMPLETE_NO_SCORES will be passed as post_filter does not contribute in scoring. COMPLETE_NO_SCORES means it is not
// a scoring clause
// Boost factor 1f is taken because if boost is multiplicative of 1 then it means "no boost"
// Previously this code in OpenSearch looked like
// https://github.com/opensearch-project/OpenSearch/commit/36a5cf8f35e5cbaa1ff857b5a5db8c02edc1a187
filteringWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, boost_factor);
}

return searchContext.shouldUseConcurrentSearch()
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort()
searchContext.sort(),
filteringWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort()
searchContext.sort(),
filteringWeight
);
}

@Override
public Collector newCollector() {
Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker);
return hybridcollector;
// Check if filterWeight is present. If it is present then return wrap Hybrid collector object underneath the FilteredCollector
// object and return it.
return Objects.nonNull(filterWeight) ? new FilteredCollector(hybridcollector, filterWeight) : hybridcollector;
}

/**
Expand Down Expand Up @@ -108,7 +134,10 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector);
}
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
Expand Down Expand Up @@ -216,9 +245,10 @@ public HybridCollectorNonConcurrentManager(
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats);
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
}

Expand All @@ -245,9 +275,10 @@ public HybridCollectorConcurrentSearchManager(
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats);
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lombok.SneakyThrows;

import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
Expand Down Expand Up @@ -46,6 +47,7 @@ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT {
"test-neural-aggs-pipeline-multi-doc-index-multiple-shards";
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard";
private static final String TEST_QUERY_TEXT3 = "hello";
private static final String TEST_QUERY_TEXT4 = "everyone";
private static final String TEST_QUERY_TEXT5 = "welcome";
private static final String TEST_DOC_TEXT1 = "Hello world";
private static final String TEST_DOC_TEXT2 = "Hi to this place";
Expand Down Expand Up @@ -182,6 +184,204 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu
}
}

@SneakyThrows
public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
testPostFilterWithSimpleHybridQuery(false, true);
testPostFilterWithComplexHybridQuery(false, true);
}

@SneakyThrows
public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
testPostFilterWithSimpleHybridQuery(false, true);
testPostFilterWithComplexHybridQuery(false, true);
}

@SneakyThrows
private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean hasPostFilterQuery) {
try {
if (isSingleShard) {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
} else {
prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);
}

HybridQueryBuilder simpleHybridQueryBuilder = createHybridQueryBuilder(false);

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000);

Map<String, Object> searchResponseAsMap;

if (isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);

assertHitResultsFromQuery(1, searchResponseAsMap);
} else if (isSingleShard && !hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
simpleHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}

// assert post-filter
List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);

List<Integer> docIndexes = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
assertNotNull(oneHit.get("_source"));
Map<String, Object> source = (Map<String, Object>) oneHit.get("_source");
int docIndex = (int) source.get(INTEGER_FIELD_1);
docIndexes.add(docIndex);
}
if (isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (isSingleShard && !hasPostFilterQuery) {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (!isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
} else {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
}
} finally {
if (isSingleShard) {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE);
} else {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}
}

@SneakyThrows
private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean hasPostFilterQuery) {
try {
if (isSingleShard) {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
} else {
prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);
}

HybridQueryBuilder complexHybridQueryBuilder = createHybridQueryBuilder(true);

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000);

Map<String, Object> searchResponseAsMap;

if (isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);

assertHitResultsFromQuery(1, searchResponseAsMap);
} else if (isSingleShard && !hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
rangeFilterQuery
);
assertHitResultsFromQuery(4, searchResponseAsMap);
} else {
searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS,
complexHybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}

// assert post-filter
List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);

List<Integer> docIndexes = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
assertNotNull(oneHit.get("_source"));
Map<String, Object> source = (Map<String, Object>) oneHit.get("_source");
int docIndex = (int) source.get(INTEGER_FIELD_1);
docIndexes.add(docIndex);
}
if (isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (isSingleShard && !hasPostFilterQuery) {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());

} else if (!isSingleShard && hasPostFilterQuery) {
assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
} else {
assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count());
}
} finally {
if (isSingleShard) {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE);
} else {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}
}

@SneakyThrows
private void testAvgSumMinMaxAggs() {
try {
Expand Down Expand Up @@ -227,6 +427,20 @@ private void testAvgSumMinMaxAggs() {
}
}

@SneakyThrows
public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", false);
testPostFilterWithSimpleHybridQuery(true, true);
testPostFilterWithComplexHybridQuery(true, true);
}

@SneakyThrows
public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings("search.concurrent_segment_search.enabled", true);
testPostFilterWithSimpleHybridQuery(true, true);
testPostFilterWithComplexHybridQuery(true, true);
}

private void testMaxAggsOnSingleShardCluster() throws Exception {
try {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
Expand Down Expand Up @@ -594,4 +808,29 @@ private void assertHitResultsFromQuery(int expected, Map<String, Object> searchR
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

private HybridQueryBuilder createHybridQueryBuilder(boolean isComplex) {
if (isComplex) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should().add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000);

QueryBuilder matchQuery = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(boolQueryBuilder).add(rangeFilterQuery).add(matchQuery);
return hybridQueryBuilder;

} else {
TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder3);
return hybridQueryBuilderNeuralThenTerm;
}
}

}
Loading

0 comments on commit d2d4cc6

Please sign in to comment.