Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fine tune connector process function #1954

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
compileOnly "org.opensearch:common-utils:${common_utils_version}"
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
testImplementation "org.opensearch.test:framework:${opensearch_version}"

compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid JSON in payload");
throw new IllegalArgumentException("Invalid payload: " + payload);
}
return (T) payload;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

package org.opensearch.ml.common.connector;

import com.google.common.collect.ImmutableList;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -20,58 +20,41 @@ public class MLPostProcessFunction {
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();

private static final Map<String, Function<Object, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();

static {
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList());
}

public static Function<List<?>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
if (embeddings.get(0) instanceof Number) {
embeddings = ImmutableList.of(embeddings);
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{eachEmbedding.size()})
.data(eachEmbedding.toArray(new Number[0]))
.build()
);
});
return modelTensors;
};
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
}

public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
public static Function<Object, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

public static boolean contains(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,48 @@

package org.opensearch.ml.common.connector;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPreProcessFunction {

private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
}

private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
}
public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string";

static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction();
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;

public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<Float>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty() && !(((List<?>)input).get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
}
}

@Override
public List<ModelTensor> process(List<Float> embedding) {
List<ModelTensor> modelTensors = new ArrayList<>();
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build());
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}
List<?> outerList = (List<?>) input;
if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof Map)) {
throw new IllegalArgumentException("Post process function input is not a List of Map.");
}
Map innerMap = (Map) outerList.get(0);

if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevance_score")) {
throw new IllegalArgumentException("The rerank result should contain index and relevance_score.");
}
}
}

@Override
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
List<ModelTensor> modelTensors = new ArrayList<>();

if (rerankResults.size() > 0) {
Double[] scores = new Double[rerankResults.size()];
for (int i = 0; i < rerankResults.size(); i++) {
Integer index = (Integer) rerankResults.get(i).get("index");
scores[index] = (Double) rerankResults.get(i).get("relevance_score");
}

for (int i = 0; i < scores.length; i++) {
modelTensors.add(ModelTensor.builder()
.name("similarity")
.shape(new long[]{1})
.data(new Number[]{scores[i]})
.dataType(MLResultDataType.FLOAT32)
.build());
}
}
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.List;
import java.util.function.Function;

public abstract class ConnectorPostProcessFunction<T> implements Function<Object, List<ModelTensor>> {

@Override
public List<ModelTensor> apply(Object input) {
if (input == null) {
throw new IllegalArgumentException("Can't run post process function as model output is null");
}
validate(input);
return process((T)input);
}

public abstract void validate(Object input);

public abstract List<ModelTensor> process(T input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;

public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<List<Float>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof List)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing List of Float values.");
}
List<?> innerList = (List<?>) outerList.get(0);

if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
}
}
}

@Override
public List<ModelTensor> process(List<List<Float>> embeddings) {
List<ModelTensor> modelTensors = new ArrayList<>();
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;


public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public BedrockEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Loading
Loading