Skip to content

Commit

Permalink
fine tune connector process function (opensearch-project#1954)
Browse files Browse the repository at this point in the history
* fine tune connector process function

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add unit test for process function

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add license header

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored and austintlee committed Mar 18, 2024
1 parent 8097562 commit cbaa991
Show file tree
Hide file tree
Showing 35 changed files with 1,311 additions and 149 deletions.
1 change: 1 addition & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
compileOnly "org.opensearch:common-utils:${common_utils_version}"
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
testImplementation "org.opensearch.test:framework:${opensearch_version}"

compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid JSON in payload");
throw new IllegalArgumentException("Invalid payload: " + payload);
}
return (T) payload;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

package org.opensearch.ml.common.connector;

import com.google.common.collect.ImmutableList;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -20,58 +20,41 @@ public class MLPostProcessFunction {
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();

private static final Map<String, Function<Object, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();

static {
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList());
}

public static Function<List<?>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
if (embeddings.get(0) instanceof Number) {
embeddings = ImmutableList.of(embeddings);
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{eachEmbedding.size()})
.data(eachEmbedding.toArray(new Number[0]))
.build()
);
});
return modelTensors;
};
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
}

public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
public static Function<Object, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

public static boolean contains(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,48 @@

package org.opensearch.ml.common.connector;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPreProcessFunction {

private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
}

private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
}
public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string";

static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction();
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;

public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<Float>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty() && !(((List<?>)input).get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
}
}

@Override
public List<ModelTensor> process(List<Float> embedding) {
List<ModelTensor> modelTensors = new ArrayList<>();
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build());
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}
List<?> outerList = (List<?>) input;
if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof Map)) {
throw new IllegalArgumentException("Post process function input is not a List of Map.");
}
Map innerMap = (Map) outerList.get(0);

if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevance_score")) {
throw new IllegalArgumentException("The rerank result should contain index and relevance_score.");
}
}
}

@Override
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
List<ModelTensor> modelTensors = new ArrayList<>();

if (rerankResults.size() > 0) {
Double[] scores = new Double[rerankResults.size()];
for (int i = 0; i < rerankResults.size(); i++) {
Integer index = (Integer) rerankResults.get(i).get("index");
scores[index] = (Double) rerankResults.get(i).get("relevance_score");
}

for (int i = 0; i < scores.length; i++) {
modelTensors.add(ModelTensor.builder()
.name("similarity")
.shape(new long[]{1})
.data(new Number[]{scores[i]})
.dataType(MLResultDataType.FLOAT32)
.build());
}
}
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.List;
import java.util.function.Function;

public abstract class ConnectorPostProcessFunction<T> implements Function<Object, List<ModelTensor>> {

@Override
public List<ModelTensor> apply(Object input) {
if (input == null) {
throw new IllegalArgumentException("Can't run post process function as model output is null");
}
validate(input);
return process((T)input);
}

public abstract void validate(Object input);

public abstract List<ModelTensor> process(T input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;

public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<List<Float>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof List)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing List of Float values.");
}
List<?> innerList = (List<?>) outerList.get(0);

if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
}
}
}

@Override
public List<ModelTensor> process(List<List<Float>> embeddings) {
List<ModelTensor> modelTensors = new ArrayList<>();
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;


public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public BedrockEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Loading

0 comments on commit cbaa991

Please sign in to comment.