Skip to content

Commit

Permalink
Supporting sparse semantic retrieval in neural search (#333)
Browse files Browse the repository at this point in the history
* sparse mapper field and query builder

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix typo

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* Add map result support in neural search for non text embedding models

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix compilation failure issue

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add more UTs

Signed-off-by: zane-neo <zaniu@amazon.com>

* add sparse encoding processor

Signed-off-by: xinyual <xinyual@amazon.com>

* add sparse encoding processor

Signed-off-by: xinyual <xinyual@amazon.com>

* remove guava in gradle

Signed-off-by: xinyual <xinyual@amazon.com>

* modify access control

Signed-off-by: xinyual <xinyual@amazon.com>

* Add map result support in neural search for non text embedding models

Signed-off-by: zane-neo <zaniu@amazon.com>

* Fix compilation failure issue

Signed-off-by: zane-neo <zaniu@amazon.com>

* change output logic

Signed-off-by: xinyual <xinyual@amazon.com>

* create abstract

Signed-off-by: xinyual <xinyual@amazon.com>

* create abstract proccesor

Signed-off-by: xinyual <xinyual@amazon.com>

* add abstract class

Signed-off-by: xinyual <xinyual@amazon.com>

* remove duplicate code

Signed-off-by: xinyual <xinyual@amazon.com>

* remove duplicate code

Signed-off-by: xinyual <xinyual@amazon.com>

* remove dl process

Signed-off-by: xinyual <xinyual@amazon.com>

* move static to abstract class

Signed-off-by: xinyual <xinyual@amazon.com>

* update query rewrite logic

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* modify header

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* merge conflict

Signed-off-by: xinyual <xinyual@amazon.com>

* delete index mapper, change to rank_features

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* remove unused import

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* list return result

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* refactor type and listTypeNestedMapKey, tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* forbid nested input. tidy.

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* enable nested

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix test

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* Add ut it to sparse encoding processor (#6)

* fix original UT problem

Signed-off-by: xinyual <xinyual@amazon.com>

* add UT IT

Signed-off-by: xinyual <xinyual@amazon.com>

* add more UT

Signed-off-by: xinyual <xinyual@amazon.com>

* add more ut

Signed-off-by: xinyual <xinyual@amazon.com>

* fix typo error

Signed-off-by: xinyual <xinyual@amazon.com>

---------

Signed-off-by: xinyual <xinyual@amazon.com>

* utils, tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* rename to sparse_encoding query

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add validation and ut

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* sparse encoding query builder ut

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* rename

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* UT for utils

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* enrich sparse encoding IT mappings

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add it

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add it

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add integ test

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* rename resource file

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* remove BoundedLinearQuery and TokenScoreUpperBound

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add delta to loose the equal

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* move SparseEncodingQueryBuilder to upper level path

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add it

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* Update src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java

Co-authored-by: zane-neo <zaniu@amazon.com>
Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* Update src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java

Co-authored-by: zane-neo <zaniu@amazon.com>
Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* restore gradle.propeties

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add release notes

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* change field modifier to private for NLPProcessor

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add comments

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* use StringUtils to check

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* null check

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* modify changelog

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* nit

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* nit

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* remove query tokens from user interface

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix test

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* tidy

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* update function name

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add javadoc

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* remove debug log including inference result

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* make query text and model id required

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* minor changes based on comments

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add locale to String.format

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* update mock model url

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

---------

Signed-off-by: zhichao-aws <zhichaog@amazon.com>
Signed-off-by: zane-neo <zaniu@amazon.com>
Signed-off-by: xinyual <xinyual@amazon.com>
Co-authored-by: zane-neo <zaniu@amazon.com>
Co-authored-by: xinyual <xinyual@amazon.com>
  • Loading branch information
3 people authored Sep 27, 2023
1 parent 8484be9 commit 7bef7a0
Show file tree
Hide file tree
Showing 24 changed files with 2,273 additions and 315 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x)
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ dependencies {
runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12'
runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA'
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -100,10 +102,38 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
Expand All @@ -113,12 +143,11 @@ private void inferenceSentencesWithRetry(
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
Expand All @@ -144,4 +173,22 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) {
throw new IllegalStateException(
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"
);
}
List<Map<String, ?>> resultMaps = new ArrayList<>();
for (ModelTensors tensors : tensorOutputList) {
List<ModelTensor> tensorList = tensors.getMlModelTensors();
for (ModelTensor tensor : tensorList) {
resultMaps.add(tensor.getDataAsMap());
}
}
return resultMaps;
}

}
17 changes: 13 additions & 4 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -31,15 +30,18 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
Expand All @@ -62,7 +64,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();;
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();

@Override
public Collection<Object> createComponents(
Expand All @@ -79,6 +81,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -87,14 +90,20 @@ public Collection<Object> createComponents(
public List<QuerySpec<?>> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent)
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
);
}

@Override
Expand Down
Loading

0 comments on commit 7bef7a0

Please sign in to comment.