From 8519cc4737c8ed2c15e931e67108a7c9c73764b7 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 29 Jan 2024 14:53:12 -0800 Subject: [PATCH 1/3] fine tune connector process function Signed-off-by: Yaliang Wu --- .../ml/common/connector/HttpConnector.java | 2 +- .../connector/MLPostProcessFunction.java | 55 +++----- .../connector/MLPreProcessFunction.java | 43 +++--- .../BedrockEmbeddingPostProcessFunction.java | 42 ++++++ .../CohereRerankPostProcessFunction.java | 57 ++++++++ .../ConnectorPostProcessFunction.java | 27 ++++ .../EmbeddingPostProcessFunction.java | 50 +++++++ .../BedrockEmbeddingPreProcessFunction.java | 34 +++++ .../CohereEmbeddingPreProcessFunction.java | 34 +++++ .../CohereRerankPreProcessFunction.java | 40 ++++++ .../ConnectorPreProcessFunction.java | 58 ++++++++ .../preprocess/DefaultPreProcessFunction.java | 72 ++++++++++ .../OpenAIEmbeddingPreProcessFunction.java | 34 +++++ .../RemoteInferencePreProcessFunction.java | 62 +++++++++ .../opensearch/ml/common/input/MLInput.java | 18 ++- .../ml/common/utils/StringUtils.java | 23 +++ .../connector/MLPostProcessFunctionTest.java | 6 +- .../ml/common/utils/StringUtilsTest.java | 6 +- .../algorithms/remote/ConnectorUtils.java | 131 +++++++++--------- .../remote/RemoteConnectorExecutor.java | 7 +- .../ml/engine/utils/ScriptUtils.java | 6 - .../algorithms/remote/ConnectorUtilsTest.java | 8 +- .../algorithms/remote/RemoteModelTest.java | 2 +- .../ml/engine/utils/ScriptUtilsTest.java | 8 +- 24 files changed, 677 insertions(+), 148 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ef0e4bf4a1..d5c148f5e1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -291,7 +291,7 @@ public T createPredictPayload(Map parameters) { payload = substitutor.replace(payload); if (!isJson(payload)) { - throw new IllegalArgumentException("Invalid JSON in payload"); + throw new IllegalArgumentException("Invalid payload: " + payload); } return (T) payload; } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index f0b51233fa..4fb3f75412 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.connector; -import com.google.common.collect.ImmutableList; -import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; import org.opensearch.ml.common.output.model.ModelTensor; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -20,58 +20,41 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; + public static final String COHERE_RERANK = "connector.post_process.cohere.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; + public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); - private static final Map, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); - + private static final Map>> POST_PROCESS_FUNCTIONS = new HashMap<>(); static { + EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction(); + BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction(); + CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); - POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList()); - } - - public static Function, List> buildModelTensorList() { - return embeddings -> { - List modelTensors = new ArrayList<>(); - if (embeddings == null) { - throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); - } - if (embeddings.get(0) instanceof Number) { - embeddings = ImmutableList.of(embeddings); - } - embeddings.forEach(embedding -> { - List eachEmbedding = (List) embedding; - modelTensors.add( - ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{eachEmbedding.size()}) - .data(eachEmbedding.toArray(new Number[0])) - .build() - ); - }); - return modelTensors; - }; + JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); + JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); + POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); } public static String getResponseFilter(String postProcessFunction) { return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static Function, List> get(String postProcessFunction) { + public static Function> get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } public static boolean contains(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction); } -} +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 4021769806..d2d65ebdfd 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -5,43 +5,48 @@ package org.opensearch.ml.common.connector; +import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.function.Function; public class MLPreProcessFunction { - private static final Map, Map>> PRE_PROCESS_FUNCTIONS = new HashMap<>(); + private static final Map> PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; + public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; + public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; - private static Function, Map> cohereTextEmbeddingPreProcess() { - return inputs -> Map.of("parameters", Map.of("texts", inputs)); - } - - private static Function, Map> openAiTextEmbeddingPreProcess() { - return inputs -> Map.of("parameters", Map.of("input", inputs)); - } - - private static Function, Map> bedrockTextEmbeddingPreProcess() { - return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0))); - } + public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input"; + public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string"; static { - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess()); + CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction(); + OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); + BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); + CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction); } public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static Function, Map> get(String preProcessFunction) { - return PRE_PROCESS_FUNCTIONS.get(preProcessFunction); + public static Function get(String postProcessFunction) { + return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java new file mode 100644 index 0000000000..eb55253c01 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; + +public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + + List outerList = (List) input; + + if (!outerList.isEmpty() && !(((List)input).get(0) instanceof Number)) { + throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); + } + } + + @Override + public List process(List embedding) { + List modelTensors = new ArrayList<>(); + modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) + .build()); + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java new file mode 100644 index 0000000000..216fcc9d0a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction>> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + List outerList = (List) input; + if (!outerList.isEmpty()) { + if (!(outerList.get(0) instanceof Map)) { + throw new IllegalArgumentException("Post process function input is not a List of Map."); + } + Map innerMap = (Map) outerList.get(0); + + if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevance_score")) { + throw new IllegalArgumentException("The rerank result should contain index and relevance_score."); + } + } + } + + @Override + public List process(List> rerankResults) { + List modelTensors = new ArrayList<>(); + + if (rerankResults.size() > 0) { + Double[] scores = new Double[rerankResults.size()]; + for (int i = 0; i < rerankResults.size(); i++) { + Integer index = (Integer) rerankResults.get(i).get("index"); + scores[index] = (Double) rerankResults.get(i).get("relevance_score"); + } + + for (int i = 0; i < scores.length; i++) { + modelTensors.add(ModelTensor.builder() + .name("similarity") + .shape(new long[]{1}) + .data(new Number[]{scores[i]}) + .dataType(MLResultDataType.FLOAT32) + .build()); + } + } + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java new file mode 100644 index 0000000000..9cb81099c4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.List; +import java.util.function.Function; + +public abstract class ConnectorPostProcessFunction implements Function> { + + @Override + public List apply(Object input) { + if (input == null) { + throw new IllegalArgumentException("Can't run post process function as model output is null"); + } + validate(input); + return process((T)input); + } + + public abstract void validate(Object input); + + public abstract List process(T input); +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java new file mode 100644 index 0000000000..b03c791295 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; + +public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction>> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + + List outerList = (List) input; + + if (!outerList.isEmpty()) { + if (!(outerList.get(0) instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List of List."); + } + List innerList = (List) outerList.get(0); + + if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) { + throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); + } + } + } + + @Override + public List process(List> embeddings) { + List modelTensors = new ArrayList<>(); + embeddings.forEach(embedding -> modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) + .build() + )); + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..dae61b6c6c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public BedrockEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0))); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..d82210f4a3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class CohereEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public CohereEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of("texts", processTextDocs(inputData))); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java new file mode 100644 index 0000000000..c975f7f329 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class CohereRerankPreProcessFunction extends ConnectorPreProcessFunction { + + public CohereRerankPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of( + "query", inputData.getQueryText(), + "documents", inputData.getTextDocs(), + "top_n", inputData.getTextDocs().size() + )); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java new file mode 100644 index 0000000000..72ca6ce112 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@Log4j2 +public abstract class ConnectorPreProcessFunction implements Function { + + protected boolean returnDirectlyForRemoteInferenceInput; + + @Override + public RemoteInferenceInputDataSet apply(MLInput mlInput) { + if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + return (RemoteInferenceInputDataSet)mlInput.getInputDataset(); + } else { + validate(mlInput); + return process(mlInput); + } + } + + public abstract void validate(MLInput mlInput); + + public abstract RemoteInferenceInputDataSet process(MLInput mlInput); + + List processTextDocs(TextDocsInputDataSet inputDataSet) { + List docs = new ArrayList<>(); + for (String doc : inputDataSet.getDocs()) { + if (doc != null) { + String gsonString = gson.toJson(doc); + // in 2.9, user will add " before and after string + // gson.toString(string) will add extra " before after string, so need to remove + docs.add(gsonString.substring(1, gsonString.length() - 1)); + } else { + docs.add(null); + } + } + return docs; + } + + public void validateTextDocsInput(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java new file mode 100644 index 0000000000..6f128fdd51 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.experimental.FieldDefaults; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class DefaultPreProcessFunction extends ConnectorPreProcessFunction { + + ScriptService scriptService; + String preProcessFunction; + boolean convertInputToJsonString; + + @Builder + public DefaultPreProcessFunction(ScriptService scriptService, String preProcessFunction, boolean convertInputToJsonString) { + this.returnDirectlyForRemoteInferenceInput = false; + this.scriptService = scriptService; + this.preProcessFunction = preProcessFunction; + this.convertInputToJsonString = convertInputToJsonString; + } + + @Override + public void validate(MLInput mlInput) { + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + mlInput.toXContent(builder, EMPTY_PARAMS); + String inputStr = builder.toString(); + Map inputParams = gson.fromJson(inputStr, Map.class); + if (convertInputToJsonString) { + inputParams = convertScriptStringToJsonString(Map.of("parameters", gson.fromJson(inputStr, Map.class))); + } + String processedInput = executeScript(scriptService, preProcessFunction, inputParams); + if (processedInput == null) { + throw new IllegalArgumentException("Pre-process function output is null"); + } + Map map = gson.fromJson(processedInput, Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } catch (IOException e) { + throw new IllegalArgumentException("Failed to run pre-process function: Wrong input"); + } + } + + private String executeScript(ScriptService scriptService, String painlessScript, Map params) { + Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); + return templateScript.execute(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..32f294fdcc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class OpenAIEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public OpenAIEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of("input", processTextDocs(inputData))); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java new file mode 100644 index 0000000000..73cf91bee7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.experimental.FieldDefaults; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction { + + ScriptService scriptService; + String preProcessFunction; + + @Builder + public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction) { + this.returnDirectlyForRemoteInferenceInput = false; + this.scriptService = scriptService; + this.preProcessFunction = preProcessFunction; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support RemoteInferenceInputDataSet"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + Map inputParams = new HashMap<>(); + inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters()); + String processedInput = executeScript(scriptService, preProcessFunction, inputParams); + if (processedInput == null) { + throw new IllegalArgumentException("Input is null after processed by preprocess function"); + } + Map map = gson.fromJson(processedInput, Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } + + String executeScript(ScriptService scriptService, String painlessScript, Map params) { + Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); + return templateScript.execute(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index acd1522736..f2d74bf8c9 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -17,6 +17,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; @@ -30,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -59,6 +61,7 @@ public class MLInput implements Input { public static final String TEXT_DOCS_FIELD = "text_docs"; // Input query text to compare against for text similarity model public static final String QUERY_TEXT_FIELD = "query_text"; + public static final String PARAMETERS_FIELD = "parameters"; // Algorithm name protected FunctionName algorithm; @@ -163,18 +166,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } break; case TEXT_SIMILARITY: - TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset; - List tdocs = ds.getTextDocs(); - String queryText = ds.getQueryText(); + TextSimilarityInputDataSet inputDataSet = (TextSimilarityInputDataSet) this.inputDataset; + List documents = inputDataSet.getTextDocs(); + String queryText = inputDataSet.getQueryText(); builder.field(QUERY_TEXT_FIELD, queryText); - if (tdocs != null && !tdocs.isEmpty()) { + if (documents != null && !documents.isEmpty()) { builder.startArray(TEXT_DOCS_FIELD); - for(String d : tdocs) { + for(String d : documents) { builder.value(d); } builder.endArray(); } break; + case REMOTE: + RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset; + Map parameters = remoteInferenceInputDataSet.getParameters(); + builder.field(PARAMETERS_FIELD, parameters); + break; default: break; } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 43aa3c76ae..f66d1e58c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -8,6 +8,7 @@ import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonParser; +import lombok.extern.log4j.Log4j2; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; @@ -21,6 +22,7 @@ import java.util.List; import java.util.Map; +@Log4j2 public class StringUtils { public static final Gson gson; @@ -97,4 +99,25 @@ public static String toJson(Object value) { throw new RuntimeException(e); } } + + public static Map convertScriptStringToJsonString(Map processedInput) { + Map parameterStringMap = new HashMap<>(); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map parametersMap = (Map) processedInput.get("parameters"); + for (String key : parametersMap.keySet()) { + if (parametersMap.get(key) instanceof String) { + parameterStringMap.put(key, (String) parametersMap.get(key)); + } else { + parameterStringMap.put(key, gson.toJson(parametersMap.get(key))); + } + } + return null; + }); + } catch (PrivilegedActionException e) { + log.error("Error processing parameters", e); + throw new RuntimeException(e); + } + return parameterStringMap; + } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index c004c93a31..1c60ee8b16 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING; import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING; import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; public class MLPostProcessFunctionTest { @@ -43,15 +44,14 @@ public void test_getResponseFilter() { @Test public void test_buildModelTensorList() { - Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList()); List> numbersList = new ArrayList<>(); numbersList.add(Collections.singletonList(1.0f)); - Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList)); + Assert.assertNotNull(MLPostProcessFunction.get(DEFAULT_EMBEDDING).apply(numbersList)); } @Test public void test_buildModelTensorList_exception() { exceptionRule.expect(IllegalArgumentException.class); - MLPostProcessFunction.buildModelTensorList().apply(null); + MLPostProcessFunction.get(DEFAULT_EMBEDDING).apply(null); } } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index a4b34d75b5..3022c97e0a 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.utils; import org.junit.Assert; @@ -87,7 +92,6 @@ public void getParameterMap() { parameters.put("key4", new int[]{10, 20}); parameters.put("key5", new Object[]{1.01, "abc"}); Map parameterMap = StringUtils.getParameterMap(parameters); - System.out.println(parameterMap); Assert.assertEquals(5, parameterMap.size()); Assert.assertEquals("value1", parameterMap.get("key1")); Assert.assertEquals("2", parameterMap.get("key2")); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 6a63d68ff5..c3e385ca4e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -7,20 +7,18 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT; import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; -import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; @@ -28,6 +26,8 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -63,14 +63,11 @@ public static RemoteInferenceInputDataSet processInput( if (mlInput == null) { throw new IllegalArgumentException("Input is null"); } - RemoteInferenceInputDataSet inputData; - if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService); - } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { - inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); - } else { - throw new IllegalArgumentException("Wrong input type"); + Optional predictAction = connector.findPredictAction(); + if (predictAction.isEmpty()) { + throw new IllegalArgumentException("no predict action found"); } + RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService); if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); inputData.getParameters().forEach((key, value) -> { @@ -88,65 +85,56 @@ public static RemoteInferenceInputDataSet processInput( return inputData; } - private static RemoteInferenceInputDataSet processTextDocsInput( - TextDocsInputDataSet inputDataSet, + private static RemoteInferenceInputDataSet processMLInput( + MLInput mlInput, Connector connector, Map parameters, ScriptService scriptService ) { - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); - } - String preProcessFunction = predictAction.get().getPreProcessFunction(); - preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; - if (MLPreProcessFunction.contains(preProcessFunction)) { - Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(inputDataSet.getDocs()); - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); + String preProcessFunction = getPreprocessFunction(mlInput, connector); + if (preProcessFunction == null) { + if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + } else { + throw new IllegalArgumentException("pre_process_function not defined in connector"); + } } else { - List docs = new ArrayList<>(); - for (String doc : inputDataSet.getDocs()) { - if (doc != null) { - String gsonString = gson.toJson(doc); - // in 2.9, user will add " before and after string - // gson.toString(string) will add extra " before after string, so need to remove - docs.add(gsonString.substring(1, gsonString.length() - 1)); + preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction); + if (MLPreProcessFunction.contains(preProcessFunction)) { + Function function = MLPreProcessFunction.get(preProcessFunction); + return function.apply(mlInput); + } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT) + && Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) { + RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + return function.apply(mlInput); } else { - docs.add(null); + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); } + } else { + boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING) + && Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING)); + DefaultPreProcessFunction function = DefaultPreProcessFunction + .builder() + .scriptService(scriptService) + .preProcessFunction(preProcessFunction) + .convertInputToJsonString(convertInputToJsonString) + .build(); + return function.apply(mlInput); } - if (preProcessFunction.contains("${parameters.")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - preProcessFunction = substitutor.replace(preProcessFunction); - } - Optional processedInput = executePreprocessFunction(scriptService, preProcessFunction, docs); - if (processedInput.isEmpty()) { - throw new IllegalArgumentException("Wrong input"); - } - Map map = gson.fromJson(processedInput.get(), Map.class); - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); } } - private static Map convertScriptStringToJsonString(Map processedInput) { - Map parameterStringMap = new HashMap<>(); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - Map parametersMap = (Map) processedInput.get("parameters"); - for (String key : parametersMap.keySet()) { - if (parametersMap.get(key) instanceof String) { - parameterStringMap.put(key, (String) parametersMap.get(key)); - } else { - parameterStringMap.put(key, gson.toJson(parametersMap.get(key))); - } - } - return null; - }); - } catch (PrivilegedActionException e) { - log.error("Error processing parameters", e); - throw new RuntimeException(e); + private static String getPreprocessFunction(MLInput mlInput, Connector connector) { + Optional predictAction = connector.findPredictAction(); + String preProcessFunction = predictAction.get().getPreProcessFunction(); + if (preProcessFunction != null) { + return preProcessFunction; + } + if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { + return MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT; } - return parameterStringMap; + return null; } public static ModelTensors processOutput( @@ -165,21 +153,16 @@ public static ModelTensors processOutput( } ConnectorAction connectorAction = predictAction.get(); String postProcessFunction = connectorAction.getPostProcessFunction(); - if (postProcessFunction != null && postProcessFunction.contains("${parameters")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - postProcessFunction = substitutor.replace(postProcessFunction); - } + postProcessFunction = fillProcessFunctionParameter(parameters, postProcessFunction); String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); if (MLPostProcessFunction.contains(postProcessFunction)) { // in this case, we can use jsonpath to build a List> result from model response. if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); - List vectors = JsonPath.read(modelResponse, responseFilter); - List processedResponse = executeBuildInPostProcessFunction( - vectors, - MLPostProcessFunction.get(postProcessFunction) - ); + + Object filteredOutput = JsonPath.read(modelResponse, responseFilter); + List processedResponse = MLPostProcessFunction.get(postProcessFunction).apply(filteredOutput); return ModelTensors.builder().mlModelTensors(processedResponse).build(); } @@ -198,6 +181,18 @@ public static ModelTensors processOutput( return ModelTensors.builder().mlModelTensors(modelTensors).build(); } + private static String fillProcessFunctionParameter(Map parameters, String processFunction) { + if (processFunction != null && processFunction.contains("${parameters.")) { + Map tmpParameters = new HashMap<>(); + for (String key : parameters.keySet()) { + tmpParameters.put(key, gson.toJson(parameters.get(key))); + } + StringSubstitutor substitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}"); + processFunction = substitutor.replace(processFunction); + } + return processFunction; + } + public static SdkHttpFullRequest signRequest( SdkHttpFullRequest request, String accessKey, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index be50af3aff..4f46c67906 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -106,14 +106,17 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List inputParameters = new HashMap<>(); if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) { - parameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); + inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); } - + parameters.putAll(inputParameters); RemoteInferenceInputDataSet inputData = processInput(mlInput, connector, parameters, getScriptService()); if (inputData.getParameters() != null) { parameters.putAll(inputData.getParameters()); } + // override again to always prioritize the input parameter + parameters.putAll(inputParameters); String payload = connector.createPredictPayload(parameters); connector.validatePayload(payload); String userStr = getClient() diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index cc721e9129..46d7794c6c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -9,9 +9,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; -import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.Script; import org.opensearch.script.ScriptService; @@ -30,10 +28,6 @@ public static Optional executePreprocessFunction( return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences))); } - public static List executeBuildInPostProcessFunction(List vectors, Function, List> function) { - return function.apply(vectors); - } - public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); if (postProcessFunction != null) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 8ad745340b..178d9aa722 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -190,12 +190,12 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio .parameters(parameters) .actions(Arrays.asList(predictAction)) .build(); - ModelTensors tensors = ConnectorUtils - .processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of()); + String modelResponse = + "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; + ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); + Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index ea3d883ddd..149848327b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -69,7 +69,7 @@ public void predict_NullConnectorExecutor() { @Test public void predict_ModelDeployed_WrongInput() { exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("Wrong input type"); + exceptionRule.expectMessage("pre_process_function not defined in connector"); Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java index bea44ebf48..b9faeafafb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.engine.utils; import static org.junit.Assert.assertEquals; @@ -40,8 +45,7 @@ public void test_executePreprocessFunction() { @Test public void test_executeBuildInPostProcessFunction() { List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); - List modelTensors = ScriptUtils - .executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING)); + List modelTensors = MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING).apply(input); assertNotNull(modelTensors); assertEquals(2, modelTensors.size()); } From 664d7159de79616e28b8617e8b59b10d1d1ab4c2 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 29 Jan 2024 17:25:10 -0800 Subject: [PATCH 2/3] add unit test for process function Signed-off-by: Yaliang Wu --- common/build.gradle | 1 + .../EmbeddingPostProcessFunction.java | 2 +- .../ConnectorPreProcessFunction.java | 3 + .../preprocess/DefaultPreProcessFunction.java | 2 +- .../RemoteInferencePreProcessFunction.java | 2 +- ...drockEmbeddingPostProcessFunctionTest.java | 48 +++++++++++ .../CohereRerankPostProcessFunctionTest.java | 60 ++++++++++++++ .../EmbeddingPostProcessFunctionTest.java | 59 +++++++++++++ ...edrockEmbeddingPreProcessFunctionTest.java | 55 ++++++++++++ ...CohereEmbeddingPreProcessFunctionTest.java | 55 ++++++++++++ .../CohereRerankPreProcessFunctionTest.java | 57 +++++++++++++ .../DefaultPreProcessFunctionTest.java | 83 +++++++++++++++++++ ...OpenAIEmbeddingPreProcessFunctionTest.java | 55 ++++++++++++ ...RemoteInferencePreProcessFunctionTest.java | 82 ++++++++++++++++++ 14 files changed, 561 insertions(+), 3 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java diff --git a/common/build.gradle b/common/build.gradle index 03de5526ac..619f8dffe5 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -20,6 +20,7 @@ dependencies { compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}" compileOnly "org.opensearch:common-utils:${common_utils_version}" testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' + testImplementation "org.opensearch.test:framework:${opensearch_version}" compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java index b03c791295..6e6d373302 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java @@ -23,7 +23,7 @@ public void validate(Object input) { if (!outerList.isEmpty()) { if (!(outerList.get(0) instanceof List)) { - throw new IllegalArgumentException("Post process function input is not a List of List."); + throw new IllegalArgumentException("The embedding should be a non-empty List containing List of Float values."); } List innerList = (List) outerList.get(0); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 72ca6ce112..d29c70048e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -23,6 +23,9 @@ public abstract class ConnectorPreProcessFunction implements Function map = gson.fromJson(processedInput, Map.class); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java index 73cf91bee7..a8c549ea3b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -48,7 +48,7 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) { inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters()); String processedInput = executeScript(scriptService, preProcessFunction, inputParams); if (processedInput == null) { - throw new IllegalArgumentException("Input is null after processed by preprocess function"); + throw new IllegalArgumentException("Preprocess function output is null"); } Map map = gson.fromJson(processedInput, Map.class); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java new file mode 100644 index 0000000000..9a97ac0374 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java @@ -0,0 +1,48 @@ +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class BedrockEmbeddingPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BedrockEmbeddingPostProcessFunction function; + + @Before + public void setUp() { + function = new BedrockEmbeddingPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a List."); + function.apply("abc"); + } + + @Test + public void process_WrongInput_NotNumberList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The embedding should be a non-empty List containing Float values."); + function.apply(Arrays.asList("abc")); + } + + @Test + public void process_CorrectInput() { + List result = function.apply(Arrays.asList(1.1, 1.2, 1.3)); + assertEquals(1, result.size()); + assertEquals(3, result.get(0).getData().length); + assertEquals(1.1, result.get(0).getData()[0]); + assertEquals(1.2, result.get(0).getData()[1]); + assertEquals(1.3, result.get(0).getData()[2]); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java new file mode 100644 index 0000000000..8a17188169 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java @@ -0,0 +1,60 @@ +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class CohereRerankPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CohereRerankPostProcessFunction function; + + @Before + public void setUp() { + function = new CohereRerankPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a List."); + function.apply("abc"); + } + + @Test + public void process_WrongInput_NotCorrectList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a List of Map."); + function.apply(Arrays.asList("abc")); + } + + @Test + public void process_WrongInput_NotCorrectMap() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The rerank result should contain index and relevance_score."); + function.apply(Arrays.asList(Map.of("test1", "value1"))); + } + + @Test + public void process_CorrectInput() { + List> rerankResults = List.of( + Map.of("index", 2, "relevance_score", 0.5), + Map.of("index", 1, "relevance_score", 0.4), + Map.of("index", 0, "relevance_score", 0.3)); + List result = function.apply(rerankResults); + assertEquals(3, result.size()); + assertEquals(1, result.get(0).getData().length); + assertEquals(0.3, result.get(0).getData()[0]); + assertEquals(0.4, result.get(1).getData()[0]); + assertEquals(0.5, result.get(2).getData()[0]); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java new file mode 100644 index 0000000000..bc3403eb1b --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java @@ -0,0 +1,59 @@ +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class EmbeddingPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + EmbeddingPostProcessFunction function; + + @Before + public void setUp() { + function = new EmbeddingPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a List."); + function.apply("abc"); + } + + @Test + public void process_WrongInput_NotListOfList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The embedding should be a non-empty List containing List of Float values."); + function.apply(Arrays.asList("abc")); + } + + @Test + public void process_WrongInput_NotListOfNumber() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The embedding should be a non-empty List containing Float values."); + function.apply(List.of(Arrays.asList("abc"))); + } + + @Test + public void process_CorrectInput() { + List result = function.apply(List.of(List.of(1.1, 1.2, 1.3), List.of(2.1, 2.2, 2.3))); + assertEquals(2, result.size()); + assertEquals(3, result.get(0).getData().length); + assertEquals(1.1, result.get(0).getData()[0]); + assertEquals(1.2, result.get(0).getData()[1]); + assertEquals(1.3, result.get(0).getData()[2]); + assertEquals(3, result.get(1).getData().length); + assertEquals(2.1, result.get(1).getData()[0]); + assertEquals(2.2, result.get(1).getData()[1]); + assertEquals(2.3, result.get(1).getData()[2]); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..40d3e0effe --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,55 @@ +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; + +public class BedrockEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BedrockEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + function = new BedrockEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("hello", dataSet.getParameters().get("inputText")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..2782681431 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,55 @@ +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; + +public class CohereEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CohereEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + function = new CohereEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("[\"hello\",\"world\"]", dataSet.getParameters().get("texts")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java new file mode 100644 index 0000000000..5fd7edb015 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java @@ -0,0 +1,57 @@ +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; + +public class CohereRerankPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + CohereRerankPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + function = new CohereRerankPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextSimilarityInputDataSet"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(3, dataSet.getParameters().size()); + assertEquals("test", dataSet.getParameters().get("query")); + assertEquals("[\"hello\"]", dataSet.getParameters().get("documents")); + assertEquals("1", dataSet.getParameters().get("top_n")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java new file mode 100644 index 0000000000..2b50827225 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java @@ -0,0 +1,83 @@ +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.ScriptService; + +import java.util.Arrays; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class DefaultPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + DefaultPreProcessFunction functionWithConvertToJsonString; + DefaultPreProcessFunction functionWithoutConvertToJsonString; + + @Mock + ScriptService scriptService; + + String preProcessFunction; + + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + preProcessFunction = ""; + functionWithConvertToJsonString = new DefaultPreProcessFunction(scriptService, preProcessFunction, true); + functionWithoutConvertToJsonString = new DefaultPreProcessFunction(scriptService, preProcessFunction, false); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + functionWithConvertToJsonString.apply(null); + } + + @Test + public void process_CorrectInput_WrongProcessedResult() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function output is null"); + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + functionWithConvertToJsonString.apply(mlInput); + } + + @Test + public void process_CorrectInput_WrongProcessedResult_WithoutConvertToJsonString() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function output is null"); + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + functionWithoutConvertToJsonString.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + String preprocessResult = "{\"parameters\": { \"input\": \"test doc1\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = functionWithConvertToJsonString.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("test doc1", dataSet.getParameters().get("input")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java new file mode 100644 index 0000000000..6085e9e644 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java @@ -0,0 +1,55 @@ +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; + +public class OpenAIEmbeddingPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + OpenAIEmbeddingPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + function = new OpenAIEmbeddingPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("[\"hello\",\"world\"]", dataSet.getParameters().get("input")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java new file mode 100644 index 0000000000..fa7f116030 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java @@ -0,0 +1,82 @@ +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.ScriptService; + +import java.util.Arrays; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class RemoteInferencePreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + RemoteInferencePreProcessFunction function; + + @Mock + ScriptService scriptService; + + String preProcessFunction; + + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + preProcessFunction = ""; + function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support RemoteInferenceInputDataSet"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput_WrongProcessedResult() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function output is null"); + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + String preprocessResult = "{\"parameters\": { \"input\": \"test doc1\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(1, dataSet.getParameters().size()); + assertEquals("test doc1", dataSet.getParameters().get("input")); + } +} From 8746b23bd6350ef9b205f6c10b2e21c90f5941f4 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 30 Jan 2024 03:44:34 -0800 Subject: [PATCH 3/3] add license header Signed-off-by: Yaliang Wu --- .../common/connector/HttpConnectorTest.java | 16 +++++++++++- ...drockEmbeddingPostProcessFunctionTest.java | 5 ++++ .../CohereRerankPostProcessFunctionTest.java | 5 ++++ .../EmbeddingPostProcessFunctionTest.java | 5 ++++ ...edrockEmbeddingPreProcessFunctionTest.java | 25 +++++++++++++++++-- ...CohereEmbeddingPreProcessFunctionTest.java | 5 ++++ .../CohereRerankPreProcessFunctionTest.java | 5 ++++ .../DefaultPreProcessFunctionTest.java | 5 ++++ ...OpenAIEmbeddingPreProcessFunctionTest.java | 5 ++++ ...RemoteInferencePreProcessFunctionTest.java | 5 ++++ 10 files changed, 78 insertions(+), 3 deletions(-) diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 8e51a06c38..f6bbebf8a5 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -173,6 +173,16 @@ public void createPredictPayload_Invalid() { connector.validatePayload(predictPayload); } + @Test + public void createPredictPayload_InvalidJson() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Invalid payload: {\"input\": ${parameters.input} }"); + String requestBody = "{\"input\": ${parameters.input} }"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + String predictPayload = connector.createPredictPayload(null); + connector.validatePayload(predictPayload); + } + @Test public void createPredictPayload() { HttpConnector connector = createHttpConnector(); @@ -268,12 +278,16 @@ public void fillNullParameters() { } public static HttpConnector createHttpConnector() { + String requestBody = "{\"input\": \"${parameters.input}\"}"; + return createHttpConnectorWithRequestBody(requestBody); + } + + public static HttpConnector createHttpConnectorWithRequestBody(String requestBody) { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("api_key", "${credential.key}"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java index 9a97ac0374..224e807031 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.postprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java index 8a17188169..5e8cfd4319 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.postprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java index bc3403eb1b..01240759ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.postprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java index 40d3e0effe..eb50befdf9 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.preprocess; import org.junit.Before; @@ -11,6 +16,7 @@ import org.opensearch.ml.common.input.MLInput; import java.util.Arrays; +import java.util.Map; import static org.junit.Assert.assertEquals; @@ -22,12 +28,22 @@ public class BedrockEmbeddingPreProcessFunctionTest { TextSimilarityInputDataSet textSimilarityInputDataSet; TextDocsInputDataSet textDocsInputDataSet; + RemoteInferenceInputDataSet remoteInferenceInputDataSet; + + MLInput textEmbeddingInput; + MLInput textSimilarityInput; + MLInput remoteInferenceInput; @Before public void setUp() { function = new BedrockEmbeddingPreProcessFunction(); textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); + + textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); } @Test @@ -41,8 +57,7 @@ public void process_NullInput() { public void process_WrongInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); - function.apply(mlInput); + function.apply(textSimilarityInput); } @Test @@ -52,4 +67,10 @@ public void process_CorrectInput() { assertEquals(1, dataSet.getParameters().size()); assertEquals("hello", dataSet.getParameters().get("inputText")); } + + @Test + public void process_RemoteInferenceInput() { + RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); + assertEquals(remoteInferenceInputDataSet, dataSet); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java index 2782681431..f739796ae8 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.preprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java index 5fd7edb015..d8a6f4d311 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.preprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java index 2b50827225..93d23b338a 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.preprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java index 6085e9e644..e4a08ed550 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.preprocess; import org.junit.Before; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java index fa7f116030..14fed71efc 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector.functions.preprocess; import org.junit.Before;