Skip to content

Commit

Permalink
Fixing multiple issues reported in #497 (#524) (#527)
Browse files Browse the repository at this point in the history
* Allow multiple identical sub-queries in hybrid query, removed validation for total hits


(cherry picked from commit 585fbbe)

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Jan 2, 2024
1 parent 63fe67f commit 31b3f66
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490))
- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498))
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,20 @@ private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchR
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
if (Objects.isNull(searchHitArray)) {
throw new IllegalStateException(
"score normalization processor cannot produce final query result, fetch query phase returns empty results"
);
}
if (searchHitArray.length != docIds.size()) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]",
searchHitArray.length,
docIds.size()
)
);
}
return searchHitArray;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -37,7 +39,7 @@ public final class HybridQueryScorer extends Scorer {

private final float[] subScores;

private final Map<Query, Integer> queryToIndex;
private final Map<Query, List<Integer>> queryToIndex;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
super(weight);
Expand Down Expand Up @@ -111,24 +113,43 @@ public float[] hybridScores() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
Scorer scorer = disiWrapper.scorer;
if (scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float subScore = disiWrapper.scorer.score();
scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore;
Query query = scorer.getWeight().getQuery();
List<Integer> indexes = queryToIndex.get(query);
// we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0"
int index = indexes.stream()
.mapToInt(idx -> idx)
.filter(idx -> Float.compare(scores[idx], 0.0f) == 0)
.findFirst()
.orElseThrow(
() -> new IllegalStateException(
String.format(
Locale.ROOT,
"cannot set score for one of hybrid search subquery [%s] and document [%d]",
query.toString(),
scorer.docID()
)
)
);
scores[index] = scorer.score();
}
return scores;
}

private Map<Query, Integer> mapQueryToIndex() {
Map<Query, Integer> queryToIndex = new HashMap<>();
private Map<Query, List<Integer>> mapQueryToIndex() {
Map<Query, List<Integer>> queryToIndex = new HashMap<>();
int idx = 0;
for (Scorer scorer : subScorers) {
if (scorer == null) {
idx++;
continue;
}
queryToIndex.put(scorer.getWeight().getQuery(), idx);
Query query = scorer.getWeight().getQuery();
queryToIndex.putIfAbsent(query, new ArrayList<>());
queryToIndex.get(query).add(idx);
idx++;
}
return queryToIndex;
Expand All @@ -137,7 +158,9 @@ private Map<Query, Integer> mapQueryToIndex() {
private DisiPriorityQueue initializeSubScorersPQ() {
Objects.requireNonNull(queryToIndex, "should not be null");
Objects.requireNonNull(subScorers, "should not be null");
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(queryToIndex.size());
// we need to count this way in order to include all identical sub-queries
int numOfSubQueries = queryToIndex.values().stream().map(List::size).reduce(0, Integer::sum);
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries);
for (Scorer scorer : subScorers) {
if (scorer == null) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ public HitsThresholdChecker(int totalHitsThreshold) {
if (totalHitsThreshold < 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be >= 0, got %d", totalHitsThreshold));
}
if (totalHitsThreshold == Integer.MAX_VALUE) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be less than max integer value"));
}
this.totalHitsThreshold = totalHitsThreshold;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor;

import static org.hamcrest.Matchers.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand All @@ -15,6 +16,7 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -45,9 +47,13 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -325,4 +331,172 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul

verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
NormalizationProcessor normalizationProcessor = new NormalizationProcessor(
PROCESSOR_TAG,
DESCRIPTION,
new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD),
new ScoreCombinationFactory().createCombination(COMBINATION_METHOD),
normalizationProcessorWorkflow
);

SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
searchRequest.setBatchedReduceSize(4);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev);
return curr;
})
);
CountDownLatch partialReduceLatch = new CountDownLatch(5);
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),

new ScoreDoc[] {
createStartStopElementForHybridSearchResults(4),
createDelimiterElementForHybridSearchResults(4),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(4) }

);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);

FetchSearchResult fetchSearchResult = new FetchSearchResult();
fetchSearchResult.setShardIndex(shardId);
fetchSearchResult.setSearchShardTarget(searchShardTarget);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(0, "10", Map.of(), Map.of()),
new SearchHit(2, "1", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(10, "3", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext);

List<QuerySearchResult> querySearchResults = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());

TestUtils.assertQueryResultScores(querySearchResults);
verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
NormalizationProcessor normalizationProcessor = new NormalizationProcessor(
PROCESSOR_TAG,
DESCRIPTION,
new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD),
new ScoreCombinationFactory().createCombination(COMBINATION_METHOD),
normalizationProcessorWorkflow
);

SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
searchRequest.setBatchedReduceSize(4);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev);
return curr;
})
);
CountDownLatch partialReduceLatch = new CountDownLatch(5);
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),

new ScoreDoc[] {
createStartStopElementForHybridSearchResults(4),
createDelimiterElementForHybridSearchResults(4),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(4) }

);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);

FetchSearchResult fetchSearchResult = new FetchSearchResult();
fetchSearchResult.setShardIndex(shardId);
fetchSearchResult.setSearchShardTarget(searchShardTarget);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(0, "10", Map.of(), Map.of()),
new SearchHit(2, "1", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(10, "3", Map.of(), Map.of()),
new SearchHit(0, "10", Map.of(), Map.of()), };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
IllegalStateException exception = expectThrows(
IllegalStateException.class,
() -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext)
);
org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
startsWith("score normalization processor cannot produce final query result")
);
}
}
Loading

0 comments on commit 31b3f66

Please sign in to comment.