diff --git a/CHANGELOG.md b/CHANGELOG.md index 675ea5983..ed5459585 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.17...2.x) ### Features +- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874)) ### Enhancements - Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907)) - Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917)) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 8b173ba81..8b9016323 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -30,20 +30,22 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; -import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; -import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; @@ -154,7 +156,9 @@ public Map querySearchResults; + @NonNull + private Optional fetchSearchResultOptional; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; + @NonNull + private ScoreCombinationTechnique combinationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0563c92a0..231749f33 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -58,7 +58,15 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); + // Builds data transfer object to pass into execute, DTO has nullable field for rankConstant which + // is only used for RRF technique + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .build(); + normalizationWorkflow.execute(normalizationExecuteDTO); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index c64f1c1f4..6507e3bd9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -47,16 +47,15 @@ public class NormalizationProcessorWorkflow { /** * Start execution of this workflow - * @param querySearchResults input data with QuerySearchResult from multiple shards - * @param normalizationTechnique technique for score normalization - * @param combinationTechnique technique for score combination + * @param normalizationExecuteDTO contains querySearchResults input data with QuerySearchResult + * from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization + * combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique */ - public void execute( - final List querySearchResults, - final Optional fetchSearchResultOptional, - final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique - ) { + public void execute(final NormalizationExecuteDTO normalizationExecuteDTO) { + final List querySearchResults = normalizationExecuteDTO.getQuerySearchResults(); + final Optional fetchSearchResultOptional = normalizationExecuteDTO.getFetchSearchResultOptional(); + final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDTO.getNormalizationTechnique(); + final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDTO.getCombinationTechnique(); // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); @@ -64,9 +63,15 @@ public void execute( log.debug("Pre-process query results"); List queryTopDocs = getQueryTopDocs(querySearchResults); + // Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + // normalize log.debug("Do score normalization"); - scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + scoreNormalizer.normalizeScores(normalizeScoresDTO); CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java new file mode 100644 index 000000000..c932a157d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; + +import java.util.List; + +/** + * DTO object to hold data required for score normalization. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizeScoresDTO { + @NonNull + private List queryTopDocs; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java new file mode 100644 index 000000000..207af156c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + +import java.util.stream.Collectors; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import lombok.Getter; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; + +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Processor for implementing reciprocal rank fusion technique on post + * query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + * by using ranks from individual subqueries to calculate 'normalized' + * scores before combining results from subqueries into final results + */ +@Log4j2 +@AllArgsConstructor +public class RRFProcessor implements SearchPhaseResultsProcessor { + public static final String TYPE = "score-ranker-processor"; + + @Getter + private final String tag; + @Getter + private final String description; + private final ScoreNormalizationTechnique normalizationTechnique; + private final ScoreCombinationTechnique combinationTechnique; + private final NormalizationProcessorWorkflow normalizationWorkflow; + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + if (shouldSkipProcessor(searchPhaseResult)) { + log.debug("Query results are not compatible with RRF processor"); + return; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + + // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending + // on coming from NormalizationProcessor or RRFProcessor + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .build(); + normalizationWorkflow.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + + private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { + return true; + } + + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); + } + + /** + * Return true if results are from hybrid query. + * @param searchPhaseResult + * @return true if results are from hybrid query + */ + private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + // check for delimiter at the end of the score docs. + return Objects.nonNull(searchPhaseResult.queryResult()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) + && searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 + && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); + } + + private List getQueryPhaseSearchResults( + final SearchPhaseResults results + ) { + return results.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + } + + private Optional getFetchSearchResults( + final SearchPhaseResults searchPhaseResults + ) { + Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); + return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java new file mode 100644 index 000000000..befe14dda --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import lombok.ToString; +import lombok.extern.log4j.Log4j2; + +import java.util.Map; + +@Log4j2 +/** + * Abstracts combination of scores based on reciprocal rank fusion algorithm + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + + // Not currently using weights for RRF, no need to modify or verify these params + public RRFScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) {} + + @Override + public float combine(final float[] scores) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 23d8e01be..1e560342a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -25,7 +25,9 @@ public class ScoreCombinationFactory { HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil), + RRFScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index a915057df..b7a07395f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -25,7 +25,7 @@ @Log4j2 class ScoreCombinationUtil { private static final String PARAM_NAME_WEIGHTS = "weights"; - private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + private static final float DELTA_FOR_WEIGHTS_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -117,7 +117,7 @@ protected void validateIfWeightsMatchScores(final float[] scores, final List weightsList) { - boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.of(0.0f, 1.0f).contains(weight)); if (isOutOfRange) { throw new IllegalArgumentException( String.format( @@ -128,7 +128,7 @@ private void validateWeights(final List weightsList) { ); } float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); - if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_WEIGHTS_ASSERTION)) { throw new IllegalArgumentException( String.format( Locale.ROOT, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java new file mode 100644 index 000000000..fa4f39942 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; +import java.util.Objects; + +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.RRFScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.RRFNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; + +/** + * Factory class to instantiate RRF processor based on user provided input. + */ +@AllArgsConstructor +@Log4j2 +public class RRFProcessorFactory implements Processor.Factory { + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + public static final String PARAMETERS = "parameters"; + + private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private ScoreNormalizationFactory scoreNormalizationFactory; + private ScoreCombinationFactory scoreCombinationFactory; + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + // assign defaults + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization( + RRFNormalizationTechnique.TECHNIQUE_NAME + ); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination( + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); + if (Objects.nonNull(combinationClause)) { + String combinationTechnique = readStringProperty( + RRFProcessor.TYPE, + tag, + combinationClause, + TECHNIQUE, + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + // check for optional combination params + Map params = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS); + normalizationTechnique = scoreNormalizationFactory.createNormalization(RRFNormalizationTechnique.TECHNIQUE_NAME, params); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); + } + log.info( + "Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]", + RRFProcessor.TYPE, + normalizationTechnique, + scoreCombinationTechnique + ); + return new RRFProcessor(tag, description, normalizationTechnique, scoreCombinationTechnique, normalizationProcessorWorkflow); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 2bb6bbed7..4acaf9626 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import lombok.ToString; @@ -31,7 +32,8 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); // get l2 norms for each sub-query List normsPerSubquery = getL2Norm(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 4fdf3c0a6..dcaae402e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -11,10 +11,12 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import com.google.common.primitives.Floats; import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** * Abstracts normalization of scores based on min-max method @@ -34,7 +36,8 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); int numOfSubqueries = queryTopDocs.stream() .filter(Objects::nonNull) .filter(topDocs -> topDocs.getTopDocs().size() > 0) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java new file mode 100644 index 000000000..16ef83d05 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Locale; +import java.util.Set; + +import org.apache.commons.lang3.Range; +import org.apache.commons.lang3.math.NumberUtils; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; + +/** + * Abstracts calculation of rank scores for each document returned as part of + * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + public static final int DEFAULT_RANK_CONSTANT = 60; + public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT); + private static final int MIN_RANK_CONSTANT = 1; + private static final int MAX_RANK_CONSTANT = 10_000; + private static final Range RANK_CONSTANT_RANGE = Range.of(MIN_RANK_CONSTANT, MAX_RANK_CONSTANT); + @ToString.Include + private final int rankConstant; + + public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); + rankConstant = getRankConstant(params); + } + + /** + * Reciprocal Rank Fusion normalization technique + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant, which has a default value of 60 if not specified by user + * algorithm as follows, where document_n_score is the new score for each document in queryTopDocs + * and subquery_result_rank is the position in the array of documents returned for each subquery + * (j + 1 is used to adjust for 0 indexing) + * document_n_score = 1 / (rankConstant + subquery_result_rank) + * document scores are summed in combination step + */ + @Override + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (TopDocs topDocs : topDocsPerSubQuery) { + int docsCountPerSubQuery = topDocs.scoreDocs.length; + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + for (int j = 0; j < docsCountPerSubQuery; j++) { + // using big decimal approach to minimize error caused by floating point ops + // score = 1.f / (float) (rankConstant + j + 1)) + scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP) + .floatValue(); + } + } + } + } + + private int getRankConstant(final Map params) { + if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_RANK_CONSTANT)) { + return DEFAULT_RANK_CONSTANT; + } + int rankConstant = getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT); + validateRankConstant(rankConstant); + return rankConstant; + } + + private void validateRankConstant(final int rankConstant) { + if (!RANK_CONSTANT_RANGE.contains(rankConstant)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "rank constant must be in the interval between 1 and 10000, submitted rank constant: %d", + rankConstant + ) + ); + } + } + + public static int getParamAsInteger(final Map parameters, final String fieldName) { + try { + return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName)); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index ca6ad20d6..7c62893a5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -6,19 +6,24 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Function; /** * Abstracts creation of exact score normalization method based on technique name */ public class ScoreNormalizationFactory { + private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); - private final Map scoreNormalizationMethodsMap = Map.of( + private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - new MinMaxScoreNormalizationTechnique(), + params -> new MinMaxScoreNormalizationTechnique(), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - new L2ScoreNormalizationTechnique() + params -> new L2ScoreNormalizationTechnique(), + RRFNormalizationTechnique.TECHNIQUE_NAME, + params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); /** @@ -27,7 +32,12 @@ public class ScoreNormalizationFactory { * @return instance of ScoreNormalizationMethod for technique name */ public ScoreNormalizationTechnique createNormalization(final String technique) { + return createNormalization(technique, Map.of()); + } + + public ScoreNormalizationTechnique createNormalization(final String technique, final Map params) { return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) - .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")) + .apply(params); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index 0b784c678..f8190a728 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -4,9 +4,7 @@ */ package org.opensearch.neuralsearch.processor.normalization; -import java.util.List; - -import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** * Abstracts normalization of scores in query search results. @@ -14,8 +12,12 @@ public interface ScoreNormalizationTechnique { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant that is only used in RRF technique */ - void normalize(final List queryTopDocs); + void normalize(final NormalizeScoresDTO normalizeScoresDTO); + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java new file mode 100644 index 000000000..ad24b0aaa --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import lombok.extern.log4j.Log4j2; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * Collection of utility methods for score combination technique classes + */ +@Log4j2 +class ScoreNormalizationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + String.join(",", supportedParams) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 263115f8f..2ce131bf9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -8,17 +8,22 @@ import java.util.Objects; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; public class ScoreNormalizer { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries - * @param scoreNormalizationTechnique exact normalization technique that should be applied + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO used as data transfer object to pass in queryTopDocs, original query results + * from multiple shards and multiple sub-queries, scoreNormalizationTechnique exact normalization technique + * that should be applied, and nullable rankConstant that is only used in RRF technique */ - public void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique scoreNormalizationTechnique) { + public void normalizeScores(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + final ScoreNormalizationTechnique scoreNormalizationTechnique = normalizeScoresDTO.getNormalizationTechnique(); if (canQueryResultsBeNormalized(queryTopDocs)) { - scoreNormalizationTechnique.normalize(queryTopDocs); + scoreNormalizationTechnique.normalize(normalizeScoresDTO); } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 9a969e71b..a4ad9f2d4 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -27,8 +27,10 @@ import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -143,12 +145,19 @@ public void testSearchPhaseResultsProcessors() { Map> searchPhaseResultsProcessors = plugin .getSearchPhaseResultsProcessors(searchParameters); assertNotNull(searchPhaseResultsProcessors); - assertEquals(1, searchPhaseResultsProcessors.size()); + assertEquals(2, searchPhaseResultsProcessors.size()); + // assert normalization processor conditions assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( NormalizationProcessor.TYPE ); assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + // assert rrf processor conditions + assertTrue(searchPhaseResultsProcessors.containsKey("score-ranker-processor")); + org.opensearch.search.pipeline.Processor.Factory rankingProcessor = searchPhaseResultsProcessors.get( + RRFProcessor.TYPE + ); + assertTrue(rankingProcessor instanceof RRFProcessorFactory); } public void testGetSettings() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index e93c9b9ec..4b34f7fe1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -271,8 +271,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl ); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -327,8 +326,8 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + verify(normalizationProcessorWorkflow, never()).execute(any()); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -417,7 +416,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..9f7e7e785 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -71,13 +71,14 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); } @@ -114,12 +115,14 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResults.add(querySearchResult); } - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); } @@ -173,12 +176,14 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -233,12 +238,14 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -284,16 +291,14 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ) - ); + expectThrows(IllegalStateException.class, () -> normalizationProcessorWorkflow.execute(normalizationExecuteDTO)); } public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { @@ -336,13 +341,14 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); + NormalizationExecuteDTO normalizationExecuteDTO = NormalizationExecuteDTO.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .build(); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ); + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index 67abd552f..9f0be0300 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; + import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -20,7 +21,11 @@ public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); - scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(List.of()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); } @SneakyThrows @@ -33,7 +38,11 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -64,7 +73,11 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -101,7 +114,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); @@ -169,7 +186,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn false ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); // shard one diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..daed466d3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Map; + +public class RRFScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + public RRFScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> RRF(scores, weights); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + private float RRF(List scores, List weights) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java index b36a6b492..5ca534dac 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -34,6 +34,14 @@ public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstan assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); } + public void testRRF_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("rrf"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof RRFScoreCombinationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java similarity index 97% rename from src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java index 9e00e3833..009681116 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { +public class ScoreNormalizationUtilTests extends OpenSearchQueryTestCase { public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java new file mode 100644 index 000000000..3097402a0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java @@ -0,0 +1,214 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.TECHNIQUE; + +public class RRFProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoValuesPassed_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testCombinationParams_whenValidValues_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNegative_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", -1))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: -1") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsTooLarge_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 50_000))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: 50000") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNotNumeric_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", "string")))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("parameter [rank_constant] must be an integer")); + } + + @SneakyThrows + public void testInvalidCombinationName_whenUnsupportedFunction_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "my_function", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100)))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("provided combination technique is not supported")); + } + + @SneakyThrows + public void testInvalidTechniqueType_whenPassingNormalization_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + config.put( + NORMALIZATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, PARAMETERS, new HashMap<>(Map.of()))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + private static void assertRRFProcessor(SearchPhaseResultsProcessor searchPhaseResultsProcessor) { + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof RRFProcessor); + RRFProcessor rrfProcessor = (RRFProcessor) searchPhaseResultsProcessor; + assertEquals("score-ranker-processor", rrfProcessor.getType()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index ba4bfee0d..29fdb735f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -12,9 +12,10 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** - * Abstracts normalization of scores based on min-max method + * Abstracts normalization of scores based on L2 method */ public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; @@ -34,7 +35,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -81,7 +86,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -155,7 +164,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index d0445f0ca..239496355 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -32,7 +33,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -72,7 +77,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -127,7 +136,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the false ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java new file mode 100644 index 000000000..00ec13b73 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -0,0 +1,232 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.List; +import java.util.Map; + +/** + * Abstracts testing of normalization of scores based on RRF method + */ +public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { + static final int RANK_CONSTANT = 60; + private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scores = { 0.5f, 0.2f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } + ) + ), + false + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ) + ), + false + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresQuery1 = { 0.5f, 0.2f }; + float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresQuery1[0]), new ScoreDoc(4, scoresQuery1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresQuery2[0]), + new ScoreDoc(4, scoresQuery2[1]), + new ScoreDoc(2, scoresQuery2[2]) } + ) + ), + false + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresShard1Query1 = { 0.5f, 0.2f }; + float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; + float[] scoresShard2Query2 = { 2.9f, 0.7f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresShard1Query1[0]), new ScoreDoc(4, scoresShard1Query1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[0]), + new ScoreDoc(4, scoresShard1and2Query3[1]), + new ScoreDoc(2, scoresShard1and2Query3[2]) } + ) + ), + false + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, scoresShard2Query2[0]), new ScoreDoc(9, scoresShard2Query2[1]) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[3]), + new ScoreDoc(9, scoresShard1and2Query3[4]), + new ScoreDoc(10, scoresShard1and2Query3[5]), + new ScoreDoc(15, scoresShard1and2Query3[6]) } + ) + ), + false + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, rrfNorm(0)), new ScoreDoc(9, rrfNorm(1)) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, rrfNorm(3)), + new ScoreDoc(9, rrfNorm(4)), + new ScoreDoc(10, rrfNorm(5)), + new ScoreDoc(15, rrfNorm(6)) } + ) + ), + false + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); + } + } + + private float rrfNorm(int rank) { + // 1.0f / (float) (rank + RANK_CONSTANT + 1); + return BigDecimal.ONE.divide(BigDecimal.valueOf(rank + RANK_CONSTANT + 1), 10, RoundingMode.HALF_UP).floatValue(); + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java index d9dcd5540..cecdf8779 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -26,6 +26,14 @@ public void testL2Norm_whenCreatingByName_thenReturnCorrectInstance() { assertTrue(scoreNormalizationTechnique instanceof L2ScoreNormalizationTechnique); } + public void testRRFNorm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("rrf"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof RRFNormalizationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index e0c51e106..6d8e810f3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -16,6 +16,7 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_OVERSAMPLE_FIELD; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION; import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; @@ -183,7 +184,11 @@ public void testFromXContent_withRescoreContext_thenBuildSuccessfully() { assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(K, neuralQueryBuilder.k()); - assertEquals(RescoreContext.getDefault(), neuralQueryBuilder.rescoreContext()); + assertEquals( + RescoreContext.getDefault().getOversampleFactor(), + neuralQueryBuilder.rescoreContext().getOversampleFactor(), + DELTA_FOR_FLOATS_ASSERTION + ); assertNull(neuralQueryBuilder.methodParameters()); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index a1e8210e6..9c162ce11 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -53,6 +53,8 @@ public abstract class OpenSearchQueryTestCase extends OpenSearchTestCase { + protected static final float DELTA_FOR_ASSERTION = 0.001f; + protected final MapperService createMapperService(Version version, XContentBuilder mapping) throws IOException { IndexMetadata meta = IndexMetadata.builder("index") .settings(Settings.builder().put("index.version.created", version)) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 1d3bc29e9..9527fe9fd 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -81,7 +81,6 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; private static final String QUERY2 = "hi"; - private static final float DELTA_FOR_ASSERTION = 0.001f; protected static final String QUERY3 = "everyone"; @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java index 196014220..f91dae327 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -21,8 +21,6 @@ public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); TopDocsMerger topDocsMerger = new TopDocsMerger(null); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 9c2718687..2e064913f 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -27,8 +27,6 @@ public class TopDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { TopDocsMerger topDocsMerger = new TopDocsMerger(null); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index bc016aae2..ab041c440 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -39,6 +39,7 @@ public class TestUtils { public static final String RELATION_EQUAL_TO = "eq"; public static final float DELTA_FOR_SCORE_ASSERTION = 0.001f; + public static final float DELTA_FOR_FLOATS_ASSERTION = 0.001f; public static final String RESTART_UPGRADE_OLD_CLUSTER = "tests.is_old_cluster"; public static final String BWC_VERSION = "tests.plugin_bwc_version"; public static final String NEURAL_SEARCH_BWC_PREFIX = "neuralsearch-bwc-";