Skip to content

Commit

Permalink
Explainability in hybrid query (#970)
Browse files Browse the repository at this point in the history
* Added Explainability support for hybrid query

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Nov 15, 2024
1 parent 0119907 commit a3fd3a2
Show file tree
Hide file tree
Showing 41 changed files with 2,461 additions and 95 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
Expand Down Expand Up @@ -80,6 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
public static final String EXPLANATION_RESPONSE_KEY = "explanation_response";

@Override
public Collection<Object> createComponents(
Expand Down Expand Up @@ -181,7 +184,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
) {
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
ExplanationResponseProcessor.TYPE,
new ExplanationResponseProcessorFactory()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import lombok.Setter;
import lombok.ToString;
import lombok.extern.log4j.Log4j2;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.query.QuerySearchResult;

/**
* Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store
Expand All @@ -37,15 +39,23 @@ public class CompoundTopDocs {
private List<TopDocs> topDocs;
@Setter
private List<ScoreDoc> scoreDocs;
@Getter
private SearchShard searchShard;

public CompoundTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs, final boolean isSortEnabled) {
initialize(totalHits, topDocs, isSortEnabled);
public CompoundTopDocs(
final TotalHits totalHits,
final List<TopDocs> topDocs,
final boolean isSortEnabled,
final SearchShard searchShard
) {
initialize(totalHits, topDocs, isSortEnabled, searchShard);
}

private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled) {
private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSortEnabled, SearchShard searchShard) {
this.totalHits = totalHits;
this.topDocs = topDocs;
scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled);
this.searchShard = searchShard;
}

/**
Expand All @@ -72,14 +82,17 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSo
* 6, 0.15
* 0, 9549511920.4881596047
*/
public CompoundTopDocs(final TopDocs topDocs) {
public CompoundTopDocs(final QuerySearchResult querySearchResult) {
final TopDocs topDocs = querySearchResult.topDocs().topDocs;
final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
boolean isSortEnabled = false;
if (topDocs instanceof TopFieldDocs) {
isSortEnabled = true;
}
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) {
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled);
initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled, searchShard);
return;
}
// skipping first two elements, it's a start-stop element and delimiter for first series
Expand All @@ -103,7 +116,7 @@ public CompoundTopDocs(final TopDocs topDocs) {
scoreDocList.add(scoreDoc);
}
}
initialize(topDocs.totalHits, topDocsList, isSortEnabled);
initialize(topDocs.totalHits, topDocsList, isSortEnabled, searchShard);
}

private List<ScoreDoc> cloneLargestScoreDocs(final List<TopDocs> docs, boolean isSortEnabled) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR;

/**
* Processor to add explanation details to search response
*/
@Getter
@AllArgsConstructor
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "explanation_response_processor";

private final String description;
private final String tag;
private final boolean ignoreFailure;

/**
* Add explanation details to search response if it is present in request context
*/
@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
return processResponse(request, response, null);
}

/**
* Combines explanation from processor with search hits level explanations and adds it to search response
*/
@Override
public SearchResponse processResponse(
final SearchRequest request,
final SearchResponse response,
final PipelineProcessingContext requestContext
) {
if (Objects.isNull(requestContext)
|| (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) {
return response;
}
// Extract explanation payload from context
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
// for score normalization, processor level explanations will be sorted in scope of each shard,
// and we are merging both into a single sorted list
SearchHits searchHits = response.getHits();
SearchHit[] searchHitsArray = searchHits.getHits();
// create a map of searchShard and list of indexes of search hit objects in search hits array
// the list will keep original order of sorting as per final search results
Map<SearchShard, List<Integer>> searchHitsByShard = new HashMap<>();
// we keep index for each shard, where index is a position in searchHitsByShard list
Map<SearchShard, Integer> explainsByShardCount = new HashMap<>();
// Build initial shard mappings
for (int i = 0; i < searchHitsArray.length; i++) {
SearchHit searchHit = searchHitsArray[i];
SearchShardTarget searchShardTarget = searchHit.getShard();
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
explainsByShardCount.putIfAbsent(searchShard, -1);
}
// Process normalization details if available in correct format
if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) {
@SuppressWarnings("unchecked")
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map<
SearchShard,
List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);
// Process each search hit to add processor level explanations
for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
// Extract various explanation components
Explanation queryLevelExplanation = searchHit.getExplanation();
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
// Create normalized explanations for each detail
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.getScoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.getScoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
);
}
// Create and set final explanation combining all components
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
);
searchHit.explanation(finalExplanation);
explainsByShardCount.put(searchShard, explanationIndexByShard);
}
}
}
return response;
}

@Override
public String getType() {
return TYPE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

Expand All @@ -43,22 +44,57 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with normalization processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
&& searchPhaseContext.getRequest().source().explain();
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.build();
normalizationWorkflow.execute(request);
}

@Override
Expand Down
Loading

0 comments on commit a3fd3a2

Please sign in to comment.