Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Implement pruning for neural sparse search #988

Merged
merged 19 commits into from
Dec 18, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@
import lombok.Getter;
import lombok.Setter;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.collect.Tuple;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.search.rescore.RescorerBuilder;

import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* A SearchRequestProcessor to generate two-phase NeuralSparseQueryBuilder,
Expand All @@ -37,41 +36,37 @@ public class NeuralSparseTwoPhaseProcessor extends AbstractProcessor implements

public static final String TYPE = "neural_sparse_two_phase_processor";
private boolean enabled;
private float ratio;
private float pruneRatio;
private PruneType pruneType;
private float windowExpansion;
private int maxWindowSize;
private static final String PARAMETER_KEY = "two_phase_parameter";
private static final String RATIO_KEY = "prune_ratio";
private static final String ENABLE_KEY = "enabled";
private static final String EXPANSION_KEY = "expansion_rate";
private static final String MAX_WINDOW_SIZE_KEY = "max_window_size";
private static final boolean DEFAULT_ENABLED = true;
private static final float DEFAULT_RATIO = 0.4f;
private static final PruneType DEFAULT_PRUNE_TYPE = PruneType.MAX_RATIO;
private static final float DEFAULT_WINDOW_EXPANSION = 5.0f;
private static final int DEFAULT_MAX_WINDOW_SIZE = 10000;
private static final int DEFAULT_BASE_QUERY_SIZE = 10;
private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50;
private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f;
private static final float RATIO_LOWER_BOUND = 0f;
private static final float RATIO_UPPER_BOUND = 1f;

protected NeuralSparseTwoPhaseProcessor(
String tag,
String description,
boolean ignoreFailure,
boolean enabled,
float ratio,
float pruneRatio,
PruneType pruneType,
float windowExpansion,
int maxWindowSize
) {
super(tag, description, ignoreFailure);
this.enabled = enabled;
if (ratio < RATIO_LOWER_BOUND || ratio > RATIO_UPPER_BOUND) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f", ratio)
);
}
this.ratio = ratio;
this.pruneRatio = pruneRatio;
this.pruneType = pruneType;
if (windowExpansion < WINDOW_EXPANSION_LOWER_BOUND) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f", windowExpansion)
Expand All @@ -93,7 +88,7 @@ protected NeuralSparseTwoPhaseProcessor(
*/
@Override
public SearchRequest processRequest(final SearchRequest request) {
if (!enabled || ratio == 0f) {
if (!enabled || pruneRatio == 0f) {
return request;
}
QueryBuilder queryBuilder = request.source().query();
Expand All @@ -117,43 +112,6 @@ public String getType() {
return TYPE;
}

/**
* Based on ratio, split a Map into two map by the value.
*
* @param queryTokens the queryTokens map, key is the token String, value is the score.
* @param thresholdRatio The ratio that control how tokens map be split.
* @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold }
*/
public static Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold(
final Map<String, Float> queryTokens,
final float thresholdRatio
) {
if (Objects.isNull(queryTokens)) {
throw new IllegalArgumentException("Query tokens cannot be null or empty.");
}
float max = 0f;
for (Float value : queryTokens.values()) {
max = Math.max(value, max);
}
float threshold = max * thresholdRatio;

Map<Boolean, Map<String, Float>> queryTokensByScore = queryTokens.entrySet()
.stream()
.collect(
Collectors.partitioningBy(entry -> entry.getValue() >= threshold, Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
);

Map<String, Float> highScoreTokens = queryTokensByScore.get(Boolean.TRUE);
Map<String, Float> lowScoreTokens = queryTokensByScore.get(Boolean.FALSE);
if (Objects.isNull(highScoreTokens)) {
highScoreTokens = Collections.emptyMap();
}
if (Objects.isNull(lowScoreTokens)) {
lowScoreTokens = Collections.emptyMap();
}
return Tuple.tuple(highScoreTokens, lowScoreTokens);
}

private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(
final Multimap<NeuralSparseQueryBuilder, Float> queryBuilderFloatMap
) {
Expand Down Expand Up @@ -201,7 +159,10 @@ private Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilde
* - Docs besides TopDocs: Score = HighScoreToken's score
* - Final TopDocs: Score = HighScoreToken's score + LowScoreToken's score
*/
NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio);
NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(
pruneRatio,
pruneType
);
result.put(modifiedQueryBuilder, updatedBoost);
}
// We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will
Expand Down Expand Up @@ -248,16 +209,40 @@ public NeuralSparseTwoPhaseProcessor create(
boolean enabled = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, ENABLE_KEY, DEFAULT_ENABLED);
Map<String, Object> twoPhaseConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY);

float ratio = DEFAULT_RATIO;
float pruneRatio = DEFAULT_RATIO;
float windowExpansion = DEFAULT_WINDOW_EXPANSION;
int maxWindowSize = DEFAULT_MAX_WINDOW_SIZE;
PruneType pruneType = DEFAULT_PRUNE_TYPE;
if (Objects.nonNull(twoPhaseConfigMap)) {
ratio = ((Number) twoPhaseConfigMap.getOrDefault(RATIO_KEY, ratio)).floatValue();
pruneRatio = ((Number) twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_RATIO_FIELD, pruneRatio)).floatValue();
windowExpansion = ((Number) twoPhaseConfigMap.getOrDefault(EXPANSION_KEY, windowExpansion)).floatValue();
maxWindowSize = ((Number) twoPhaseConfigMap.getOrDefault(MAX_WINDOW_SIZE_KEY, maxWindowSize)).intValue();
pruneType = PruneType.fromString(
twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString()
);
}
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}

return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, windowExpansion, maxWindowSize);
return new NeuralSparseTwoPhaseProcessor(
tag,
description,
ignoreFailure,
enabled,
pruneRatio,
pruneType,
windowExpansion,
maxWindowSize
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import lombok.Getter;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

/**
* This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use,
Expand All @@ -27,18 +30,26 @@ public final class SparseEncodingProcessor extends InferenceProcessor {

public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
@Getter
private final PruneType pruneType;
@Getter
private final float pruneRatio;

public SparseEncodingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
PruneType pruneType,
float pruneRatio,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
this.pruneType = pruneType;
this.pruneRatio = pruneRatio;
}

@Override
Expand All @@ -49,17 +60,23 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
this.modelId,
inferenceList,
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
);
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
handler.accept(sparseVectors);
}, onException));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;
import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE;

import java.util.Locale;
import java.util.Map;

import org.opensearch.cluster.service.ClusterService;
Expand All @@ -19,6 +22,8 @@
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneType;

/**
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
Expand All @@ -40,7 +45,40 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
// if the field is miss, will return PruneType.None
PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD));
float pruneRatio = 0;
if (pruneType != PruneType.NONE) {
// if we have prune type, then prune ratio field must have value
// readDoubleProperty will throw exception if value is not present
pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue();
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}
} else if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) {
// if we don't have prune type, then prune ratio field must not have value
throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided");
}

return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService);
return new SparseEncodingProcessor(
tag,
description,
batchSize,
modelId,
fieldMap,
pruneType,
pruneRatio,
clientAccessor,
environment,
clusterService
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;

import static org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

/**
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
Expand Down Expand Up @@ -90,6 +90,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
// 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
// Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
private float twoPhasePruneRatio = 0F;
private PruneType twoPhasePruneType = PruneType.NONE;

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

Expand Down Expand Up @@ -129,22 +130,23 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {

/**
* Copy this QueryBuilder for two phase rescorer, set the copy one's twoPhasePruneRatio to -1.
* @param ratio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
* @param pruneRatio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
* @return A copy NeuralSparseQueryBuilder for twoPhase, it will be added to the rescorer.
*/
public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float ratio) {
this.twoPhasePruneRatio(ratio);
public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float pruneRatio, PruneType pruneType) {
this.twoPhasePruneRatio(pruneRatio);
this.twoPhasePruneType(pruneType);
NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder().fieldName(this.fieldName)
.queryName(this.queryName)
.queryText(this.queryText)
.modelId(this.modelId)
.maxTokenScore(this.maxTokenScore)
.twoPhasePruneRatio(-1f * ratio);
.twoPhasePruneRatio(-1f * pruneRatio);
if (Objects.nonNull(this.queryTokensSupplier)) {
Map<String, Float> tokens = queryTokensSupplier.get();
// Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1,
// while those less than or equal to the threshold are stored in v2.
Tuple<Map<String, Float>, Map<String, Float>> splitTokens = splitQueryTokensByRatioedMaxScoreAsThreshold(tokens, ratio);
Tuple<Map<String, Float>, Map<String, Float>> splitTokens = PruneUtils.splitSparseVector(pruneType, pruneRatio, tokens);
this.queryTokensSupplier(() -> splitTokens.v1());
copy.queryTokensSupplier(() -> splitTokens.v2());
} else {
Expand Down Expand Up @@ -346,9 +348,10 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
ActionListener.wrap(mapResultList -> {
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = splitQueryTokensByRatioedMaxScoreAsThreshold(
queryTokens,
twoPhasePruneRatio
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = PruneUtils.splitSparseVector(
twoPhasePruneType,
twoPhasePruneRatio,
queryTokens
);
setOnce.set(splitQueryTokens.v1());
twoPhaseSharedQueryToken = splitQueryTokens.v2();
Expand Down
Loading
Loading