Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] ByFieldRerank Processor (ReRankProcessor enhancement) #960

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ merged to main, the workflow will create a backport PR to the `2.x` branch.

## Building On Lucene Version Updates
There may be a Lucene version update that can affect your workflow causing errors like
`java.lang.NoClassDefFoundError: org/apache/lucene/codecs/lucene99/Lucene99Codec` or
`java.lang.NoClassDefFoundError: org/apache/lucene/codecs/lucene99/Lucene99Codec` or
`Provider org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec could not be instantiated`. In this case
we can observe there may be an issue with a dependency with [K-NN](https://github.com/opensearch-project/k-NN).
This results in having issues with not being able to do `./gradlew run` or `./gradlew build`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ public void testHybridQueryWithRescore_whenIndexWithMultipleShards_E2EFlow() thr
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME);
createIndexWithConfiguration(
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())),
PIPELINE_NAME
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())),
PIPELINE_NAME
);
addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null);
createSearchPipeline(
SEARCH_PIPELINE_NAME,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f }))
SEARCH_PIPELINE_NAME,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f }))
);
break;
case MIXED:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

Compatible with OpenSearch 2.18.0

### Features
- Introduces ByFieldRerankProcessor for second level reranking on documents ([#932](https://github.com/opensearch-project/neural-search/pull/932))

### Enhancements
- Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907))
- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import com.google.common.collect.Sets;
import lombok.AllArgsConstructor;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
Expand All @@ -22,9 +18,17 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import lombok.AllArgsConstructor;
import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_KEEP_PREVIOUS_SCORE;
import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_REMOVE_TARGET_FIELD;
import static org.opensearch.neuralsearch.processor.rerank.RerankProcessor.processorRequiresContext;

/**
* Factory for rerank processors. Must:
Expand All @@ -51,22 +55,55 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
clusterService
);

// Currently the createFetchers method requires that you provide a context map, this branch makes sure we can ignore this on
// processors that don't need the context map
List<ContextSourceFetcher> contextFetchers = processorRequiresContext(type)
? ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag, clusterService)
: Collections.emptyList();

Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());

switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
case BY_FIELD:
String targetField = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
ByFieldRerankProcessor.TARGET_FIELD
);
boolean removeTargetField = ConfigurationUtils.readBooleanProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
ByFieldRerankProcessor.REMOVE_TARGET_FIELD,
DEFAULT_REMOVE_TARGET_FIELD
);
boolean keepPreviousScore = ConfigurationUtils.readBooleanProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE,
DEFAULT_KEEP_PREVIOUS_SCORE
);

return new ByFieldRerankProcessor(
description,
tag,
ignoreFailure,
targetField,
removeTargetField,
keepPreviousScore,
contextFetchers
);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));
}
Expand Down Expand Up @@ -100,6 +137,7 @@ private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
*
* @param type the constructing RerankType
* @return does this RerankType depend on the QueryContextSourceFetcher?
*/
Expand All @@ -109,8 +147,8 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {

/**
* Create necessary queryContextFetchers for this processor
* @param config processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @param config Processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher Should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.rerank;

import lombok.extern.log4j.Log4j2;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils.SearchHitValidator;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getScoreFromSourceMap;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.removeTargetFieldFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria;

/**
* A reranking processor that reorders search results based on the content of a specified field.
* <p>
* The ByFieldRerankProcessor allows for reordering of search results by considering the content of a
* designated target field within each document. This processor will update the <code>_score</code> field with what has been provided
* by {@code target_field}. When {@code keep_previous_score} is enabled a new field is appended called <code>previous_score</code> which was the score prior to reranking.
* <p>
* Key features:
* <ul>
* <li>Reranks search results based on a specified target field</li>
* <li>Optionally removes the target field from the final search results</li>
* <li>Supports nested field structures using dot notation</li>
* </ul>
* <p>
* The processor uses the following configuration parameters:
* <ul>
* <li>{@code target_field}: The field to be used for reranking (required)</li>
* <li>{@code remove_target_field}: Whether to remove the target field from the final results (optional, default: false)</li>
* <li>{@code keep_previous_score}: Whether to append the previous score in a field called <code>previous_score</code> (optional, default: false)</li>
* </ul>
* <p>
* Usage example:
* <pre>
* {
* "rerank": {
* "by_field": {
* "target_field": "document.relevance_score",
* "remove_target_field": true,
* "keep_previous_score": false
* }
* }
* }
* </pre>
* <p>
* This processor is useful in scenarios where additional, document-specific
* information stored in a field can be used to improve the relevance of search results
* beyond the initial scoring.
*/
@Log4j2
public class ByFieldRerankProcessor extends RescoringRerankProcessor {

public static final String TARGET_FIELD = "target_field";
public static final String REMOVE_TARGET_FIELD = "remove_target_field";
public static final String KEEP_PREVIOUS_SCORE = "keep_previous_score";

public static final boolean DEFAULT_REMOVE_TARGET_FIELD = false;
public static final boolean DEFAULT_KEEP_PREVIOUS_SCORE = false;

protected final String targetField;
protected final boolean removeTargetField;
protected final boolean keepPreviousScore;

/**
* Constructor to pass values to the RerankProcessor constructor.
*
* @param description The description of the processor
* @param tag The processor's identifier
* @param ignoreFailure If true, OpenSearch ignores any failure of this processor and
* continues to run the remaining processors in the search pipeline.
* @param targetField The field you want to replace your <code>_score</code> with
* @param removeTargetField A flag to let you delete the target_field for better visualization (i.e. removes a duplicate value)
* @param keepPreviousScore A flag to let you decide to stash your previous <code>_score</code> in a field called <code>previous_score</code> (i.e. for debugging purposes)
* @param contextSourceFetchers Context from some source and puts it in a map for a reranking processor to use <b> (Unused in ByFieldRerankProcessor)</b>
*/
public ByFieldRerankProcessor(
final String description,
final String tag,
final boolean ignoreFailure,
final String targetField,
final boolean removeTargetField,
final boolean keepPreviousScore,
final List<ContextSourceFetcher> contextSourceFetchers
) {
super(RerankType.BY_FIELD, description, tag, ignoreFailure, contextSourceFetchers);
this.targetField = targetField;
this.removeTargetField = removeTargetField;
this.keepPreviousScore = keepPreviousScore;
}

@Override
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
SearchHit[] searchHits = response.getHits().getHits();

SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator;

if (!validateRerankCriteria(searchHits, searchHitValidator, listener)) {
return;
}

List<Float> scores = new ArrayList<>(searchHits.length);

for (SearchHit hit : searchHits) {
Map<String, Object> sourceAsMap = hit.getSourceAsMap();

float score = getScoreFromSourceMap(sourceAsMap, targetField);
scores.add(score);

if (keepPreviousScore) {
sourceAsMap.put("previous_score", hit.getScore());
}

if (removeTargetField) {
removeTargetFieldFromSource(sourceAsMap, targetField);
}

try {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
BytesReference sourceMapAsBytes = BytesReference.bytes(builder.map(sourceAsMap));
hit.sourceRef(sourceMapAsBytes);
} catch (IOException e) {
log.error(e.getMessage());
listener.onFailure(new RuntimeException(e));
return;
}
}

listener.onResponse(scores);
}

/**
* Implements the behavior of the SearchHit validator {@code SearchHitValidator}
* It checks all the following
* <ul>
* <li>Checks the search hit has a source mapping</li>
* <li>Checks that the mapping exists in the source mapping using the target_field</li>
* <li>Checks that the mapping has a numerical score for it to rerank</li>
* </ul>
* @param hit A search hit to validate
*/
public void byFieldSearchHitValidator(final SearchHit hit) {
if (!hit.hasSource()) {
log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId()));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId())
);
}

Map<String, Object> sourceMap = hit.getSourceAsMap();
if (!mappingExistsInSource(sourceMap, targetField)) {
log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%d]", targetField, hit.docId()));

throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%d]", hit.docId()));
}

Optional<Object> val = getValueFromSource(sourceMap, targetField);

if (!(val.get() instanceof Number)) {
log.error(String.format(Locale.ROOT, "The field mapping to rerank [%s: %s] is not Numerical", targetField, val.orElse(null)));

throw new IllegalArgumentException(
String.format(Locale.ROOT, "The field mapping to rerank by [%s] is not Numerical", val.orElse(null))
);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public abstract class RerankProcessor implements SearchResponseProcessor {
@Getter
private final boolean ignoreFailure;
protected List<ContextSourceFetcher> contextSourceFetchers;
static final protected List<RerankType> processorsWithNoContext = List.of(RerankType.BY_FIELD);

/**
* Generate the information that this processor needs in order to rerank.
Expand All @@ -48,6 +49,11 @@ public void generateRerankingContext(
final SearchResponse searchResponse,
final ActionListener<Map<String, Object>> listener
) {
// Processors that don't require context, result on a listener infinitely waiting for a response without this check
if (!processorRequiresContext(subType)) {
listener.onResponse(Map.of());
}

Map<String, Object> overallContext = new ConcurrentHashMap<>();
AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size());
for (ContextSourceFetcher csf : contextSourceFetchers) {
Expand Down Expand Up @@ -102,4 +108,19 @@ public void processResponseAsync(
responseListener.onFailure(e);
}
}

/**
* There are scenarios where ranking occurs without needing context. Currently, these are the processors don't require
* the context mapping
* <ul>
* <li>
* ByFieldRerankProcessor - Uses the search response to get value to rescore by
* </li>
* </ul>
* @param subType The kind of rerank processor
* @return Whether a rerank subtype needs context to perform the rescore search response action.
*/
public static boolean processorRequiresContext(RerankType subType) {
return !processorsWithNoContext.contains(subType);
}
}
Loading
Loading