From 1e09989fdf5d0e9aed73a5d342bf2b0492cf5b59 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Nov 2024 15:04:58 +0800 Subject: [PATCH 01/17] add impl Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 12 ++- .../SparseEncodingProcessorFactory.java | 29 +++++- .../processor/pruning/PruneUtils.java | 92 +++++++++++++++++++ .../processor/pruning/PruningType.java | 45 +++++++++ .../neuralsearch/util/TokenWeightUtil.java | 21 ++++- 5 files changed, 192 insertions(+), 7 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..d49a6a709 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,6 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.pruning.PruningType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -27,6 +28,8 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; + private final PruningType pruningType; + private final float pruneRatio; public SparseEncodingProcessor( String tag, @@ -34,11 +37,15 @@ public SparseEncodingProcessor( int batchSize, String modelId, Map fieldMap, + PruningType pruningType, + float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + this.pruningType = pruningType; + this.pruneRatio = pruneRatio; } @Override @@ -49,7 +56,8 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio); + setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } @@ -59,7 +67,7 @@ public void doBatchExecute(List inferenceList, Consumer> handler mlCommonsClientAccessor.inferenceSentencesWithMapResult( this.modelId, inferenceList, - ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) + ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), onException) ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 46055df16..8f9e42eea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -6,6 +6,8 @@ import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; @@ -19,6 +21,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.pruning.PruneUtils; +import org.opensearch.neuralsearch.processor.pruning.PruningType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -40,7 +44,30 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + PruningType pruningType = PruningType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); + float pruneRatio = 0; + if (pruningType != PruningType.NONE) { + // if we have prune type, then prune ratio field must have value + // readDoubleProperty will throw exception if value is not present + pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); + } else { + // if we don't have prune type, then prune ratio field must not have value + if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { + throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); + } + } - return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService); + return new SparseEncodingProcessor( + tag, + description, + batchSize, + modelId, + fieldMap, + pruningType, + pruneRatio, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java new file mode 100644 index 000000000..47aaaeac9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.pruning; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class PruneUtils { + public static final String PRUNE_TYPE_FIELD = "prune_type"; + public static final String PRUNE_RATIO_FIELD = "prune_ratio"; + + public static Map pruningByTopK(Map sparseVector, int k) { + List> list = new ArrayList<>(sparseVector.entrySet()); + list.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + Map result = new HashMap<>(); + for (int i = 0; i < k && i < list.size(); i++) { + Map.Entry entry = list.get(i); + result.put(entry.getKey(), entry.getValue()); + } + return result; + } + + public static Map pruningByMaxRatio(Map sparseVector, float ratio) { + float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); + + Map result = new HashMap<>(); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentValue = entry.getValue(); + float currentRatio = currentValue / maxValue; + + if (currentRatio >= ratio) { + result.put(entry.getKey(), entry.getValue()); + } + } + + return result; + } + + public static Map pruningByValue(Map sparseVector, float thresh) { + Map result = new HashMap<>(sparseVector); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentValue = Math.abs(entry.getValue()); + if (currentValue < thresh) { + result.remove(entry.getKey()); + } + } + + return result; + } + + public static Map pruningByAlphaMass(Map sparseVector, float alpha) { + List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); + sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); + float topSum = 0f; + + Map result = new HashMap<>(); + for (Map.Entry entry : sortedEntries) { + float value = entry.getValue(); + topSum += value; + result.put(entry.getKey(), value); + + if (topSum / sum >= alpha) { + break; + } + } + + return result; + } + + public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { + switch (pruningType) { + case TOP_K: + return pruningByTopK(sparseVector, (int) pruneRatio); + case ALPHA_MASS: + return pruningByAlphaMass(sparseVector, pruneRatio); + case MAX_RATIO: + return pruningByMaxRatio(sparseVector, pruneRatio); + case ABS_VALUE: + return pruningByValue(sparseVector, pruneRatio); + default: + return sparseVector; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java new file mode 100644 index 000000000..5a26a1e53 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.pruning; + +import org.apache.commons.lang.StringUtils; + +/** + * Enum representing different types of pruning methods for sparse vectors + */ +public enum PruningType { + NONE("none"), + TOP_K("top_k"), + ALPHA_MASS("alpha_mass"), + MAX_RATIO("max_ratio"), + ABS_VALUE("abs_value"); + + private final String value; + + PruningType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + /** + * Get PruningType from string value + * + * @param value string representation of pruning type + * @return corresponding PruningType enum + * @throws IllegalArgumentException if value doesn't match any pruning type + */ + public static PruningType fromString(String value) { + if (StringUtils.isEmpty(value)) return NONE; + for (PruningType type : PruningType.values()) { + if (type.value.equals(value)) { + return type; + } + } + throw new IllegalArgumentException("Unknown pruning type: " + value); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index e36b42cd6..3189706de 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,6 +4,9 @@ */ package org.opensearch.neuralsearch.util; +import org.opensearch.neuralsearch.processor.pruning.PruneUtils; +import org.opensearch.neuralsearch.processor.pruning.PruningType; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -44,7 +47,11 @@ public class TokenWeightUtil { * * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ - public static List> fetchListOfTokenWeightMap(List> mapResultList) { + public static List> fetchListOfTokenWeightMap( + List> mapResultList, + PruningType pruningType, + float pruneRatio + ) { if (null == mapResultList || mapResultList.isEmpty()) { throw new IllegalArgumentException("The inference result can not be null or empty."); } @@ -58,10 +65,16 @@ public static List> fetchListOfTokenWeightMap(List) map.get("response")); } - return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList()); + return results.stream() + .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruningType, pruneRatio)) + .collect(Collectors.toList()); + } + + public static List> fetchListOfTokenWeightMap(List> mapResultList) { + return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruningType.NONE, 0f); } - private static Map buildTokenWeightMap(Object uncastedMap) { + private static Map buildTokenWeightMap(Object uncastedMap, PruningType pruningType, float pruneRatio) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } @@ -72,6 +85,6 @@ private static Map buildTokenWeightMap(Object uncastedMap) { } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } - return result; + return PruneUtils.pruningSparseVector(pruningType, pruneRatio, result); } } From adca9bed4594dcaf769a7bbbd510d3a9980d1394 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Nov 2024 15:59:51 +0800 Subject: [PATCH 02/17] add UT Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 7 +- .../SparseEncodingProcessorFactory.java | 8 +- .../processor/pruning/PruneUtils.java | 92 --------- .../neuralsearch/util/TokenWeightUtil.java | 4 +- .../neuralsearch/util/pruning/PruneUtils.java | 180 ++++++++++++++++++ .../pruning/PruningType.java | 2 +- .../util/pruning/PruneUtilsTests.java | 159 ++++++++++++++++ 7 files changed, 353 insertions(+), 99 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java rename src/main/java/org/opensearch/neuralsearch/{processor => util}/pruning/PruningType.java (95%) create mode 100644 src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index d49a6a709..61851c1d6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -14,7 +14,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruningType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -67,7 +67,10 @@ public void doBatchExecute(List inferenceList, Consumer> handler mlCommonsClientAccessor.inferenceSentencesWithMapResult( this.modelId, inferenceList, - ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), onException) + ActionListener.wrap( + resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), + onException + ) ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 8f9e42eea..40a31392c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -21,8 +21,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; -import org.opensearch.neuralsearch.processor.pruning.PruneUtils; -import org.opensearch.neuralsearch.processor.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.pruning.PruningType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -44,12 +44,16 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + // if the field is miss, will return PruningType.None PruningType pruningType = PruningType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); float pruneRatio = 0; if (pruningType != PruningType.NONE) { // if we have prune type, then prune ratio field must have value // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); + if (!PruneUtils.isValidPruneRatio(pruningType, pruneRatio)) throw new IllegalArgumentException( + "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruningType.name() + ); } else { // if we don't have prune type, then prune ratio field must not have value if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java deleted file mode 100644 index 47aaaeac9..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruneUtils.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.processor.pruning; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class PruneUtils { - public static final String PRUNE_TYPE_FIELD = "prune_type"; - public static final String PRUNE_RATIO_FIELD = "prune_ratio"; - - public static Map pruningByTopK(Map sparseVector, int k) { - List> list = new ArrayList<>(sparseVector.entrySet()); - list.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); - - Map result = new HashMap<>(); - for (int i = 0; i < k && i < list.size(); i++) { - Map.Entry entry = list.get(i); - result.put(entry.getKey(), entry.getValue()); - } - return result; - } - - public static Map pruningByMaxRatio(Map sparseVector, float ratio) { - float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); - - Map result = new HashMap<>(); - for (Map.Entry entry : sparseVector.entrySet()) { - float currentValue = entry.getValue(); - float currentRatio = currentValue / maxValue; - - if (currentRatio >= ratio) { - result.put(entry.getKey(), entry.getValue()); - } - } - - return result; - } - - public static Map pruningByValue(Map sparseVector, float thresh) { - Map result = new HashMap<>(sparseVector); - for (Map.Entry entry : sparseVector.entrySet()) { - float currentValue = Math.abs(entry.getValue()); - if (currentValue < thresh) { - result.remove(entry.getKey()); - } - } - - return result; - } - - public static Map pruningByAlphaMass(Map sparseVector, float alpha) { - List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); - sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); - - float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); - float topSum = 0f; - - Map result = new HashMap<>(); - for (Map.Entry entry : sortedEntries) { - float value = entry.getValue(); - topSum += value; - result.put(entry.getKey(), value); - - if (topSum / sum >= alpha) { - break; - } - } - - return result; - } - - public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { - switch (pruningType) { - case TOP_K: - return pruningByTopK(sparseVector, (int) pruneRatio); - case ALPHA_MASS: - return pruningByAlphaMass(sparseVector, pruneRatio); - case MAX_RATIO: - return pruningByMaxRatio(sparseVector, pruneRatio); - case ABS_VALUE: - return pruningByValue(sparseVector, pruneRatio); - default: - return sparseVector; - } - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 3189706de..0ee48fa33 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.util; -import org.opensearch.neuralsearch.processor.pruning.PruneUtils; -import org.opensearch.neuralsearch.processor.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.pruning.PruningType; import java.util.ArrayList; import java.util.HashMap; diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java new file mode 100644 index 000000000..d7d2234cf --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** + * Utility class providing methods for pruning sparse vectors using different strategies. + * Pruning helps reduce the dimensionality of sparse vectors by removing less significant elements + * based on various criteria. + */ +public class PruneUtils { + public static final String PRUNE_TYPE_FIELD = "prune_type"; + public static final String PRUNE_RATIO_FIELD = "prune_ratio"; + + /** + * Prunes a sparse vector by keeping only the top K elements with the highest values. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param k The number of top elements to keep + * @return A new map containing only the top K elements + */ + private static Map pruningByTopK(Map sparseVector, int k) { + PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (pq.size() < k) { + pq.offer(entry); + } else if (entry.getValue() > pq.peek().getValue()) { + pq.poll(); + pq.offer(entry); + } + } + + Map result = new HashMap<>(); + while (!pq.isEmpty()) { + Map.Entry entry = pq.poll(); + result.put(entry.getKey(), entry.getValue()); + } + + return result; + } + + /** + * Prunes a sparse vector by keeping only elements whose values are within a certain ratio + * of the maximum value in the vector. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param ratio The minimum ratio relative to the maximum value for elements to be kept + * @return A new map containing only elements meeting the ratio threshold + */ + private static Map pruningByMaxRatio(Map sparseVector, float ratio) { + float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); + + Map result = new HashMap<>(); + for (Map.Entry entry : sparseVector.entrySet()) { + float currentRatio = entry.getValue() / maxValue; + + if (currentRatio >= ratio) { + result.put(entry.getKey(), entry.getValue()); + } + } + + return result; + } + + /** + * Prunes a sparse vector by removing elements with values below a certain threshold. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param thresh The minimum absolute value for elements to be kept + * @return A new map containing only elements meeting the threshold + */ + private static Map pruningByValue(Map sparseVector, float thresh) { + Map result = new HashMap<>(sparseVector); + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() < thresh) { + result.remove(entry.getKey()); + } + } + + return result; + } + + /** + * Prunes a sparse vector by keeping only elements whose cumulative sum of values + * is within a certain ratio of the total sum. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param alpha The minimum ratio relative to the total sum for elements to be kept + * @return A new map containing only elements meeting the ratio threshold + */ + private static Map pruningByAlphaMass(Map sparseVector, float alpha) { + List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); + sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); + float topSum = 0f; + + Map result = new HashMap<>(); + for (Map.Entry entry : sortedEntries) { + float value = entry.getValue(); + topSum += value; + result.put(entry.getKey(), value); + + if (topSum / sum >= alpha) { + break; + } + } + + return result; + } + + /** + * Prunes a sparse vector using the specified pruning type and ratio. + * + * @param pruningType The type of pruning strategy to use + * @param pruneRatio The ratio or threshold for pruning + * @param sparseVector The input sparse vector as a map of string keys to float values + * @return A new map containing the pruned sparse vector + */ + public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { + if (Objects.isNull(pruningType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( + "Prune type and prune ratio must be provided" + ); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException("Pruned values must be positive"); + } + } + + switch (pruningType) { + case TOP_K: + return pruningByTopK(sparseVector, (int) pruneRatio); + case ALPHA_MASS: + return pruningByAlphaMass(sparseVector, pruneRatio); + case MAX_RATIO: + return pruningByMaxRatio(sparseVector, pruneRatio); + case ABS_VALUE: + return pruningByValue(sparseVector, pruneRatio); + default: + return sparseVector; + } + } + + /** + * Validates whether a prune ratio is valid for a given pruning type. + * + * @param pruningType The type of pruning strategy + * @param pruneRatio The ratio or threshold to validate + * @return true if the ratio is valid for the given pruning type, false otherwise + * @throws IllegalArgumentException if pruning type is null + */ + public static boolean isValidPruneRatio(PruningType pruningType, float pruneRatio) { + if (pruningType == null) { + throw new IllegalArgumentException("Pruning type cannot be null"); + } + + switch (pruningType) { + case TOP_K: + return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); + case ALPHA_MASS: + case MAX_RATIO: + return pruneRatio > 0 && pruneRatio < 1; + case ABS_VALUE: + return pruneRatio > 0; + default: + return true; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java similarity index 95% rename from src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java rename to src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java index 5a26a1e53..6629bb937 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/pruning/PruningType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.pruning; +package org.opensearch.neuralsearch.util.pruning; import org.apache.commons.lang.StringUtils; diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java new file mode 100644 index 000000000..07a7f11eb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class PruneUtilsTests extends OpenSearchTestCase { + @Test + public void testPruningByTopK() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("c")); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + } + + @Test + public void testPruningByMaxRatio() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 8.0f); + input.put("c", 5.0f); + input.put("d", 2.0f); + + Map result = PruneUtils.pruningSparseVector(PruningType.MAX_RATIO, 0.7f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 + assertTrue(result.containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 + } + + @Test + public void testPruningByValue() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 2.0f); + input.put("d", 1.0f); + + Map result = PruneUtils.pruningSparseVector(PruningType.ABS_VALUE, 3.0f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("b")); + } + + @Test + public void testPruningByAlphaMass() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 6.0f); + input.put("c", 3.0f); + input.put("d", 1.0f); + // Total sum = 20.0 + + Map result = PruneUtils.pruningSparseVector(PruningType.ALPHA_MASS, 0.8f, input); + + assertEquals(2, result.size()); + assertTrue(result.containsKey("a")); + assertTrue(result.containsKey("b")); + } + + @Test + public void testEmptyInput() { + Map input = new HashMap<>(); + + Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 5, input); + assertTrue(result.isEmpty()); + } + + @Test + public void testNegativeValues() { + Map input = new HashMap<>(); + input.put("a", -5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input) + ); + assertEquals(exception.getMessage(), "Pruned values must be positive"); + } + + @Test + public void testInvalidPruningType() { + Map input = new HashMap<>(); + input.put("a", 1.0f); + input.put("b", 2.0f); + + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(null, 2, input) + ); + assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); + + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(null, 2, input) + ); + assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); + } + + @Test + public void testIsValidPruneRatio() { + // Test TOP_K validation + assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 100)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, -1)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1.5f)); + + // Test ALPHA_MASS validation + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.1f)); + + // Test MAX_RATIO validation + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.1f)); + + // Test ABS_VALUE validation + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 100.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, -0.1f)); + + // Test with extreme cases + assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, Float.MIN_VALUE)); + } + + @Test + public void testIsValidPruneRatioWithNullType() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); + assertEquals("Pruning type cannot be null", exception.getMessage()); + } +} From 2cc0d103ae238cc7e91b6ea45a412faf5f638dd5 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 16:49:09 +0800 Subject: [PATCH 03/17] rename pruneType; UT Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 15 +- .../SparseEncodingProcessorFactory.java | 16 +- .../neuralsearch/util/TokenWeightUtil.java | 12 +- .../{PruningType.java => PruneType.java} | 12 +- .../neuralsearch/util/pruning/PruneUtils.java | 16 +- ...ncodingEmbeddingProcessorFactoryTests.java | 182 ++++++++++++++++++ .../util/TokenWeightUtilTests.java | 33 ++++ .../util/pruning/PruneTypeTests.java | 30 +++ .../util/pruning/PruneUtilsTests.java | 71 +++---- 9 files changed, 313 insertions(+), 74 deletions(-) rename src/main/java/org/opensearch/neuralsearch/util/pruning/{PruningType.java => PruneType.java} (77%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 61851c1d6..a3a6cacbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -9,12 +9,13 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import lombok.Getter; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.util.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; @@ -28,7 +29,9 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; - private final PruningType pruningType; + @Getter + private final PruneType pruneType; + @Getter private final float pruneRatio; public SparseEncodingProcessor( @@ -37,14 +40,14 @@ public SparseEncodingProcessor( int batchSize, String modelId, Map fieldMap, - PruningType pruningType, + PruneType pruneType, float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); - this.pruningType = pruningType; + this.pruneType = pruneType; this.pruneRatio = pruneRatio; } @@ -56,7 +59,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -68,7 +71,7 @@ public void doBatchExecute(List inferenceList, Consumer> handler this.modelId, inferenceList, ActionListener.wrap( - resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruningType, pruneRatio)), + resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio)), onException ) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 40a31392c..19cea9419 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -8,9 +8,9 @@ import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; import java.util.Map; @@ -22,7 +22,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.util.pruning.PruneUtils; -import org.opensearch.neuralsearch.util.pruning.PruningType; +import org.opensearch.neuralsearch.util.pruning.PruneType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -44,15 +44,15 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); - // if the field is miss, will return PruningType.None - PruningType pruningType = PruningType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); + // if the field is miss, will return PruneType.None + PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); float pruneRatio = 0; - if (pruningType != PruningType.NONE) { + if (pruneType != PruneType.NONE) { // if we have prune type, then prune ratio field must have value // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); - if (!PruneUtils.isValidPruneRatio(pruningType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruningType.name() + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( + "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue() ); } else { // if we don't have prune type, then prune ratio field must not have value @@ -67,7 +67,7 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description, batchSize, modelId, fieldMap, - pruningType, + pruneType, pruneRatio, clientAccessor, environment, diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 0ee48fa33..0de3610f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,8 +4,8 @@ */ package org.opensearch.neuralsearch.util; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.neuralsearch.util.pruning.PruneUtils; -import org.opensearch.neuralsearch.util.pruning.PruningType; import java.util.ArrayList; import java.util.HashMap; @@ -49,7 +49,7 @@ public class TokenWeightUtil { */ public static List> fetchListOfTokenWeightMap( List> mapResultList, - PruningType pruningType, + PruneType pruneType, float pruneRatio ) { if (null == mapResultList || mapResultList.isEmpty()) { @@ -66,15 +66,15 @@ public static List> fetchListOfTokenWeightMap( results.addAll((List) map.get("response")); } return results.stream() - .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruningType, pruneRatio)) + .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruneType, pruneRatio)) .collect(Collectors.toList()); } public static List> fetchListOfTokenWeightMap(List> mapResultList) { - return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruningType.NONE, 0f); + return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruneType.NONE, 0f); } - private static Map buildTokenWeightMap(Object uncastedMap, PruningType pruningType, float pruneRatio) { + private static Map buildTokenWeightMap(Object uncastedMap, PruneType pruneType, float pruneRatio) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } @@ -85,6 +85,6 @@ private static Map buildTokenWeightMap(Object uncastedMap, Prunin } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } - return PruneUtils.pruningSparseVector(pruningType, pruneRatio, result); + return PruneUtils.pruningSparseVector(pruneType, pruneRatio, result); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java similarity index 77% rename from src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java rename to src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java index 6629bb937..22376b7c5 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruningType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java @@ -9,7 +9,7 @@ /** * Enum representing different types of pruning methods for sparse vectors */ -public enum PruningType { +public enum PruneType { NONE("none"), TOP_K("top_k"), ALPHA_MASS("alpha_mass"), @@ -18,7 +18,7 @@ public enum PruningType { private final String value; - PruningType(String value) { + PruneType(String value) { this.value = value; } @@ -27,15 +27,15 @@ public String getValue() { } /** - * Get PruningType from string value + * Get PruneType from string value * * @param value string representation of pruning type - * @return corresponding PruningType enum + * @return corresponding PruneType enum * @throws IllegalArgumentException if value doesn't match any pruning type */ - public static PruningType fromString(String value) { + public static PruneType fromString(String value) { if (StringUtils.isEmpty(value)) return NONE; - for (PruningType type : PruningType.values()) { + for (PruneType type : PruneType.values()) { if (type.value.equals(value)) { return type; } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java index d7d2234cf..87e87cffa 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java @@ -122,13 +122,13 @@ private static Map pruningByAlphaMass(Map sparseVe /** * Prunes a sparse vector using the specified pruning type and ratio. * - * @param pruningType The type of pruning strategy to use + * @param pruneType The type of pruning strategy to use * @param pruneRatio The ratio or threshold for pruning * @param sparseVector The input sparse vector as a map of string keys to float values * @return A new map containing the pruned sparse vector */ - public static Map pruningSparseVector(PruningType pruningType, float pruneRatio, Map sparseVector) { - if (Objects.isNull(pruningType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( + public static Map pruningSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { + if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( "Prune type and prune ratio must be provided" ); @@ -138,7 +138,7 @@ public static Map pruningSparseVector(PruningType pruningType, fl } } - switch (pruningType) { + switch (pruneType) { case TOP_K: return pruningByTopK(sparseVector, (int) pruneRatio); case ALPHA_MASS: @@ -155,17 +155,17 @@ public static Map pruningSparseVector(PruningType pruningType, fl /** * Validates whether a prune ratio is valid for a given pruning type. * - * @param pruningType The type of pruning strategy + * @param pruneType The type of pruning strategy * @param pruneRatio The ratio or threshold to validate * @return true if the ratio is valid for the given pruning type, false otherwise * @throws IllegalArgumentException if pruning type is null */ - public static boolean isValidPruneRatio(PruningType pruningType, float pruneRatio) { - if (pruningType == null) { + public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { + if (pruneType == null) { throw new IllegalArgumentException("Pruning type cannot be null"); } - switch (pruningType) { + switch (pruneType) { case TOP_K: return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); case ALPHA_MASS: diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..8b1fafe8b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_TYPE_FIELD; +import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_RATIO_FIELD; + +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class SparseEncodingEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String MODEL_ID = "testModelId"; + private static final int BATCH_SIZE = 1; + + private MLCommonsClientAccessor clientAccessor; + private Environment environment; + private ClusterService clusterService; + private SparseEncodingProcessorFactory sparseEncodingProcessorFactory; + + @Before + public void setup() { + clientAccessor = mock(MLCommonsClientAccessor.class); + environment = mock(Environment.class); + clusterService = mock(ClusterService.class); + sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(clientAccessor, environment, clusterService); + } + + @SneakyThrows + public void testCreateProcessor_whenAllRequiredParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.NONE, processor.getPruneType()); + assertEquals(0f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenPruneParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 2f); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.TOP_K, processor.getPruneType()); + assertEquals(2f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenEmptyFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of()); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unable to create the processor as field_map has invalid key or value", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingModelIdField_thenFail() { + Map config = new HashMap<>(); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[model_id] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[field_map] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "invalid_prune_type"); + config.put(PRUNE_RATIO_FIELD, 2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unknown pruning type: invalid_prune_type", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 0.2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "alpha_mass"); + + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[prune_ratio] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_RATIO_FIELD, 0.1); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("prune_ratio field is not supported when prune_type is not provided", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java index 887d8fc17..234a70823 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.test.OpenSearchTestCase; public class TokenWeightUtilTests extends OpenSearchTestCase { @@ -104,4 +105,36 @@ public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_th List> inputData = List.of(Map.of("response", List.of(mockData))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } + + public void testFetchListOfTokenWeightMap_invokeWithPrune() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.MAX_RATIO, 0.8f), List.of(Map.of("world", 2f))); + } + + public void testFetchListOfTokenWeightMap_invokeWithPrune_MultipleObjectsInMultipleResponse() { + + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + },{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA)), Map.of("response", List.of(MOCK_DATA))); + assertEquals( + TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.TOP_K, 1f), + List.of(Map.of("world", 2f), Map.of("world", 2f)) + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java new file mode 100644 index 000000000..a1a823093 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.pruning; + +import org.opensearch.test.OpenSearchTestCase; + +public class PruneTypeTests extends OpenSearchTestCase { + public void testGetValue() { + assertEquals("none", PruneType.NONE.getValue()); + assertEquals("top_k", PruneType.TOP_K.getValue()); + assertEquals("alpha_mass", PruneType.ALPHA_MASS.getValue()); + assertEquals("max_ratio", PruneType.MAX_RATIO.getValue()); + assertEquals("abs_value", PruneType.ABS_VALUE.getValue()); + } + + public void testFromString() { + assertEquals(PruneType.NONE, PruneType.fromString("none")); + assertEquals(PruneType.NONE, PruneType.fromString(null)); + assertEquals(PruneType.NONE, PruneType.fromString("")); + assertEquals(PruneType.TOP_K, PruneType.fromString("top_k")); + assertEquals(PruneType.ALPHA_MASS, PruneType.fromString("alpha_mass")); + assertEquals(PruneType.MAX_RATIO, PruneType.fromString("max_ratio")); + assertEquals(PruneType.ABS_VALUE, PruneType.fromString("abs_value")); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneType.fromString("test_value")); + assertEquals("Unknown pruning type: test_value", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java index 07a7f11eb..74aadf09f 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java @@ -4,14 +4,13 @@ */ package org.opensearch.neuralsearch.util.pruning; -import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; import java.util.HashMap; import java.util.Map; public class PruneUtilsTests extends OpenSearchTestCase { - @Test + public void testPruningByTopK() { Map input = new HashMap<>(); input.put("a", 5.0f); @@ -19,7 +18,7 @@ public void testPruningByTopK() { input.put("c", 4.0f); input.put("d", 1.0f); - Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input); + Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); @@ -28,7 +27,6 @@ public void testPruningByTopK() { assertEquals(4.0f, result.get("c"), 0.001); } - @Test public void testPruningByMaxRatio() { Map input = new HashMap<>(); input.put("a", 10.0f); @@ -36,14 +34,13 @@ public void testPruningByMaxRatio() { input.put("c", 5.0f); input.put("d", 2.0f); - Map result = PruneUtils.pruningSparseVector(PruningType.MAX_RATIO, 0.7f, input); + Map result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 assertTrue(result.containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 } - @Test public void testPruningByValue() { Map input = new HashMap<>(); input.put("a", 5.0f); @@ -51,14 +48,13 @@ public void testPruningByValue() { input.put("c", 2.0f); input.put("d", 1.0f); - Map result = PruneUtils.pruningSparseVector(PruningType.ABS_VALUE, 3.0f, input); + Map result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); assertTrue(result.containsKey("b")); } - @Test public void testPruningByAlphaMass() { Map input = new HashMap<>(); input.put("a", 10.0f); @@ -67,22 +63,20 @@ public void testPruningByAlphaMass() { input.put("d", 1.0f); // Total sum = 20.0 - Map result = PruneUtils.pruningSparseVector(PruningType.ALPHA_MASS, 0.8f, input); + Map result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input); assertEquals(2, result.size()); assertTrue(result.containsKey("a")); assertTrue(result.containsKey("b")); } - @Test public void testEmptyInput() { Map input = new HashMap<>(); - Map result = PruneUtils.pruningSparseVector(PruningType.TOP_K, 5, input); + Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input); assertTrue(result.isEmpty()); } - @Test public void testNegativeValues() { Map input = new HashMap<>(); input.put("a", -5.0f); @@ -91,12 +85,11 @@ public void testNegativeValues() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(PruningType.TOP_K, 2, input) + () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input) ); - assertEquals(exception.getMessage(), "Pruned values must be positive"); + assertEquals("Pruned values must be positive", exception.getMessage()); } - @Test public void testInvalidPruningType() { Map input = new HashMap<>(); input.put("a", 1.0f); @@ -115,43 +108,41 @@ public void testInvalidPruningType() { assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); } - @Test public void testIsValidPruneRatio() { // Test TOP_K validation - assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 100)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 0)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, -1)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.TOP_K, 1.5f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 100)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, -1)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1.5f)); // Test ALPHA_MASS validation - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0.5f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 0)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, -0.1f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, 1.1f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.1f)); // Test MAX_RATIO validation - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.0f)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 0.5f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, -0.1f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, 1.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.1f)); // Test ABS_VALUE validation - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 0.0f)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 1.0f)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, 100.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 100.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, -0.1f)); // Test with extreme cases - assertTrue(PruneUtils.isValidPruneRatio(PruningType.TOP_K, Float.MAX_VALUE)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ABS_VALUE, Float.MAX_VALUE)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.ALPHA_MASS, Float.MIN_VALUE)); - assertTrue(PruneUtils.isValidPruneRatio(PruningType.MAX_RATIO, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, Float.MIN_VALUE)); } - @Test public void testIsValidPruneRatioWithNullType() { IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); assertEquals("Pruning type cannot be null", exception.getMessage()); From 26098cc425442b4c61c341f74f8a3b382b821023 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 16:53:10 +0800 Subject: [PATCH 04/17] changelog Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c127ef7d7..b76f1c39f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Enhancements ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) From 6af02b878c8a9efd1e4b68f5b9342f687bb5aff6 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Nov 2024 17:44:35 +0800 Subject: [PATCH 05/17] ut Signed-off-by: zhichao-aws --- .../SparseEncodingProcessorTests.java | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..d705616a9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -14,10 +14,12 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.Map; import java.util.ArrayList; import java.util.Collections; @@ -49,6 +51,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.util.pruning.PruneType; public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock @@ -90,6 +93,17 @@ private SparseEncodingProcessor createInstance(int batchSize) { return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } + @SneakyThrows + private SparseEncodingProcessor createInstance(PruneType pruneType, float pruneRatio) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put("prune_type", pruneType.getValue()); + config.put("prune_ratio", pruneRatio); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); @@ -260,9 +274,98 @@ public void test_batchExecute_exception() { } } + @SuppressWarnings("unchecked") + public void testExecute_withPruningConfig_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> dataAsMapList = Collections.singletonList( + Map.of("response", Arrays.asList(ImmutableMap.of("hello", 1.0f, "world", 0.1f), ImmutableMap.of("test", 0.8f, "low", 0.4f))) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + + ArgumentCaptor docCaptor = ArgumentCaptor.forClass(IngestDocument.class); + verify(handler).accept(docCaptor.capture(), isNull()); + + IngestDocument processedDoc = docCaptor.getValue(); + Map first = (Map) processedDoc.getFieldValue("key1Mapped", Map.class); + Map second = (Map) processedDoc.getFieldValue("key2Mapped", Map.class); + + assertNotNull(first); + assertNotNull(second); + + assertTrue(first.containsKey("hello")); + assertFalse(first.containsKey("world")); + assertEquals(1.0f, first.get("hello"), 0.001f); + + assertTrue(second.containsKey("test")); + assertTrue(second.containsKey("low")); + assertEquals(0.8f, second.get("test"), 0.001f); + assertEquals(0.4f, second.get("low"), 0.001f); + } + + public void test_batchExecute_withPruning_successful() { + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> mockMLResponse = Collections.singletonList( + Map.of( + "response", + Arrays.asList( + ImmutableMap.of("token1", 1.0f, "token2", 0.3f, "token3", 0.8f), + ImmutableMap.of("token4", 0.9f, "token5", 0.2f, "token6", 0.7f) + ) + ) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(mockMLResponse); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer> resultHandler = mock(Consumer.class); + Consumer exceptionHandler = mock(Consumer.class); + + List inferenceList = Arrays.asList("test1", "test2"); + processor.doBatchExecute(inferenceList, resultHandler, exceptionHandler); + + ArgumentCaptor>> resultCaptor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCaptor.capture()); + verify(exceptionHandler, never()).accept(any()); + + List> processedResults = resultCaptor.getValue(); + + assertEquals(2, processedResults.size()); + + Map firstResult = processedResults.get(0); + assertEquals(2, firstResult.size()); + assertTrue(firstResult.containsKey("token1")); + assertTrue(firstResult.containsKey("token3")); + assertFalse(firstResult.containsKey("token2")); + + Map secondResult = processedResults.get(1); + assertEquals(2, secondResult.size()); + assertTrue(secondResult.containsKey("token4")); + assertTrue(secondResult.containsKey("token6")); + assertFalse(secondResult.containsKey("token5")); + } + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); - IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f, "world", 0.1f))); List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); return mockMapResult; From 2ac90de8da1f81a4d6dd29e2249bbb98357f7692 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 19 Nov 2024 15:58:35 +0800 Subject: [PATCH 06/17] add it Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessIT.java | 30 +++++++++++++++++++ ...ncodingPipelineConfigurationWithPrune.json | 21 +++++++++++++ .../UploadSparseEncodingModelRequestBody.json | 10 ++----- .../neuralsearch/BaseNeuralSearchIT.java | 7 +++-- 4 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index 349da1033..83b680d19 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -18,6 +18,7 @@ import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.collect.ImmutableList; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; public class SparseEncodingProcessIT extends BaseNeuralSearchIT { @@ -39,6 +40,35 @@ public void testSparseEncodingProcessor() throws Exception { createSparseEncodingIndex(); ingestDocument(); assertEquals(1, getDocCount(INDEX_NAME)); + + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("title_sparse"); + neuralSparseQueryBuilder.queryTokensSupplier(() -> Map.of("good", 1.0f, "a", 2.0f)); + Map searchResponse = search(INDEX_NAME, neuralSparseQueryBuilder, 2); + assertFalse(searchResponse.isEmpty()); + double maxScore = (Double) ((Map) searchResponse.get("hits")).get("max_score"); + assertEquals(4.4433594, maxScore, 1e-3); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + public void testSparseEncodingProcessorWithPrune() throws Exception { + String modelId = null; + try { + modelId = prepareSparseEncodingModel(); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.SPARSE_ENCODING_PRUNE); + createSparseEncodingIndex(); + ingestDocument(); + assertEquals(1, getDocCount(INDEX_NAME)); + + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("title_sparse"); + neuralSparseQueryBuilder.queryTokensSupplier(() -> Map.of("good", 1.0f, "a", 2.0f)); + Map searchResponse = search(INDEX_NAME, neuralSparseQueryBuilder, 2); + assertFalse(searchResponse.isEmpty()); + double maxScore = (Double) ((Map) searchResponse.get("hits")).get("max_score"); + assertEquals(3.640625, maxScore, 1e-3); } finally { wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); } diff --git a/src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json b/src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json new file mode 100644 index 000000000..642228e06 --- /dev/null +++ b/src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json @@ -0,0 +1,21 @@ +{ + "description": "An example sparse Encoding pipeline", + "processors" : [ + { + "sparse_encoding": { + "model_id": "%s", + "batch_size": "%d", + "prune_type": "max_ratio", + "prune_ratio": 0.8, + "field_map": { + "title": "title_sparse", + "favor_list": "favor_list_sparse", + "favorites": { + "game": "game_sparse", + "movie": "movie_sparse" + } + } + } + } + ] +} diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json index 5c2a73f6e..6bdac87c5 100644 --- a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json +++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json @@ -1,10 +1,6 @@ { - "name": "tokenize-idf-0915", - "version": "1.0.0", - "function_name": "SPARSE_TOKENIZE", - "description": "test model", - "model_format": "TORCH_SCRIPT", + "name": "amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1", + "version": "1.0.1", "model_group_id": "%s", - "model_content_hash_value": "b345e9e943b62c405a8dd227ef2c46c84c5ff0a0b71b584be9132b37bce91a9a", - "url": "https://github.com/opensearch-project/ml-commons/raw/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/sparse_encoding/sparse_demo.zip" + "model_format": "TORCH_SCRIPT" } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index afc545447..4bfb9d8c0 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -85,7 +85,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ProcessorType.TEXT_IMAGE_EMBEDDING, "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, - "processor/PipelineConfigurationWithNestedFieldsMapping.json" + "processor/PipelineConfigurationWithNestedFieldsMapping.json", + ProcessorType.SPARSE_ENCODING_PRUNE, + "processor/SparseEncodingPipelineConfigurationWithPrune.json" ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; @@ -1439,6 +1441,7 @@ protected enum ProcessorType { TEXT_EMBEDDING, TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, TEXT_IMAGE_EMBEDDING, - SPARSE_ENCODING + SPARSE_ENCODING, + SPARSE_ENCODING_PRUNE } } From 97963f102142d53be52b0693df3b5a4f42d41919 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Nov 2024 11:48:27 +0800 Subject: [PATCH 07/17] change on 2-phase Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessor.java | 86 +++++------- .../processor/SparseEncodingProcessor.java | 21 +-- .../query/NeuralSparseQueryBuilder.java | 22 +-- .../neuralsearch/util/TokenWeightUtil.java | 21 +-- .../neuralsearch/util/pruning/PruneUtils.java | 126 ++++++++++++------ .../NeuralSparseTwoPhaseProcessorTests.java | 31 +---- .../util/TokenWeightUtilTests.java | 33 ----- .../util/pruning/PruneUtilsTests.java | 119 +++++++++++++---- 8 files changed, 244 insertions(+), 215 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 8d386e615..1214057cb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -14,6 +14,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; @@ -37,41 +39,37 @@ public class NeuralSparseTwoPhaseProcessor extends AbstractProcessor implements public static final String TYPE = "neural_sparse_two_phase_processor"; private boolean enabled; - private float ratio; + private float pruneRatio; + private PruneType pruneType; private float windowExpansion; private int maxWindowSize; private static final String PARAMETER_KEY = "two_phase_parameter"; - private static final String RATIO_KEY = "prune_ratio"; private static final String ENABLE_KEY = "enabled"; private static final String EXPANSION_KEY = "expansion_rate"; private static final String MAX_WINDOW_SIZE_KEY = "max_window_size"; private static final boolean DEFAULT_ENABLED = true; private static final float DEFAULT_RATIO = 0.4f; + private static final PruneType DEFAULT_PRUNE_TYPE = PruneType.MAX_RATIO; private static final float DEFAULT_WINDOW_EXPANSION = 5.0f; private static final int DEFAULT_MAX_WINDOW_SIZE = 10000; private static final int DEFAULT_BASE_QUERY_SIZE = 10; private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50; private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f; - private static final float RATIO_LOWER_BOUND = 0f; - private static final float RATIO_UPPER_BOUND = 1f; protected NeuralSparseTwoPhaseProcessor( String tag, String description, boolean ignoreFailure, boolean enabled, - float ratio, + float pruneRatio, + PruneType pruneType, float windowExpansion, int maxWindowSize ) { super(tag, description, ignoreFailure); this.enabled = enabled; - if (ratio < RATIO_LOWER_BOUND || ratio > RATIO_UPPER_BOUND) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f", ratio) - ); - } - this.ratio = ratio; + this.pruneRatio = pruneRatio; + this.pruneType = pruneType; if (windowExpansion < WINDOW_EXPANSION_LOWER_BOUND) { throw new IllegalArgumentException( String.format(Locale.ROOT, "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f", windowExpansion) @@ -93,7 +91,7 @@ protected NeuralSparseTwoPhaseProcessor( */ @Override public SearchRequest processRequest(final SearchRequest request) { - if (!enabled || ratio == 0f) { + if (!enabled || pruneRatio == 0f) { return request; } QueryBuilder queryBuilder = request.source().query(); @@ -117,43 +115,6 @@ public String getType() { return TYPE; } - /** - * Based on ratio, split a Map into two map by the value. - * - * @param queryTokens the queryTokens map, key is the token String, value is the score. - * @param thresholdRatio The ratio that control how tokens map be split. - * @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold } - */ - public static Tuple, Map> splitQueryTokensByRatioedMaxScoreAsThreshold( - final Map queryTokens, - final float thresholdRatio - ) { - if (Objects.isNull(queryTokens)) { - throw new IllegalArgumentException("Query tokens cannot be null or empty."); - } - float max = 0f; - for (Float value : queryTokens.values()) { - max = Math.max(value, max); - } - float threshold = max * thresholdRatio; - - Map> queryTokensByScore = queryTokens.entrySet() - .stream() - .collect( - Collectors.partitioningBy(entry -> entry.getValue() >= threshold, Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) - ); - - Map highScoreTokens = queryTokensByScore.get(Boolean.TRUE); - Map lowScoreTokens = queryTokensByScore.get(Boolean.FALSE); - if (Objects.isNull(highScoreTokens)) { - highScoreTokens = Collections.emptyMap(); - } - if (Objects.isNull(lowScoreTokens)) { - lowScoreTokens = Collections.emptyMap(); - } - return Tuple.tuple(highScoreTokens, lowScoreTokens); - } - private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap( final Multimap queryBuilderFloatMap ) { @@ -201,7 +162,10 @@ private Multimap collectNeuralSparseQueryBuilde * - Docs besides TopDocs: Score = HighScoreToken's score * - Final TopDocs: Score = HighScoreToken's score + LowScoreToken's score */ - NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio); + NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase( + pruneRatio, + pruneType + ); result.put(modifiedQueryBuilder, updatedBoost); } // We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will @@ -248,16 +212,32 @@ public NeuralSparseTwoPhaseProcessor create( boolean enabled = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, ENABLE_KEY, DEFAULT_ENABLED); Map twoPhaseConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY); - float ratio = DEFAULT_RATIO; + float pruneRatio = DEFAULT_RATIO; float windowExpansion = DEFAULT_WINDOW_EXPANSION; int maxWindowSize = DEFAULT_MAX_WINDOW_SIZE; + PruneType pruneType = DEFAULT_PRUNE_TYPE; if (Objects.nonNull(twoPhaseConfigMap)) { - ratio = ((Number) twoPhaseConfigMap.getOrDefault(RATIO_KEY, ratio)).floatValue(); + pruneRatio = ((Number) twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_RATIO_FIELD, pruneRatio)).floatValue(); windowExpansion = ((Number) twoPhaseConfigMap.getOrDefault(EXPANSION_KEY, windowExpansion)).floatValue(); maxWindowSize = ((Number) twoPhaseConfigMap.getOrDefault(MAX_WINDOW_SIZE_KEY, maxWindowSize)).intValue(); + pruneType = PruneType.fromString( + twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString() + ); } + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( + "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue() + ); - return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, windowExpansion, maxWindowSize); + return new NeuralSparseTwoPhaseProcessor( + tag, + description, + ignoreFailure, + enabled, + pruneRatio, + pruneType, + windowExpansion, + maxWindowSize + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index a3a6cacbb..35a8f25f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -19,6 +19,7 @@ import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; /** * This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use, @@ -59,7 +60,10 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); + sparseVectors = sparseVectors.stream() + .map(vector -> PruneUtils.pruningSparseVector(pruneType, pruneRatio, vector, false).v1()) + .toList(); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -67,13 +71,12 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { - mlCommonsClientAccessor.inferenceSentencesWithMapResult( - this.modelId, - inferenceList, - ActionListener.wrap( - resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps, pruneType, pruneRatio)), - onException - ) - ); + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); + sparseVectors = sparseVectors.stream() + .map(vector -> PruneUtils.pruningSparseVector(pruneType, pruneRatio, vector, false).v1()) + .toList(); + handler.accept(sparseVectors); + }, onException)); } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index f46997d5e..ba86b8872 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -47,8 +47,8 @@ import lombok.NoArgsConstructor; import lombok.Setter; import lombok.experimental.Accessors; - -import static org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; /** * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model @@ -90,6 +90,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder tokens = queryTokensSupplier.get(); // Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1, // while those less than or equal to the threshold are stored in v2. - Tuple, Map> splitTokens = splitQueryTokensByRatioedMaxScoreAsThreshold(tokens, ratio); + Tuple, Map> splitTokens = PruneUtils.pruningSparseVector(pruneType, pruneRatio, tokens, true); this.queryTokensSupplier(() -> splitTokens.v1()); copy.queryTokensSupplier(() -> splitTokens.v2()); } else { @@ -346,9 +348,11 @@ private BiConsumer> getModelInferenceAsync(SetOnce { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { - Tuple, Map> splitQueryTokens = splitQueryTokensByRatioedMaxScoreAsThreshold( + Tuple, Map> splitQueryTokens = PruneUtils.pruningSparseVector( + twoPhasePruneType, + twoPhasePruneRatio, queryTokens, - twoPhasePruneRatio + true ); setOnce.set(splitQueryTokens.v1()); twoPhaseSharedQueryToken = splitQueryTokens.v2(); diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 0de3610f8..e36b42cd6 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -4,9 +4,6 @@ */ package org.opensearch.neuralsearch.util; -import org.opensearch.neuralsearch.util.pruning.PruneType; -import org.opensearch.neuralsearch.util.pruning.PruneUtils; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -47,11 +44,7 @@ public class TokenWeightUtil { * * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ - public static List> fetchListOfTokenWeightMap( - List> mapResultList, - PruneType pruneType, - float pruneRatio - ) { + public static List> fetchListOfTokenWeightMap(List> mapResultList) { if (null == mapResultList || mapResultList.isEmpty()) { throw new IllegalArgumentException("The inference result can not be null or empty."); } @@ -65,16 +58,10 @@ public static List> fetchListOfTokenWeightMap( } results.addAll((List) map.get("response")); } - return results.stream() - .map(uncastedMap -> TokenWeightUtil.buildTokenWeightMap(uncastedMap, pruneType, pruneRatio)) - .collect(Collectors.toList()); - } - - public static List> fetchListOfTokenWeightMap(List> mapResultList) { - return TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList, PruneType.NONE, 0f); + return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList()); } - private static Map buildTokenWeightMap(Object uncastedMap, PruneType pruneType, float pruneRatio) { + private static Map buildTokenWeightMap(Object uncastedMap) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } @@ -85,6 +72,6 @@ private static Map buildTokenWeightMap(Object uncastedMap, PruneT } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } - return PruneUtils.pruningSparseVector(pruneType, pruneRatio, result); + return result; } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java index 87e87cffa..ed7ac7f03 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java @@ -4,6 +4,8 @@ */ package org.opensearch.neuralsearch.util.pruning; +import org.opensearch.common.collect.Tuple; + import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -26,9 +28,14 @@ public class PruneUtils { * * @param sparseVector The input sparse vector as a map of string keys to float values * @param k The number of top elements to keep - * @return A new map containing only the top K elements + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with top K elements, the second with remaining elements (or null) */ - private static Map pruningByTopK(Map sparseVector, int k) { + private static Tuple, Map> pruningByTopK( + Map sparseVector, + int k, + boolean requiresPrunedEntries + ) { PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); for (Map.Entry entry : sparseVector.entrySet()) { @@ -40,13 +47,18 @@ private static Map pruningByTopK(Map sparseVector, } } - Map result = new HashMap<>(); + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>(sparseVector) : null; + while (!pq.isEmpty()) { Map.Entry entry = pq.poll(); - result.put(entry.getKey(), entry.getValue()); + highScores.put(entry.getKey(), entry.getValue()); + if (requiresPrunedEntries) { + lowScores.remove(entry.getKey()); + } } - return result; + return new Tuple<>(highScores, lowScores); } /** @@ -55,21 +67,29 @@ private static Map pruningByTopK(Map sparseVector, * * @param sparseVector The input sparse vector as a map of string keys to float values * @param ratio The minimum ratio relative to the maximum value for elements to be kept - * @return A new map containing only elements meeting the ratio threshold + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with elements meeting the ratio threshold, + * the second with elements below the threshold (or null) */ - private static Map pruningByMaxRatio(Map sparseVector, float ratio) { + private static Tuple, Map> pruningByMaxRatio( + Map sparseVector, + float ratio, + boolean requiresPrunedEntries + ) { float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); - Map result = new HashMap<>(); - for (Map.Entry entry : sparseVector.entrySet()) { - float currentRatio = entry.getValue() / maxValue; + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>() : null; - if (currentRatio >= ratio) { - result.put(entry.getKey(), entry.getValue()); + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() >= ratio * maxValue) { + highScores.put(entry.getKey(), entry.getValue()); + } else if (requiresPrunedEntries) { + lowScores.put(entry.getKey(), entry.getValue()); } } - return result; + return new Tuple<>(highScores, lowScores); } /** @@ -77,17 +97,27 @@ private static Map pruningByMaxRatio(Map sparseVec * * @param sparseVector The input sparse vector as a map of string keys to float values * @param thresh The minimum absolute value for elements to be kept - * @return A new map containing only elements meeting the threshold + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with elements above the threshold, + * the second with elements below the threshold (or null) */ - private static Map pruningByValue(Map sparseVector, float thresh) { - Map result = new HashMap<>(sparseVector); + private static Tuple, Map> pruningByValue( + Map sparseVector, + float thresh, + boolean requiresPrunedEntries + ) { + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>() : null; + for (Map.Entry entry : sparseVector.entrySet()) { - if (entry.getValue() < thresh) { - result.remove(entry.getKey()); + if (entry.getValue() >= thresh) { + highScores.put(entry.getKey(), entry.getValue()); + } else if (requiresPrunedEntries) { + lowScores.put(entry.getKey(), entry.getValue()); } } - return result; + return new Tuple<>(highScores, lowScores); } /** @@ -96,27 +126,36 @@ private static Map pruningByValue(Map sparseVector * * @param sparseVector The input sparse vector as a map of string keys to float values * @param alpha The minimum ratio relative to the total sum for elements to be kept - * @return A new map containing only elements meeting the ratio threshold + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with elements meeting the alpha mass threshold, + * the second with remaining elements (or null) */ - private static Map pruningByAlphaMass(Map sparseVector, float alpha) { + private static Tuple, Map> pruningByAlphaMass( + Map sparseVector, + float alpha, + boolean requiresPrunedEntries + ) { List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); float topSum = 0f; - Map result = new HashMap<>(); + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>() : null; + for (Map.Entry entry : sortedEntries) { float value = entry.getValue(); topSum += value; - result.put(entry.getKey(), value); - if (topSum / sum >= alpha) { - break; + if (topSum <= alpha * sum) { + highScores.put(entry.getKey(), value); + } else if (requiresPrunedEntries) { + lowScores.put(entry.getKey(), value); } } - return result; + return new Tuple<>(highScores, lowScores); } /** @@ -125,12 +164,23 @@ private static Map pruningByAlphaMass(Map sparseVe * @param pruneType The type of pruning strategy to use * @param pruneRatio The ratio or threshold for pruning * @param sparseVector The input sparse vector as a map of string keys to float values - * @return A new map containing the pruned sparse vector + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with high-scoring elements, + * the second with low-scoring elements (or null if requiresPrunedEntries is false) */ - public static Map pruningSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { - if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) throw new IllegalArgumentException( - "Prune type and prune ratio must be provided" - ); + public static Tuple, Map> pruningSparseVector( + PruneType pruneType, + float pruneRatio, + Map sparseVector, + boolean requiresPrunedEntries + ) { + if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { + throw new IllegalArgumentException("Prune type and prune ratio must be provided"); + } + + if (Objects.isNull(sparseVector)) { + throw new IllegalArgumentException("Sparse vector must be provided"); + } for (Map.Entry entry : sparseVector.entrySet()) { if (entry.getValue() <= 0) { @@ -140,15 +190,15 @@ public static Map pruningSparseVector(PruneType pruneType, float switch (pruneType) { case TOP_K: - return pruningByTopK(sparseVector, (int) pruneRatio); + return pruningByTopK(sparseVector, (int) pruneRatio, requiresPrunedEntries); case ALPHA_MASS: - return pruningByAlphaMass(sparseVector, pruneRatio); + return pruningByAlphaMass(sparseVector, pruneRatio, requiresPrunedEntries); case MAX_RATIO: - return pruningByMaxRatio(sparseVector, pruneRatio); + return pruningByMaxRatio(sparseVector, pruneRatio, requiresPrunedEntries); case ABS_VALUE: - return pruningByValue(sparseVector, pruneRatio); + return pruningByValue(sparseVector, pruneRatio, requiresPrunedEntries); default: - return sparseVector; + return new Tuple<>(new HashMap<>(sparseVector), requiresPrunedEntries ? new HashMap<>() : null); } } @@ -170,9 +220,9 @@ public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); case ALPHA_MASS: case MAX_RATIO: - return pruneRatio > 0 && pruneRatio < 1; + return pruneRatio >= 0 && pruneRatio < 1; case ABS_VALUE: - return pruneRatio > 0; + return pruneRatio >= 0; default: return true; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java index 2ce7c7b96..655728c49 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java @@ -6,7 +6,6 @@ import lombok.SneakyThrows; import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.collect.Tuple; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; @@ -28,7 +27,7 @@ public class NeuralSparseTwoPhaseProcessorTests extends OpenSearchTestCase { public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory); - assertEquals(0.3f, processor.getRatio(), 1e-3); + assertEquals(0.3f, processor.getPruneRatio(), 1e-3); assertEquals(4.0f, processor.getWindowExpansion(), 1e-3); assertEquals(10000, processor.getMaxWindowSize()); @@ -40,7 +39,7 @@ public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception Collections.emptyMap(), null ); - assertEquals(0.4f, defaultProcessor.getRatio(), 1e-3); + assertEquals(0.4f, defaultProcessor.getPruneRatio(), 1e-3); assertEquals(5.0f, defaultProcessor.getWindowExpansion(), 1e-3); assertEquals(10000, defaultProcessor.getMaxWindowSize()); } @@ -140,32 +139,6 @@ public void testProcessRequest_whenTwoPhaseEnabledAndWithOutNeuralSparseQuery_th assertNull(returnRequest.source().rescores()); } - @SneakyThrows - public void testGetSplitSetOnceByScoreThreshold() { - Map queryTokens = new HashMap<>(); - for (int i = 0; i < 10; i++) { - queryTokens.put(String.valueOf(i), (float) i); - } - Tuple, Map> splitQueryTokens = NeuralSparseTwoPhaseProcessor - .splitQueryTokensByRatioedMaxScoreAsThreshold(queryTokens, 0.4f); - assertNotNull(splitQueryTokens); - Map upSet = splitQueryTokens.v1(); - Map downSet = splitQueryTokens.v2(); - assertNotNull(upSet); - assertEquals(6, upSet.size()); - assertNotNull(downSet); - assertEquals(4, downSet.size()); - } - - @SneakyThrows - public void testGetSplitSetOnceByScoreThreshold_whenNullQueryToken_thenThrowException() { - Map queryTokens = null; - expectThrows( - IllegalArgumentException.class, - () -> NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold(queryTokens, 0.4f) - ); - } - public void testType() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory); diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java index 234a70823..887d8fc17 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -7,7 +7,6 @@ import java.util.List; import java.util.Map; -import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.test.OpenSearchTestCase; public class TokenWeightUtilTests extends OpenSearchTestCase { @@ -105,36 +104,4 @@ public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_th List> inputData = List.of(Map.of("response", List.of(mockData))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } - - public void testFetchListOfTokenWeightMap_invokeWithPrune() { - /* - [{ - "response": [ - {"hello": 1.0, "world": 2.0} - ] - }] - */ - List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); - assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.MAX_RATIO, 0.8f), List.of(Map.of("world", 2f))); - } - - public void testFetchListOfTokenWeightMap_invokeWithPrune_MultipleObjectsInMultipleResponse() { - - /* - [{ - "response": [ - {"hello": 1.0, "world": 2.0} - ] - },{ - "response": [ - {"hello": 1.0, "world": 2.0} - ] - }] - */ - List> inputData = List.of(Map.of("response", List.of(MOCK_DATA)), Map.of("response", List.of(MOCK_DATA))); - assertEquals( - TokenWeightUtil.fetchListOfTokenWeightMap(inputData, PruneType.TOP_K, 1f), - List.of(Map.of("world", 2f), Map.of("world", 2f)) - ); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java index 74aadf09f..8dc31711a 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.util.pruning; +import org.opensearch.common.collect.Tuple; import org.opensearch.test.OpenSearchTestCase; import java.util.HashMap; @@ -18,13 +19,23 @@ public void testPruningByTopK() { input.put("c", 4.0f); input.put("d", 1.0f); - Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input); + // Test without pruned entries + Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input, false); - assertEquals(2, result.size()); - assertTrue(result.containsKey("a")); - assertTrue(result.containsKey("c")); - assertEquals(5.0f, result.get("a"), 0.001); - assertEquals(4.0f, result.get("c"), 0.001); + assertEquals(2, result.v1().size()); + assertNull(result.v2()); + assertTrue(result.v1().containsKey("a")); + assertTrue(result.v1().containsKey("c")); + assertEquals(5.0f, result.v1().get("a"), 0.001); + assertEquals(4.0f, result.v1().get("c"), 0.001); + + // Test with pruned entries + result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input, true); + + assertEquals(2, result.v1().size()); + assertEquals(2, result.v2().size()); + assertTrue(result.v2().containsKey("b")); + assertTrue(result.v2().containsKey("d")); } public void testPruningByMaxRatio() { @@ -34,11 +45,21 @@ public void testPruningByMaxRatio() { input.put("c", 5.0f); input.put("d", 2.0f); - Map result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input); + // Test without pruned entries + Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input, false); + + assertEquals(2, result.v1().size()); + assertNull(result.v2()); + assertTrue(result.v1().containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 + assertTrue(result.v1().containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 - assertEquals(2, result.size()); - assertTrue(result.containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 - assertTrue(result.containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 + // Test with pruned entries + result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input, true); + + assertEquals(2, result.v1().size()); + assertEquals(2, result.v2().size()); + assertTrue(result.v2().containsKey("c")); + assertTrue(result.v2().containsKey("d")); } public void testPruningByValue() { @@ -48,11 +69,21 @@ public void testPruningByValue() { input.put("c", 2.0f); input.put("d", 1.0f); - Map result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input); + // Test without pruned entries + Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input, false); + + assertEquals(2, result.v1().size()); + assertNull(result.v2()); + assertTrue(result.v1().containsKey("a")); + assertTrue(result.v1().containsKey("b")); - assertEquals(2, result.size()); - assertTrue(result.containsKey("a")); - assertTrue(result.containsKey("b")); + // Test with pruned entries + result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input, true); + + assertEquals(2, result.v1().size()); + assertEquals(2, result.v2().size()); + assertTrue(result.v2().containsKey("c")); + assertTrue(result.v2().containsKey("d")); } public void testPruningByAlphaMass() { @@ -61,20 +92,46 @@ public void testPruningByAlphaMass() { input.put("b", 6.0f); input.put("c", 3.0f); input.put("d", 1.0f); - // Total sum = 20.0 - Map result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input); + // Test without pruned entries + Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input, false); + + assertEquals(2, result.v1().size()); + assertNull(result.v2()); + assertTrue(result.v1().containsKey("a")); + assertTrue(result.v1().containsKey("b")); - assertEquals(2, result.size()); - assertTrue(result.containsKey("a")); - assertTrue(result.containsKey("b")); + // Test with pruned entries + result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input, true); + + assertEquals(2, result.v1().size()); + assertEquals(2, result.v2().size()); + assertTrue(result.v2().containsKey("c")); + assertTrue(result.v2().containsKey("d")); } public void testEmptyInput() { Map input = new HashMap<>(); - Map result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input); - assertTrue(result.isEmpty()); + Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input, false); + assertTrue(result.v1().isEmpty()); + assertNull(result.v2()); + + result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.5f, input, false); + assertTrue(result.v1().isEmpty()); + assertNull(result.v2()); + + result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.5f, input, false); + assertTrue(result.v1().isEmpty()); + assertNull(result.v2()); + + result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 0.5f, input, false); + assertTrue(result.v1().isEmpty()); + assertNull(result.v2()); + + result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input, true); + assertTrue(result.v1().isEmpty()); + assertTrue(result.v2().isEmpty()); } public void testNegativeValues() { @@ -85,7 +142,7 @@ public void testNegativeValues() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input) + () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input, false) ); assertEquals("Pruned values must be positive", exception.getMessage()); } @@ -97,17 +154,25 @@ public void testInvalidPruningType() { IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(null, 2, input) + () -> PruneUtils.pruningSparseVector(null, 2, input, false) ); assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(null, 2, input) + () -> PruneUtils.pruningSparseVector(null, 2, input, true) ); assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); } + public void testNullSparseVector() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, null, false) + ); + assertEquals(exception.getMessage(), "Sparse vector must be provided"); + } + public void testIsValidPruneRatio() { // Test TOP_K validation assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1)); @@ -119,19 +184,19 @@ public void testIsValidPruneRatio() { // Test ALPHA_MASS validation assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0.5f)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.0f)); - assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, -0.1f)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.1f)); // Test MAX_RATIO validation - assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.0f)); assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.5f)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.0f)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, -0.1f)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.1f)); // Test ABS_VALUE validation - assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 0.0f)); assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 1.0f)); assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 100.0f)); assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, -0.1f)); From c5dd6021f52fa2b315fa1e27f4a7cb4ec285e60e Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Nov 2024 13:11:42 +0800 Subject: [PATCH 08/17] UT Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessor.java | 3 -- .../NeuralSparseTwoPhaseProcessorTests.java | 52 +++++++++++++++++-- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 1214057cb..342889db5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -9,7 +9,6 @@ import lombok.Getter; import lombok.Setter; import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.collect.Tuple; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; @@ -23,11 +22,9 @@ import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.search.rescore.RescorerBuilder; -import java.util.Collections; import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; /** * A SearchRequestProcessor to generate two-phase NeuralSparseQueryBuilder, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java index 655728c49..24257127f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java @@ -9,6 +9,8 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.pruning.PruneUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.test.OpenSearchTestCase; @@ -19,7 +21,6 @@ public class NeuralSparseTwoPhaseProcessorTests extends OpenSearchTestCase { static final private String PARAMETER_KEY = "two_phase_parameter"; - static final private String RATIO_KEY = "prune_ratio"; static final private String ENABLE_KEY = "enabled"; static final private String EXPANSION_KEY = "expansion_rate"; static final private String MAX_WINDOW_SIZE_KEY = "max_window_size"; @@ -30,6 +31,7 @@ public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception assertEquals(0.3f, processor.getPruneRatio(), 1e-3); assertEquals(4.0f, processor.getWindowExpansion(), 1e-3); assertEquals(10000, processor.getMaxWindowSize()); + assertEquals(PruneType.MAX_RATIO, processor.getPruneType()); NeuralSparseTwoPhaseProcessor defaultProcessor = factory.create( Collections.emptyMap(), @@ -42,11 +44,23 @@ public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception assertEquals(0.4f, defaultProcessor.getPruneRatio(), 1e-3); assertEquals(5.0f, defaultProcessor.getWindowExpansion(), 1e-3); assertEquals(10000, defaultProcessor.getMaxWindowSize()); + assertEquals(PruneType.MAX_RATIO, processor.getPruneType()); + } + + public void testFactory_whenCreatePipelineWithCustomPruneType_thenSuccess() throws Exception { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 5f, "top_k", true, 5f, 1000); + assertEquals(5f, processor.getPruneRatio(), 1e-6); + assertEquals(PruneType.TOP_K, processor.getPruneType()); } public void testFactory_whenRatioOutOfRange_thenThrowException() { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, "max_ratio", true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 0f, "top_k", true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, "alpha_mass", true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, -1f, "abs_value", true, 5.0f, 10000)); } public void testFactory_whenWindowExpansionOutOfRange_thenThrowException() { @@ -72,6 +86,19 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio assertNotNull(searchRequest.source().rescores()); } + public void testProcessRequest_whenUseCustomPruneType_thenSuccess() throws Exception { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder)); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, "alpha_mass", true, 4.0f, 10000); + processor.processRequest(searchRequest); + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query(); + assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3); + assertEquals(queryBuilder.twoPhasePruneType(), PruneType.ALPHA_MASS); + assertNotNull(searchRequest.source().rescores()); + } + public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); @@ -155,9 +182,28 @@ private NeuralSparseTwoPhaseProcessor createTestProcessor( Map configMap = new HashMap<>(); configMap.put(ENABLE_KEY, enabled); Map twoPhaseParaMap = new HashMap<>(); - twoPhaseParaMap.put(RATIO_KEY, ratio); + twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, ratio); + twoPhaseParaMap.put(EXPANSION_KEY, expand); + twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, max_window); + configMap.put(PARAMETER_KEY, twoPhaseParaMap); + return factory.create(Collections.emptyMap(), null, null, false, configMap, null); + } + + private NeuralSparseTwoPhaseProcessor createTestProcessor( + NeuralSparseTwoPhaseProcessor.Factory factory, + float ratio, + String type, + boolean enabled, + float expand, + int max_window + ) throws Exception { + Map configMap = new HashMap<>(); + configMap.put(ENABLE_KEY, enabled); + Map twoPhaseParaMap = new HashMap<>(); + twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, ratio); twoPhaseParaMap.put(EXPANSION_KEY, expand); twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, max_window); + twoPhaseParaMap.put(PruneUtils.PRUNE_TYPE_FIELD, type); configMap.put(PARAMETER_KEY, twoPhaseParaMap); return factory.create(Collections.emptyMap(), null, null, false, configMap, null); } @@ -166,7 +212,7 @@ private NeuralSparseTwoPhaseProcessor createTestProcessor(NeuralSparseTwoPhasePr Map configMap = new HashMap<>(); configMap.put(ENABLE_KEY, true); Map twoPhaseParaMap = new HashMap<>(); - twoPhaseParaMap.put(RATIO_KEY, 0.3f); + twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, 0.3f); twoPhaseParaMap.put(EXPANSION_KEY, 4.0f); twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, 10000); configMap.put(PARAMETER_KEY, twoPhaseParaMap); From c7f0031ffe28081b3e44a8f48ea71d13de28abc8 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Nov 2024 13:37:50 +0800 Subject: [PATCH 09/17] it Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessorIT.java | 30 ++++---------- .../query/NeuralSparseQueryBuilderTests.java | 39 +++++++++++++++++++ 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java index 3e4ed8844..5f921809a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java @@ -47,15 +47,8 @@ public class NeuralSparseTwoPhaseProcessorIT extends BaseNeuralSearchIT { private final Map testRankFeaturesDoc = createRandomTokenWeightMap(TEST_TOKENS); private static final List TWO_PHASE_TEST_TOKEN = List.of("hello", "world"); - private static final Map testFixedQueryTokens = new HashMap<>(); + private static final Map testFixedQueryTokens = Map.of("hello", 5.0f, "world", 4.0f, "a", 3.0f, "b", 2.0f, "c", 1.0f); private static final Supplier> testFixedQueryTokenSupplier = () -> testFixedQueryTokens; - static { - testFixedQueryTokens.put("hello", 5.0f); - testFixedQueryTokens.put("world", 4.0f); - testFixedQueryTokens.put("a", 3.0f); - testFixedQueryTokens.put("b", 2.0f); - testFixedQueryTokens.put("c", 1.0f); - } @Before public void setUp() throws Exception { @@ -82,7 +75,6 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries_whenTwoPhaseEnabl NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) .queryTokensSupplier(randomTokenWeightSupplier); NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(randomTokenWeightSupplier); boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); @@ -116,7 +108,7 @@ private void setDefaultSearchPipelineForIndex(String indexName) { * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "TEST_QUERY_TEXT", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -127,13 +119,12 @@ private void setDefaultSearchPipelineForIndex(String indexName) { * } */ @SneakyThrows - public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScore() { + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabled_thenGetExpectedScore() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); initializeTwoPhaseProcessor(); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -148,14 +139,13 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScor } @SneakyThrows - public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); initializeTwoPhaseProcessor(); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -164,7 +154,6 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS float scoreWithoutTwoPhase = objectToFloat(firstInnerHit.get("_score")); sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -190,7 +179,7 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -209,7 +198,6 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); QueryBuilder queryBuilder = new MatchAllQueryBuilder(); @@ -232,7 +220,7 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -240,7 +228,7 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -316,7 +304,6 @@ public void testMultiNeuralSparseQuery_whenTwoPhaseAndFilter_thenGetExpectedScor setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); boolQueryBuilder.should(sparseEncodingQueryBuilder); @@ -401,7 +388,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInConstantScoreQuery_then createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(1.0f); ConstantScoreQueryBuilder constantScoreQueryBuilder = new ConstantScoreQueryBuilder(sparseEncodingQueryBuilder); @@ -421,7 +407,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInDisjunctionMaxQuery_the createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(5.0f); DisMaxQueryBuilder disMaxQueryBuilder = new DisMaxQueryBuilder(); @@ -444,7 +429,6 @@ public void testNeuralSparseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_then createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(5.0f); FunctionScoreQueryBuilder functionScoreQueryBuilder = new FunctionScoreQueryBuilder(sparseEncodingQueryBuilder); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7509efd42..c4d50ad55 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -52,6 +52,7 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -649,6 +650,44 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get()); } + @SneakyThrows + public void testRewrite_whenQueryTokensSupplierNull_andPruneSet_thenSuceessPrune() { + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .twoPhaseSharedQueryToken(Map.of()) + .twoPhasePruneRatio(3.0f) + .twoPhasePruneType(PruneType.ABS_VALUE); + Map expectedMap = Map.of("1", 1f, "2", 5f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.queryTokensSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertEquals(Map.of("2", 5f), queryBuilder.queryTokensSupplier().get()); + assertEquals(Map.of("1", 1f), queryBuilder.twoPhaseSharedQueryToken()); + } + @SneakyThrows public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) From 0fd25979babc44fef51b2de0678afcd46ae928ab Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Nov 2024 15:29:17 +0800 Subject: [PATCH 10/17] rename Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessor.java | 4 +- .../processor/SparseEncodingProcessor.java | 8 ++-- .../SparseEncodingProcessorFactory.java | 4 +- .../query/NeuralSparseQueryBuilder.java | 8 ++-- .../util/{pruning => prune}/PruneType.java | 10 ++-- .../util/{pruning => prune}/PruneUtils.java | 40 ++++++++-------- .../NeuralSparseTwoPhaseProcessorTests.java | 4 +- .../SparseEncodingProcessorTests.java | 6 +-- ...ncodingEmbeddingProcessorFactoryTests.java | 8 ++-- .../query/NeuralSparseQueryBuilderTests.java | 2 +- .../{pruning => prune}/PruneTypeTests.java | 4 +- .../{pruning => prune}/PruneUtilsTests.java | 48 +++++++++---------- 12 files changed, 73 insertions(+), 73 deletions(-) rename src/main/java/org/opensearch/neuralsearch/util/{pruning => prune}/PruneType.java (76%) rename src/main/java/org/opensearch/neuralsearch/util/{pruning => prune}/PruneUtils.java (87%) rename src/test/java/org/opensearch/neuralsearch/util/{pruning => prune}/PruneTypeTests.java (90%) rename src/test/java/org/opensearch/neuralsearch/util/{pruning => prune}/PruneUtilsTests.java (80%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 342889db5..60c52836e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -13,8 +13,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; -import org.opensearch.neuralsearch.util.pruning.PruneType; -import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 35a8f25f8..d383cfea2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -15,11 +15,11 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; -import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.prune.PruneUtils; /** * This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use, @@ -62,7 +62,7 @@ public void doExecute( mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); sparseVectors = sparseVectors.stream() - .map(vector -> PruneUtils.pruningSparseVector(pruneType, pruneRatio, vector, false).v1()) + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector, false).v1()) .toList(); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); @@ -74,7 +74,7 @@ public void doBatchExecute(List inferenceList, Consumer> handler mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); sparseVectors = sparseVectors.stream() - .map(vector -> PruneUtils.pruningSparseVector(pruneType, pruneRatio, vector, false).v1()) + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector, false).v1()) .toList(); handler.accept(sparseVectors); }, onException)); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 19cea9419..6a9854e78 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -21,8 +21,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; -import org.opensearch.neuralsearch.util.pruning.PruneUtils; -import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; +import org.opensearch.neuralsearch.util.prune.PruneType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index ba86b8872..14ed3eb94 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -47,8 +47,8 @@ import lombok.NoArgsConstructor; import lombok.Setter; import lombok.experimental.Accessors; -import org.opensearch.neuralsearch.util.pruning.PruneType; -import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; /** * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model @@ -146,7 +146,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float Map tokens = queryTokensSupplier.get(); // Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1, // while those less than or equal to the threshold are stored in v2. - Tuple, Map> splitTokens = PruneUtils.pruningSparseVector(pruneType, pruneRatio, tokens, true); + Tuple, Map> splitTokens = PruneUtils.pruneSparseVector(pruneType, pruneRatio, tokens, true); this.queryTokensSupplier(() -> splitTokens.v1()); copy.queryTokensSupplier(() -> splitTokens.v2()); } else { @@ -348,7 +348,7 @@ private BiConsumer> getModelInferenceAsync(SetOnce { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { - Tuple, Map> splitQueryTokens = PruneUtils.pruningSparseVector( + Tuple, Map> splitQueryTokens = PruneUtils.pruneSparseVector( twoPhasePruneType, twoPhasePruneRatio, queryTokens, diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java similarity index 76% rename from src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java rename to src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java index 22376b7c5..9e228ae27 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java @@ -2,12 +2,12 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.util.pruning; +package org.opensearch.neuralsearch.util.prune; import org.apache.commons.lang.StringUtils; /** - * Enum representing different types of pruning methods for sparse vectors + * Enum representing different types of prune methods for sparse vectors */ public enum PruneType { NONE("none"), @@ -29,9 +29,9 @@ public String getValue() { /** * Get PruneType from string value * - * @param value string representation of pruning type + * @param value string representation of prune type * @return corresponding PruneType enum - * @throws IllegalArgumentException if value doesn't match any pruning type + * @throws IllegalArgumentException if value doesn't match any prune type */ public static PruneType fromString(String value) { if (StringUtils.isEmpty(value)) return NONE; @@ -40,6 +40,6 @@ public static PruneType fromString(String value) { return type; } } - throw new IllegalArgumentException("Unknown pruning type: " + value); + throw new IllegalArgumentException("Unknown prune type: " + value); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java similarity index 87% rename from src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java rename to src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index ed7ac7f03..d1490d775 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/pruning/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.util.pruning; +package org.opensearch.neuralsearch.util.prune; import org.opensearch.common.collect.Tuple; @@ -15,8 +15,8 @@ import java.util.PriorityQueue; /** - * Utility class providing methods for pruning sparse vectors using different strategies. - * Pruning helps reduce the dimensionality of sparse vectors by removing less significant elements + * Utility class providing methods for prune sparse vectors using different strategies. + * Prune helps reduce the dimensionality of sparse vectors by removing less significant elements * based on various criteria. */ public class PruneUtils { @@ -31,7 +31,7 @@ public class PruneUtils { * @param requiresPrunedEntries Whether to return pruned entries * @return A tuple containing two maps: the first with top K elements, the second with remaining elements (or null) */ - private static Tuple, Map> pruningByTopK( + private static Tuple, Map> pruneByTopK( Map sparseVector, int k, boolean requiresPrunedEntries @@ -71,7 +71,7 @@ private static Tuple, Map> pruningByTopK( * @return A tuple containing two maps: the first with elements meeting the ratio threshold, * the second with elements below the threshold (or null) */ - private static Tuple, Map> pruningByMaxRatio( + private static Tuple, Map> pruneByMaxRatio( Map sparseVector, float ratio, boolean requiresPrunedEntries @@ -101,7 +101,7 @@ private static Tuple, Map> pruningByMaxRatio( * @return A tuple containing two maps: the first with elements above the threshold, * the second with elements below the threshold (or null) */ - private static Tuple, Map> pruningByValue( + private static Tuple, Map> pruneByValue( Map sparseVector, float thresh, boolean requiresPrunedEntries @@ -130,7 +130,7 @@ private static Tuple, Map> pruningByValue( * @return A tuple containing two maps: the first with elements meeting the alpha mass threshold, * the second with remaining elements (or null) */ - private static Tuple, Map> pruningByAlphaMass( + private static Tuple, Map> pruneByAlphaMass( Map sparseVector, float alpha, boolean requiresPrunedEntries @@ -159,16 +159,16 @@ private static Tuple, Map> pruningByAlphaMass( } /** - * Prunes a sparse vector using the specified pruning type and ratio. + * Prunes a sparse vector using the specified prune type and ratio. * - * @param pruneType The type of pruning strategy to use - * @param pruneRatio The ratio or threshold for pruning + * @param pruneType The type of prune strategy to use + * @param pruneRatio The ratio or threshold for prune * @param sparseVector The input sparse vector as a map of string keys to float values * @param requiresPrunedEntries Whether to return pruned entries * @return A tuple containing two maps: the first with high-scoring elements, * the second with low-scoring elements (or null if requiresPrunedEntries is false) */ - public static Tuple, Map> pruningSparseVector( + public static Tuple, Map> pruneSparseVector( PruneType pruneType, float pruneRatio, Map sparseVector, @@ -190,29 +190,29 @@ public static Tuple, Map> pruningSparseVector( switch (pruneType) { case TOP_K: - return pruningByTopK(sparseVector, (int) pruneRatio, requiresPrunedEntries); + return pruneByTopK(sparseVector, (int) pruneRatio, requiresPrunedEntries); case ALPHA_MASS: - return pruningByAlphaMass(sparseVector, pruneRatio, requiresPrunedEntries); + return pruneByAlphaMass(sparseVector, pruneRatio, requiresPrunedEntries); case MAX_RATIO: - return pruningByMaxRatio(sparseVector, pruneRatio, requiresPrunedEntries); + return pruneByMaxRatio(sparseVector, pruneRatio, requiresPrunedEntries); case ABS_VALUE: - return pruningByValue(sparseVector, pruneRatio, requiresPrunedEntries); + return pruneByValue(sparseVector, pruneRatio, requiresPrunedEntries); default: return new Tuple<>(new HashMap<>(sparseVector), requiresPrunedEntries ? new HashMap<>() : null); } } /** - * Validates whether a prune ratio is valid for a given pruning type. + * Validates whether a prune ratio is valid for a given prune type. * - * @param pruneType The type of pruning strategy + * @param pruneType The type of prune strategy * @param pruneRatio The ratio or threshold to validate - * @return true if the ratio is valid for the given pruning type, false otherwise - * @throws IllegalArgumentException if pruning type is null + * @return true if the ratio is valid for the given prune type, false otherwise + * @throws IllegalArgumentException if prune type is null */ public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { if (pruneType == null) { - throw new IllegalArgumentException("Pruning type cannot be null"); + throw new IllegalArgumentException("Prune type cannot be null"); } switch (pruneType) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java index 24257127f..40230a618 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java @@ -9,8 +9,8 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; -import org.opensearch.neuralsearch.util.pruning.PruneType; -import org.opensearch.neuralsearch.util.pruning.PruneUtils; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index d705616a9..8d512cc4c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -51,7 +51,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; -import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneType; public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock @@ -275,7 +275,7 @@ public void test_batchExecute_exception() { } @SuppressWarnings("unchecked") - public void testExecute_withPruningConfig_successful() { + public void testExecute_withPruneConfig_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", "value1"); @@ -317,7 +317,7 @@ public void testExecute_withPruningConfig_successful() { assertEquals(0.4f, second.get("low"), 0.001f); } - public void test_batchExecute_withPruning_successful() { + public void test_batchExecute_withPrune_successful() { SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); List> mockMLResponse = Collections.singletonList( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java index 8b1fafe8b..23cf19c40 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -8,8 +8,8 @@ import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; -import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_TYPE_FIELD; -import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_RATIO_FIELD; +import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_TYPE_FIELD; +import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_RATIO_FIELD; import lombok.SneakyThrows; import org.junit.Before; @@ -18,7 +18,7 @@ import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; -import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneType; import org.opensearch.test.OpenSearchTestCase; import java.util.HashMap; @@ -134,7 +134,7 @@ public void testCreateProcessor_whenInvalidPruneType_thenFail() { IllegalArgumentException.class, () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) ); - assertEquals("Unknown pruning type: invalid_prune_type", exception.getMessage()); + assertEquals("Unknown prune type: invalid_prune_type", exception.getMessage()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index c4d50ad55..2c4c88871 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -52,7 +52,7 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; -import org.opensearch.neuralsearch.util.pruning.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneType; import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java similarity index 90% rename from src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java rename to src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java index a1a823093..f8ba5b604 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneTypeTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.util.pruning; +package org.opensearch.neuralsearch.util.prune; import org.opensearch.test.OpenSearchTestCase; @@ -25,6 +25,6 @@ public void testFromString() { assertEquals(PruneType.ABS_VALUE, PruneType.fromString("abs_value")); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneType.fromString("test_value")); - assertEquals("Unknown pruning type: test_value", exception.getMessage()); + assertEquals("Unknown prune type: test_value", exception.getMessage()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java similarity index 80% rename from src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java rename to src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index 8dc31711a..febefa80d 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/pruning/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.util.pruning; +package org.opensearch.neuralsearch.util.prune; import org.opensearch.common.collect.Tuple; import org.opensearch.test.OpenSearchTestCase; @@ -12,7 +12,7 @@ public class PruneUtilsTests extends OpenSearchTestCase { - public void testPruningByTopK() { + public void testPruneByTopK() { Map input = new HashMap<>(); input.put("a", 5.0f); input.put("b", 3.0f); @@ -20,7 +20,7 @@ public void testPruningByTopK() { input.put("d", 1.0f); // Test without pruned entries - Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input, false); + Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input, false); assertEquals(2, result.v1().size()); assertNull(result.v2()); @@ -30,7 +30,7 @@ public void testPruningByTopK() { assertEquals(4.0f, result.v1().get("c"), 0.001); // Test with pruned entries - result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input, true); + result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input, true); assertEquals(2, result.v1().size()); assertEquals(2, result.v2().size()); @@ -38,7 +38,7 @@ public void testPruningByTopK() { assertTrue(result.v2().containsKey("d")); } - public void testPruningByMaxRatio() { + public void testPruneByMaxRatio() { Map input = new HashMap<>(); input.put("a", 10.0f); input.put("b", 8.0f); @@ -46,7 +46,7 @@ public void testPruningByMaxRatio() { input.put("d", 2.0f); // Test without pruned entries - Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input, false); + Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.7f, input, false); assertEquals(2, result.v1().size()); assertNull(result.v2()); @@ -54,7 +54,7 @@ public void testPruningByMaxRatio() { assertTrue(result.v1().containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 // Test with pruned entries - result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.7f, input, true); + result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.7f, input, true); assertEquals(2, result.v1().size()); assertEquals(2, result.v2().size()); @@ -62,7 +62,7 @@ public void testPruningByMaxRatio() { assertTrue(result.v2().containsKey("d")); } - public void testPruningByValue() { + public void testPruneByValue() { Map input = new HashMap<>(); input.put("a", 5.0f); input.put("b", 3.0f); @@ -70,7 +70,7 @@ public void testPruningByValue() { input.put("d", 1.0f); // Test without pruned entries - Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input, false); + Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 3.0f, input, false); assertEquals(2, result.v1().size()); assertNull(result.v2()); @@ -78,7 +78,7 @@ public void testPruningByValue() { assertTrue(result.v1().containsKey("b")); // Test with pruned entries - result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 3.0f, input, true); + result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 3.0f, input, true); assertEquals(2, result.v1().size()); assertEquals(2, result.v2().size()); @@ -86,7 +86,7 @@ public void testPruningByValue() { assertTrue(result.v2().containsKey("d")); } - public void testPruningByAlphaMass() { + public void testPruneByAlphaMass() { Map input = new HashMap<>(); input.put("a", 10.0f); input.put("b", 6.0f); @@ -94,7 +94,7 @@ public void testPruningByAlphaMass() { input.put("d", 1.0f); // Test without pruned entries - Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input, false); + Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.8f, input, false); assertEquals(2, result.v1().size()); assertNull(result.v2()); @@ -102,7 +102,7 @@ public void testPruningByAlphaMass() { assertTrue(result.v1().containsKey("b")); // Test with pruned entries - result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.8f, input, true); + result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.8f, input, true); assertEquals(2, result.v1().size()); assertEquals(2, result.v2().size()); @@ -113,23 +113,23 @@ public void testPruningByAlphaMass() { public void testEmptyInput() { Map input = new HashMap<>(); - Tuple, Map> result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input, false); + Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 5, input, false); assertTrue(result.v1().isEmpty()); assertNull(result.v2()); - result = PruneUtils.pruningSparseVector(PruneType.MAX_RATIO, 0.5f, input, false); + result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.5f, input, false); assertTrue(result.v1().isEmpty()); assertNull(result.v2()); - result = PruneUtils.pruningSparseVector(PruneType.ALPHA_MASS, 0.5f, input, false); + result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.5f, input, false); assertTrue(result.v1().isEmpty()); assertNull(result.v2()); - result = PruneUtils.pruningSparseVector(PruneType.ABS_VALUE, 0.5f, input, false); + result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 0.5f, input, false); assertTrue(result.v1().isEmpty()); assertNull(result.v2()); - result = PruneUtils.pruningSparseVector(PruneType.TOP_K, 5, input, true); + result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 5, input, true); assertTrue(result.v1().isEmpty()); assertTrue(result.v2().isEmpty()); } @@ -142,25 +142,25 @@ public void testNegativeValues() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, input, false) + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input, false) ); assertEquals("Pruned values must be positive", exception.getMessage()); } - public void testInvalidPruningType() { + public void testInvalidPruneType() { Map input = new HashMap<>(); input.put("a", 1.0f); input.put("b", 2.0f); IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(null, 2, input, false) + () -> PruneUtils.pruneSparseVector(null, 2, input, false) ); assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(null, 2, input, true) + () -> PruneUtils.pruneSparseVector(null, 2, input, true) ); assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); } @@ -168,7 +168,7 @@ public void testInvalidPruningType() { public void testNullSparseVector() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruningSparseVector(PruneType.TOP_K, 2, null, false) + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, null, false) ); assertEquals(exception.getMessage(), "Sparse vector must be provided"); } @@ -210,6 +210,6 @@ public void testIsValidPruneRatio() { public void testIsValidPruneRatioWithNullType() { IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); - assertEquals("Pruning type cannot be null", exception.getMessage()); + assertEquals("Prune type cannot be null", exception.getMessage()); } } From 09e476541ea592168f3080222f5ae277365ad8d5 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 21 Nov 2024 14:14:56 +0800 Subject: [PATCH 11/17] enhance: more detailed error message Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessor.java | 7 +++++- .../SparseEncodingProcessorFactory.java | 7 +++++- .../neuralsearch/util/prune/PruneUtils.java | 24 +++++++++++++++++++ ...ncodingEmbeddingProcessorFactoryTests.java | 2 +- .../util/prune/PruneUtilsTests.java | 14 +++++++++++ 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 60c52836e..9e0ecefd4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -222,7 +222,12 @@ public NeuralSparseTwoPhaseProcessor create( ); } if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue() + "Illegal prune_ratio " + + pruneRatio + + " for prune_type: " + + pruneType.getValue() + + ". " + + PruneUtils.getValidPruneRatioDescription(pruneType) ); return new NeuralSparseTwoPhaseProcessor( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 6a9854e78..f5066369f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -52,7 +52,12 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description, // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " + pruneRatio + " for prune_type: " + pruneType.getValue() + "Illegal prune_ratio " + + pruneRatio + + " for prune_type: " + + pruneType.getValue() + + ". " + + PruneUtils.getValidPruneRatioDescription(pruneType) ); } else { // if we don't have prune type, then prune ratio field must not have value diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index d1490d775..36fc55a23 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -227,4 +227,28 @@ public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { return true; } } + + /** + * Get description of valid prune ratio for a given prune type. + * + * @param pruneType The type of prune strategy + * @throws IllegalArgumentException if prune type is null + */ + public static String getValidPruneRatioDescription(PruneType pruneType) { + if (pruneType == null) { + throw new IllegalArgumentException("Prune type cannot be null"); + } + + switch (pruneType) { + case TOP_K: + return "prune_ratio should be positive integer."; + case MAX_RATIO: + case ALPHA_MASS: + return "prune_ratio should be in the range [0, 1)."; + case ABS_VALUE: + return "prune_ratio should be non-negative."; + default: + return ""; + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java index 23cf19c40..9d1b45866 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -149,7 +149,7 @@ public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { IllegalArgumentException.class, () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) ); - assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k", exception.getMessage()); + assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index febefa80d..4835b9f08 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -212,4 +212,18 @@ public void testIsValidPruneRatioWithNullType() { IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); assertEquals("Prune type cannot be null", exception.getMessage()); } + + public void testGetValidPruneRatioDescription() { + assertEquals("prune_ratio should be positive integer.", PruneUtils.getValidPruneRatioDescription(PruneType.TOP_K)); + assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.MAX_RATIO)); + assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.ALPHA_MASS)); + assertEquals("prune_ratio should be non-negative.", PruneUtils.getValidPruneRatioDescription(PruneType.ABS_VALUE)); + assertEquals("", PruneUtils.getValidPruneRatioDescription(PruneType.NONE)); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.getValidPruneRatioDescription(null) + ); + assertEquals(exception.getMessage(), "Prune type cannot be null"); + } } From 5b8ab70cdc07f98af2f424dc1973ef9483266db5 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 21 Nov 2024 14:29:17 +0800 Subject: [PATCH 12/17] refactor to prune and split Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 8 +- .../query/NeuralSparseQueryBuilder.java | 7 +- .../neuralsearch/util/prune/PruneUtils.java | 61 ++++-- .../util/prune/PruneUtilsTests.java | 177 ++++++++++-------- 4 files changed, 151 insertions(+), 102 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index d383cfea2..1f06e81e2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -61,9 +61,7 @@ public void doExecute( ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); - sparseVectors = sparseVectors.stream() - .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector, false).v1()) - .toList(); + sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -73,9 +71,7 @@ public void doExecute( public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); - sparseVectors = sparseVectors.stream() - .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector, false).v1()) - .toList(); + sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); handler.accept(sparseVectors); }, onException)); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 14ed3eb94..be9719452 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -146,7 +146,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float Map tokens = queryTokensSupplier.get(); // Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1, // while those less than or equal to the threshold are stored in v2. - Tuple, Map> splitTokens = PruneUtils.pruneSparseVector(pruneType, pruneRatio, tokens, true); + Tuple, Map> splitTokens = PruneUtils.splitSparseVector(pruneType, pruneRatio, tokens); this.queryTokensSupplier(() -> splitTokens.v1()); copy.queryTokensSupplier(() -> splitTokens.v2()); } else { @@ -348,11 +348,10 @@ private BiConsumer> getModelInferenceAsync(SetOnce { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { - Tuple, Map> splitQueryTokens = PruneUtils.pruneSparseVector( + Tuple, Map> splitQueryTokens = PruneUtils.splitSparseVector( twoPhasePruneType, twoPhasePruneRatio, - queryTokens, - true + queryTokens ); setOnce.set(splitQueryTokens.v1()); twoPhaseSharedQueryToken = splitQueryTokens.v2(); diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index 36fc55a23..85415cb63 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -159,20 +159,18 @@ private static Tuple, Map> pruneByAlphaMass( } /** - * Prunes a sparse vector using the specified prune type and ratio. + * Split a sparse vector using the specified prune type and ratio. * - * @param pruneType The type of prune strategy to use - * @param pruneRatio The ratio or threshold for prune + * @param pruneType The type of prune strategy to use + * @param pruneRatio The ratio or threshold for prune * @param sparseVector The input sparse vector as a map of string keys to float values - * @param requiresPrunedEntries Whether to return pruned entries * @return A tuple containing two maps: the first with high-scoring elements, - * the second with low-scoring elements (or null if requiresPrunedEntries is false) + * the second with low-scoring elements (or null if requiresPrunedEntries is false) */ - public static Tuple, Map> pruneSparseVector( + public static Tuple, Map> splitSparseVector( PruneType pruneType, float pruneRatio, - Map sparseVector, - boolean requiresPrunedEntries + Map sparseVector ) { if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { throw new IllegalArgumentException("Prune type and prune ratio must be provided"); @@ -190,15 +188,52 @@ public static Tuple, Map> pruneSparseVector( switch (pruneType) { case TOP_K: - return pruneByTopK(sparseVector, (int) pruneRatio, requiresPrunedEntries); + return pruneByTopK(sparseVector, (int) pruneRatio, true); + case ALPHA_MASS: + return pruneByAlphaMass(sparseVector, pruneRatio, true); + case MAX_RATIO: + return pruneByMaxRatio(sparseVector, pruneRatio, true); + case ABS_VALUE: + return pruneByValue(sparseVector, pruneRatio, true); + default: + return new Tuple<>(new HashMap<>(sparseVector), new HashMap<>()); + } + } + + /** + * Prune a sparse vector using the specified prune type and ratio. + * + * @param pruneType The type of prune strategy to use + * @param pruneRatio The ratio or threshold for prune + * @param sparseVector The input sparse vector as a map of string keys to float values + * @return A map with high-scoring elements + */ + public static Map pruneSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { + if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { + throw new IllegalArgumentException("Prune type and prune ratio must be provided"); + } + + if (Objects.isNull(sparseVector)) { + throw new IllegalArgumentException("Sparse vector must be provided"); + } + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException("Pruned values must be positive"); + } + } + + switch (pruneType) { + case TOP_K: + return pruneByTopK(sparseVector, (int) pruneRatio, false).v1(); case ALPHA_MASS: - return pruneByAlphaMass(sparseVector, pruneRatio, requiresPrunedEntries); + return pruneByAlphaMass(sparseVector, pruneRatio, false).v1(); case MAX_RATIO: - return pruneByMaxRatio(sparseVector, pruneRatio, requiresPrunedEntries); + return pruneByMaxRatio(sparseVector, pruneRatio, false).v1(); case ABS_VALUE: - return pruneByValue(sparseVector, pruneRatio, requiresPrunedEntries); + return pruneByValue(sparseVector, pruneRatio, false).v1(); default: - return new Tuple<>(new HashMap<>(sparseVector), requiresPrunedEntries ? new HashMap<>() : null); + return sparseVector; } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index 4835b9f08..99bbe3c92 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -19,23 +19,22 @@ public void testPruneByTopK() { input.put("c", 4.0f); input.put("d", 1.0f); - // Test without pruned entries - Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input, false); - - assertEquals(2, result.v1().size()); - assertNull(result.v2()); - assertTrue(result.v1().containsKey("a")); - assertTrue(result.v1().containsKey("c")); - assertEquals(5.0f, result.v1().get("a"), 0.001); - assertEquals(4.0f, result.v1().get("c"), 0.001); - - // Test with pruned entries - result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input, true); - - assertEquals(2, result.v1().size()); - assertEquals(2, result.v2().size()); - assertTrue(result.v2().containsKey("b")); - assertTrue(result.v2().containsKey("d")); + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input); + + assertEquals(2, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.TOP_K, 2, input); + + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(4.0f, tupleResult.v1().get("c"), 0.001); + assertEquals(3.0f, tupleResult.v2().get("b"), 0.001); + assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); } public void testPruneByMaxRatio() { @@ -45,21 +44,22 @@ public void testPruneByMaxRatio() { input.put("c", 5.0f); input.put("d", 2.0f); - // Test without pruned entries - Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.7f, input, false); + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.7f, input); - assertEquals(2, result.v1().size()); - assertNull(result.v2()); - assertTrue(result.v1().containsKey("a")); // 10.0/10.0 = 1.0 >= 0.7 - assertTrue(result.v1().containsKey("b")); // 8.0/10.0 = 0.8 >= 0.7 + assertEquals(2, result.size()); + assertEquals(10.0f, result.get("a"), 0.001); + assertEquals(8.0f, result.get("b"), 0.001); - // Test with pruned entries - result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.7f, input, true); + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.MAX_RATIO, 0.7f, input); - assertEquals(2, result.v1().size()); - assertEquals(2, result.v2().size()); - assertTrue(result.v2().containsKey("c")); - assertTrue(result.v2().containsKey("d")); + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(10.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(8.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(5.0f, tupleResult.v2().get("c"), 0.001); + assertEquals(2.0f, tupleResult.v2().get("d"), 0.001); } public void testPruneByValue() { @@ -69,21 +69,22 @@ public void testPruneByValue() { input.put("c", 2.0f); input.put("d", 1.0f); - // Test without pruned entries - Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 3.0f, input, false); + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 3.0f, input); - assertEquals(2, result.v1().size()); - assertNull(result.v2()); - assertTrue(result.v1().containsKey("a")); - assertTrue(result.v1().containsKey("b")); + assertEquals(2, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(3.0f, result.get("b"), 0.001); - // Test with pruned entries - result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 3.0f, input, true); + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.ABS_VALUE, 3.0f, input); - assertEquals(2, result.v1().size()); - assertEquals(2, result.v2().size()); - assertTrue(result.v2().containsKey("c")); - assertTrue(result.v2().containsKey("d")); + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(3.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(2.0f, tupleResult.v2().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); } public void testPruneByAlphaMass() { @@ -93,45 +94,35 @@ public void testPruneByAlphaMass() { input.put("c", 3.0f); input.put("d", 1.0f); - // Test without pruned entries - Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.8f, input, false); + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.8f, input); - assertEquals(2, result.v1().size()); - assertNull(result.v2()); - assertTrue(result.v1().containsKey("a")); - assertTrue(result.v1().containsKey("b")); + assertEquals(2, result.size()); + assertEquals(10.0f, result.get("a"), 0.001); + assertEquals(6.0f, result.get("b"), 0.001); - // Test with pruned entries - result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.8f, input, true); + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.ALPHA_MASS, 0.8f, input); - assertEquals(2, result.v1().size()); - assertEquals(2, result.v2().size()); - assertTrue(result.v2().containsKey("c")); - assertTrue(result.v2().containsKey("d")); + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(10.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(6.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(3.0f, tupleResult.v2().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); } public void testEmptyInput() { Map input = new HashMap<>(); - Tuple, Map> result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 5, input, false); - assertTrue(result.v1().isEmpty()); - assertNull(result.v2()); + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 5, input); + assertTrue(result.isEmpty()); - result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.5f, input, false); - assertTrue(result.v1().isEmpty()); - assertNull(result.v2()); - - result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.5f, input, false); - assertTrue(result.v1().isEmpty()); - assertNull(result.v2()); - - result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 0.5f, input, false); - assertTrue(result.v1().isEmpty()); - assertNull(result.v2()); - - result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 5, input, true); - assertTrue(result.v1().isEmpty()); - assertTrue(result.v2().isEmpty()); + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.TOP_K, 5, input); + assertTrue(tupleResult.v1().isEmpty()); + assertTrue(tupleResult.v2().isEmpty()); } public void testNegativeValues() { @@ -140,11 +131,19 @@ public void testNegativeValues() { input.put("b", 3.0f); input.put("c", 4.0f); - IllegalArgumentException exception = assertThrows( + // Test prune + IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input, false) + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input) ); - assertEquals("Pruned values must be positive", exception.getMessage()); + assertEquals("Pruned values must be positive", exception1.getMessage()); + + // Test split + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, input) + ); + assertEquals("Pruned values must be positive", exception2.getMessage()); } public void testInvalidPruneType() { @@ -152,25 +151,45 @@ public void testInvalidPruneType() { input.put("a", 1.0f); input.put("b", 2.0f); + // Test prune IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruneSparseVector(null, 2, input, false) + () -> PruneUtils.splitSparseVector(null, 2, input) ); assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruneSparseVector(null, 2, input, true) + () -> PruneUtils.splitSparseVector(null, 2, input) ); assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); + + // Test split + IllegalArgumentException exception3 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(null, 2, input) + ); + assertEquals(exception3.getMessage(), "Prune type and prune ratio must be provided"); + + IllegalArgumentException exception4 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(null, 2, input) + ); + assertEquals(exception4.getMessage(), "Prune type and prune ratio must be provided"); } public void testNullSparseVector() { - IllegalArgumentException exception = assertThrows( + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, null) + ); + assertEquals(exception1.getMessage(), "Sparse vector must be provided"); + + IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, null, false) + () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, null) ); - assertEquals(exception.getMessage(), "Sparse vector must be provided"); + assertEquals(exception2.getMessage(), "Sparse vector must be provided"); } public void testIsValidPruneRatio() { From 6cabbc0edb101017f6928a5505e36e5107b80324 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 22 Nov 2024 14:36:07 +0800 Subject: [PATCH 13/17] changelog Signed-off-by: zhichao-aws --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b76f1c39f..eaa01039f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features -- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Enhancements +- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) ### Infrastructure From cffd829100bf77ba195bcab6a88cac4ce299f376 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 22 Nov 2024 15:14:40 +0800 Subject: [PATCH 14/17] fix UT cov Signed-off-by: zhichao-aws --- .../neuralsearch/util/prune/PruneUtils.java | 2 +- .../util/prune/PruneUtilsTests.java | 33 +++++++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index 85415cb63..34a9cff2d 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -165,7 +165,7 @@ private static Tuple, Map> pruneByAlphaMass( * @param pruneRatio The ratio or threshold for prune * @param sparseVector The input sparse vector as a map of string keys to float values * @return A tuple containing two maps: the first with high-scoring elements, - * the second with low-scoring elements (or null if requiresPrunedEntries is false) + * the second with low-scoring elements */ public static Tuple, Map> splitSparseVector( PruneType pruneType, diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index 99bbe3c92..f0869ac53 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -112,6 +112,33 @@ public void testPruneByAlphaMass() { assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); } + public void testNonePrune() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.NONE, 2, input); + + assertEquals(4, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(3.0f, result.get("b"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + assertEquals(1.0f, result.get("d"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.NONE, 2, input); + + assertEquals(4, tupleResult.v1().size()); + assertEquals(0, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(3.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(4.0f, tupleResult.v1().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v1().get("d"), 0.001); + } + public void testEmptyInput() { Map input = new HashMap<>(); @@ -154,13 +181,13 @@ public void testInvalidPruneType() { // Test prune IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(null, 2, input) + () -> PruneUtils.pruneSparseVector(null, 2, input) ); assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(null, 2, input) + () -> PruneUtils.pruneSparseVector(null, 2, input) ); assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); @@ -181,7 +208,7 @@ public void testInvalidPruneType() { public void testNullSparseVector() { IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, null) + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, null) ); assertEquals(exception1.getMessage(), "Sparse vector must be provided"); From 0d928a9abf40427fb4f057185e5477725b7cf796 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 10 Dec 2024 16:43:31 +0800 Subject: [PATCH 15/17] address review comments Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessor.java | 19 +++++++------ .../processor/SparseEncodingProcessor.java | 12 +++++--- .../SparseEncodingProcessorFactory.java | 26 +++++++++-------- .../neuralsearch/util/prune/PruneType.java | 6 ++-- .../neuralsearch/util/prune/PruneUtils.java | 28 +++++++++++-------- ...ncodingEmbeddingProcessorFactoryTests.java | 2 +- .../util/prune/PruneUtilsTests.java | 23 +++++---------- 7 files changed, 61 insertions(+), 55 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 9e0ecefd4..bc5971e3f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -221,14 +221,17 @@ public NeuralSparseTwoPhaseProcessor create( twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString() ); } - if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " - + pruneRatio - + " for prune_type: " - + pruneType.getValue() - + ". " - + PruneUtils.getValidPruneRatioDescription(pruneType) - ); + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Illegal prune_ratio %f for prune_type: %s. %s", + pruneRatio, + pruneType.getValue(), + PruneUtils.getValidPruneRatioDescription(pruneType) + ) + ); + } return new NeuralSparseTwoPhaseProcessor( tag, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 1f06e81e2..9250c8d64 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -60,8 +60,10 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); - sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps) + .stream() + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) + .toList(); setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -70,8 +72,10 @@ public void doExecute( @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); - sparseVectors = sparseVectors.stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps) + .stream() + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) + .toList(); handler.accept(sparseVectors); }, onException)); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index f5066369f..7a7d7dfde 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -12,6 +12,7 @@ import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import java.util.Locale; import java.util.Map; import org.opensearch.cluster.service.ClusterService; @@ -51,19 +52,20 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description, // if we have prune type, then prune ratio field must have value // readDoubleProperty will throw exception if value is not present pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); - if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( - "Illegal prune_ratio " - + pruneRatio - + " for prune_type: " - + pruneType.getValue() - + ". " - + PruneUtils.getValidPruneRatioDescription(pruneType) - ); - } else { - // if we don't have prune type, then prune ratio field must not have value - if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { - throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Illegal prune_ratio %f for prune_type: %s. %s", + pruneRatio, + pruneType.getValue(), + PruneUtils.getValidPruneRatioDescription(pruneType) + ) + ); } + } else if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { + // if we don't have prune type, then prune ratio field must not have value + throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); } return new SparseEncodingProcessor( diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java index 9e228ae27..5f8e62b7c 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java @@ -6,6 +6,8 @@ import org.apache.commons.lang.StringUtils; +import java.util.Locale; + /** * Enum representing different types of prune methods for sparse vectors */ @@ -33,13 +35,13 @@ public String getValue() { * @return corresponding PruneType enum * @throws IllegalArgumentException if value doesn't match any prune type */ - public static PruneType fromString(String value) { + public static PruneType fromString(final String value) { if (StringUtils.isEmpty(value)) return NONE; for (PruneType type : PruneType.values()) { if (type.value.equals(value)) { return type; } } - throw new IllegalArgumentException("Unknown prune type: " + value); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value)); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index 34a9cff2d..77836972e 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -33,13 +33,13 @@ public class PruneUtils { */ private static Tuple, Map> pruneByTopK( Map sparseVector, - int k, + float k, boolean requiresPrunedEntries ) { PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); for (Map.Entry entry : sparseVector.entrySet()) { - if (pq.size() < k) { + if (pq.size() < (int) k) { pq.offer(entry); } else if (entry.getValue() > pq.peek().getValue()) { pq.poll(); @@ -172,8 +172,8 @@ public static Tuple, Map> splitSparseVector( float pruneRatio, Map sparseVector ) { - if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { - throw new IllegalArgumentException("Prune type and prune ratio must be provided"); + if (Objects.isNull(pruneType)) { + throw new IllegalArgumentException("Prune type must be provided"); } if (Objects.isNull(sparseVector)) { @@ -188,7 +188,7 @@ public static Tuple, Map> splitSparseVector( switch (pruneType) { case TOP_K: - return pruneByTopK(sparseVector, (int) pruneRatio, true); + return pruneByTopK(sparseVector, pruneRatio, true); case ALPHA_MASS: return pruneByAlphaMass(sparseVector, pruneRatio, true); case MAX_RATIO: @@ -208,9 +208,13 @@ public static Tuple, Map> splitSparseVector( * @param sparseVector The input sparse vector as a map of string keys to float values * @return A map with high-scoring elements */ - public static Map pruneSparseVector(PruneType pruneType, float pruneRatio, Map sparseVector) { - if (Objects.isNull(pruneType) || Objects.isNull(pruneRatio)) { - throw new IllegalArgumentException("Prune type and prune ratio must be provided"); + public static Map pruneSparseVector( + final PruneType pruneType, + final float pruneRatio, + final Map sparseVector + ) { + if (Objects.isNull(pruneType)) { + throw new IllegalArgumentException("Prune type must be provided"); } if (Objects.isNull(sparseVector)) { @@ -225,7 +229,7 @@ public static Map pruneSparseVector(PruneType pruneType, float pr switch (pruneType) { case TOP_K: - return pruneByTopK(sparseVector, (int) pruneRatio, false).v1(); + return pruneByTopK(sparseVector, pruneRatio, false).v1(); case ALPHA_MASS: return pruneByAlphaMass(sparseVector, pruneRatio, false).v1(); case MAX_RATIO: @@ -245,7 +249,7 @@ public static Map pruneSparseVector(PruneType pruneType, float pr * @return true if the ratio is valid for the given prune type, false otherwise * @throws IllegalArgumentException if prune type is null */ - public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { + public static boolean isValidPruneRatio(final PruneType pruneType, final float pruneRatio) { if (pruneType == null) { throw new IllegalArgumentException("Prune type cannot be null"); } @@ -269,7 +273,7 @@ public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) { * @param pruneType The type of prune strategy * @throws IllegalArgumentException if prune type is null */ - public static String getValidPruneRatioDescription(PruneType pruneType) { + public static String getValidPruneRatioDescription(final PruneType pruneType) { if (pruneType == null) { throw new IllegalArgumentException("Prune type cannot be null"); } @@ -283,7 +287,7 @@ public static String getValidPruneRatioDescription(PruneType pruneType) { case ABS_VALUE: return "prune_ratio should be non-negative."; default: - return ""; + return "prune_ratio field is not supported when prune_type is none"; } } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java index 9d1b45866..5d098e77e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -149,7 +149,7 @@ public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { IllegalArgumentException.class, () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) ); - assertEquals("Illegal prune_ratio 0.2 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage()); + assertEquals("Illegal prune_ratio 0.200000 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index f0869ac53..536125152 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -183,26 +183,14 @@ public void testInvalidPruneType() { IllegalArgumentException.class, () -> PruneUtils.pruneSparseVector(null, 2, input) ); - assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); - - IllegalArgumentException exception2 = assertThrows( - IllegalArgumentException.class, - () -> PruneUtils.pruneSparseVector(null, 2, input) - ); - assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); + assertEquals(exception1.getMessage(), "Prune type must be provided"); // Test split - IllegalArgumentException exception3 = assertThrows( - IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(null, 2, input) - ); - assertEquals(exception3.getMessage(), "Prune type and prune ratio must be provided"); - - IllegalArgumentException exception4 = assertThrows( + IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, () -> PruneUtils.splitSparseVector(null, 2, input) ); - assertEquals(exception4.getMessage(), "Prune type and prune ratio must be provided"); + assertEquals(exception2.getMessage(), "Prune type must be provided"); } public void testNullSparseVector() { @@ -264,7 +252,10 @@ public void testGetValidPruneRatioDescription() { assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.MAX_RATIO)); assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.ALPHA_MASS)); assertEquals("prune_ratio should be non-negative.", PruneUtils.getValidPruneRatioDescription(PruneType.ABS_VALUE)); - assertEquals("", PruneUtils.getValidPruneRatioDescription(PruneType.NONE)); + assertEquals( + "prune_ratio field is not supported when prune_type is none", + PruneUtils.getValidPruneRatioDescription(PruneType.NONE) + ); IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, From 7486ee82c998d4b6f3028cea5ccdcfc686c82e5e Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 10 Dec 2024 17:13:43 +0800 Subject: [PATCH 16/17] enlarge score diff range Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java index 5f921809a..bc61f7c29 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java @@ -43,7 +43,7 @@ public class NeuralSparseTwoPhaseProcessorIT extends BaseNeuralSearchIT { private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); - private static final Float DELTA = 1e-5f; + private static final Float DELTA = 1e-4f; private final Map testRankFeaturesDoc = createRandomTokenWeightMap(TEST_TOKENS); private static final List TWO_PHASE_TEST_TOKEN = List.of("hello", "world"); From 185fc5159dc5dfab54ba825eea5c5f1754cb0a78 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 18 Dec 2024 10:29:00 +0800 Subject: [PATCH 17/17] address comments: check lowScores non null instead of flag Signed-off-by: zhichao-aws --- .../opensearch/neuralsearch/util/prune/PruneUtils.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index 77836972e..a4c35adcc 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -53,7 +53,7 @@ private static Tuple, Map> pruneByTopK( while (!pq.isEmpty()) { Map.Entry entry = pq.poll(); highScores.put(entry.getKey(), entry.getValue()); - if (requiresPrunedEntries) { + if (Objects.nonNull(lowScores)) { lowScores.remove(entry.getKey()); } } @@ -84,7 +84,7 @@ private static Tuple, Map> pruneByMaxRatio( for (Map.Entry entry : sparseVector.entrySet()) { if (entry.getValue() >= ratio * maxValue) { highScores.put(entry.getKey(), entry.getValue()); - } else if (requiresPrunedEntries) { + } else if (Objects.nonNull(lowScores)) { lowScores.put(entry.getKey(), entry.getValue()); } } @@ -112,7 +112,7 @@ private static Tuple, Map> pruneByValue( for (Map.Entry entry : sparseVector.entrySet()) { if (entry.getValue() >= thresh) { highScores.put(entry.getKey(), entry.getValue()); - } else if (requiresPrunedEntries) { + } else if (Objects.nonNull(lowScores)) { lowScores.put(entry.getKey(), entry.getValue()); } } @@ -150,7 +150,7 @@ private static Tuple, Map> pruneByAlphaMass( if (topSum <= alpha * sum) { highScores.put(entry.getKey(), value); - } else if (requiresPrunedEntries) { + } else if (Objects.nonNull(lowScores)) { lowScores.put(entry.getKey(), value); } }