Skip to content

Commit

Permalink
[Backport 2.16] Fix custom prompt substitute with List issue in ml in…
Browse files Browse the repository at this point in the history
…ference search response processor (#2887)

* match any placeholder starts from parameters and end with toString()

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

* fix custom prompt issues

Signed-off-by: Mingshi Liu <mingshl@amazon.com>

---------

Signed-off-by: Mingshi Liu <mingshl@amazon.com>
  • Loading branch information
mingshl authored Sep 6, 2024
1 parent 7915fc3 commit 302e070
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -308,20 +309,21 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}

@Override
public <T> T createPayload(String action, Map<String, String> parameters) {
public <T> T createPayload(String action, Map<String, String> parameters) {
Optional<ConnectorAction> connectorAction = findAction(action);
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
parseParameters(parameters);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
}
return (T) payload;
}
return (T) parameters.get("http_body");

}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;

@Log4j2
public class StringUtils {
Expand All @@ -56,6 +54,7 @@ public class StringUtils {
static {
gson = new Gson();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static boolean isValidJsonString(String Json) {
try {
Expand Down Expand Up @@ -239,4 +238,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
* @param map A map containing key-value pairs where the values may contain toString() method calls.
* @return A list of prefixes for the toString() method calls found in the map values.
*/
public static List<String> collectToStringPrefixes(Map<String, String> map) {
List<String> prefixes = new ArrayList<>();
for (String key : map.keySet()) {
String value = map.get(key);
if (value != null) {
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String prefix = matcher.group(1);
prefixes.add(prefix);
}
}
}
return prefixes;
}

/**
* Parses the given parameters map and processes the values containing toString() method calls.
*
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
* @return A new map with the processed values for the toString() method calls.
*/
public static Map<String, String> parseParameters(Map<String, String> parameters) {
if (parameters != null) {
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);

if (!toStringParametersPrefixes.isEmpty()) {
for (String prefix : toStringParametersPrefixes) {
String value = parameters.get(prefix);
if (value != null) {
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
}
}
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@

package org.opensearch.ml.common.utils;

import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static java.util.stream.Collectors.toList;
import static org.junit.Assert.assertEquals;
import org.apache.commons.text.StringSubstitutor;
import org.junit.Assert;
import org.junit.Test;

public class StringUtilsTest {

Expand Down Expand Up @@ -218,4 +223,203 @@ public void testGetErrorMessageWhenHiddenNull() {
// Assert
assertEquals(expected, result);
}

/**
* Tests the collectToStringPrefixes method with a map containing toString() method calls
* in the values. Verifies that the method correctly extracts the prefixes of the toString()
* method calls.
*/
@Test
public void testGetToStringPrefix() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}"
);
parameters.put("context", "${parameters.text.toString()}");

List<String> prefixes = collectToStringPrefixes(parameters);
List<String> expectPrefixes = new ArrayList<>();
expectPrefixes.add("text");
expectPrefixes.add("context");
expectPrefixes.add("history");
assertEquals(prefixes, expectPrefixes);
}

/**
* Tests the parseParameters method with a map containing a list of strings as the value
* for the "context" key. Verifies that the method correctly processes the list and adds
* the processed value to the map with the expected key. Also tests the string substitution
* using the processed values.
*/
@Test
public void testParseParametersListToString() {
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
parameters.put("context", toJson(listOfDocuments));

parseParameters(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
assertEquals(requestBody, "{\"prompt\": \"answer question based on context: [\\\"document1\\\"]\"}");
}

/**
* Tests the parseParameters method with a map containing a list of strings as the value
* for the "context" key, and the "prompt" value containing escaped characters. Verifies
* that the method correctly processes the list and adds the processed value to the map
* with the expected key. Also tests the string substitution using the processed values.
*/
@Test
public void testParseParametersListToStringWithEscapedPrompt() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
parameters.put("context", toJson(listOfDocuments));

parseParameters(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}"
);
}

/**
* Tests the parseParameters method with a map containing a list of strings as the value
* for the "context" key, and the "prompt" value containing escaped characters. Verifies
* that the method correctly processes the list and adds the processed value to the map
* with the expected key. Also tests the string substitution using the processed values.
*/
@Test
public void testParseParametersListToStringModelConfig() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.model_config.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
parameters.put("model_config.context", toJson(listOfDocuments));

parseParameters(parameters);
assertEquals(parameters.get("model_config.context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}"
);
}

/**
* Tests the parseParameters method with a map containing a nested list of strings as the
* value for the "context" key. Verifies that the method correctly processes the nested
* list and adds the processed value to the map with the expected key. Also tests the
* string substitution using the processed values.
*/
@Test
public void testParseParametersNestedListToString() {
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));

parseParameters(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"answer question based on context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]\"}"
);
}

/**
* Tests the parseParameters method with a map containing a map of strings as the value
* for the "context" key. Verifies that the method correctly processes the map and adds
* the processed value to the map with the expected key. Also tests the string substitution
* using the processed values.
*/
@Test
public void testParseParametersMapToString() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}"
);
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
parameters.put("context", toJson(mapOfDocuments));
parameters.put("history", "hello\n");
parseParameters(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "{\\\"name\\\":\\\"John\\\"}");
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"answer question based on context: {\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}"
);
}

/**
* Tests the parseParameters method with a map containing a nested map of strings as the
* value for the "context" key. Verifies that the method correctly processes the nested
* map and adds the processed value to the map with the expected key. Also tests the
* string substitution using the processed values.
*/
@Test
public void testParseParametersNestedMapToString() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}"
);
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
parameters.put("context", toJson(mapOfDocuments));
parameters.put("history", "hello\n");
parseParameters(parameters);
assertEquals(
parameters.get("context" + TO_STRING_FUNCTION_NAME),
"{\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"}"
);
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"answer question based on context: {\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ private void processPredictionsManyToOne(
}
}
}

modelParameters = StringUtils.getParameterMap(modelInputParameters);
Map<String, String> modelParametersInString = StringUtils.getParameterMap(modelInputParameters);
modelParameters.putAll(modelParametersInString);

Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
inputMapKeys.removeAll(modelConfigs.keySet());
Expand Down

0 comments on commit 302e070

Please sign in to comment.