Skip to content

Commit

Permalink
[ML] Bug truncate input text for the inference API (#103027) (#103338)
Browse files Browse the repository at this point in the history
* Refactoring to support truncation

* Adding failing test reminders

* Adding tests

* Tracking truncation in request

* Passing tests

* Adding parser test

* refactoring access

* Fixing bug

* Adding test for checkModelConfig

* addressing pr feedback

* Fixing elser token limit and using array of booleans

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
jonathan-buttner and elasticmachine authored Dec 12, 2023
1 parent dc3d00e commit 013fa94
Show file tree
Hide file tree
Showing 63 changed files with 1,559 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.common.xcontent;

import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
Expand Down Expand Up @@ -178,4 +179,33 @@ public static <T> List<T> parseList(XContentParser parser, CheckedFunction<XCont
} while (parser.nextToken() != Token.END_ARRAY);
return list;
}

/**
* This is the same as {@link #parseList(XContentParser, CheckedFunction)}
* except that it passes the array index while parsing the array. Parses a list of a given type from the given {@code parser}
* while passing the valueParser the current array index.
* Assumes that the parser is currently positioned on a {@link Token#START_ARRAY} token and will fail if it is not.
* The returned list may or may not be mutable.
*
* @param parser x-content parser
* @param valueParser parser for expected list value type
* @return list parsed from parser
*/
public static <T> List<T> parseList(XContentParser parser, CheckedBiFunction<XContentParser, Integer, T, IOException> valueParser)
throws IOException {
XContentParserUtils.ensureExpectedToken(Token.START_ARRAY, parser.currentToken(), parser);

if (parser.nextToken() == Token.END_ARRAY) {
return List.of();
}

final ArrayList<T> list = new ArrayList<>();

int index = 0;
do {
list.add(valueParser.apply(parser, index++));
} while (parser.nextToken() != Token.END_ARRAY);

return list;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
Expand Down Expand Up @@ -285,4 +289,35 @@ public void testParseTypedKeysObjectErrors() throws IOException {
}
}
}

public void testParseListWithIndex_IncrementsIndexBy1ForEachEntryInList() throws IOException {
String jsonString = """
["a", "b", "c"]
""";

var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

List<String> results;
var indices = new ArrayList<Integer>();

try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(parserConfig, jsonString.getBytes(StandardCharsets.UTF_8))
) {
if (jsonParser.currentToken() == null) {
jsonParser.nextToken();
}

results = XContentParserUtils.parseList(jsonParser, (parser, index) -> {
XContentParser.Token token = parser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_STRING, token, parser);
indices.add(index);

return parser.text();
});
}

assertThat(results, Matchers.is(List.of("a", "b", "c")));
assertThat(indices, Matchers.is(List.of(0, 1, 2)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
Expand Down Expand Up @@ -114,7 +115,8 @@ public List<RestHandler> getRestHandlers(
@Override
public Collection<?> createComponents(PluginServices services) {
var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings));
var truncator = new Truncator(settings, services.clusterService());
serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));

httpManager.set(HttpClientManager.create(settings, services.threadPool(), services.clusterService(), throttlerManager));

Expand Down Expand Up @@ -211,7 +213,8 @@ public List<Setting<?>> getSettings() {
HttpClientManager.getSettings(),
HttpRequestSenderFactory.HttpRequestSender.getSettings(),
ThrottlerManager.getSettings(),
RetrySettings.getSettingsDefinitions()
RetrySettings.getSettingsDefinitions(),
Truncator.getSettings()
).flatMap(Collection::stream).collect(Collectors.toList());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
* Provides truncation logic for inference requests
*/
public class Truncator {

/**
* Defines the percentage to reduce the input text for an inference request.
*/
static final Setting<Double> REDUCTION_PERCENTAGE_SETTING = Setting.doubleSetting(
"xpack.inference.truncator.reducation_percentage",
0.5,
0.01,
0.99,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static List<Setting<?>> getSettings() {
return List.of(REDUCTION_PERCENTAGE_SETTING);
}

/**
* OpenAI estimates that there are 4 characters per token
* <a href="https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them">here</a>.
* We'll take a conservative approach and assume there's a token every 3 characters.
*/
private static final double CHARS_PER_TOKEN = 3;

public static double countTokens(String text) {
return Math.ceil(text.length() / CHARS_PER_TOKEN);
}

private volatile double reductionPercentage;

public Truncator(Settings settings, ClusterService clusterService) {
this.reductionPercentage = REDUCTION_PERCENTAGE_SETTING.get(settings);

clusterService.getClusterSettings().addSettingsUpdateConsumer(REDUCTION_PERCENTAGE_SETTING, this::setReductionPercentage);
}

private void setReductionPercentage(double percentage) {
reductionPercentage = percentage;
}

/**
* Truncate each entry in the list to the specified number of tokens.
* @param input list of strings
* @param tokenLimit the number of tokens to limit the text to
* @return the resulting list of text and whether it was truncated
*/
public static TruncationResult truncate(List<String> input, @Nullable Integer tokenLimit) {
if (tokenLimit == null) {
return new TruncationResult(input, new boolean[input.size()]);
}

var maxLength = maxLength(tokenLimit);

var truncatedText = new ArrayList<String>(input.size());
var wasTruncated = new boolean[input.size()];

for (int i = 0; i < input.size(); i++) {
var text = input.get(i);
var truncateResult = truncate(text, maxLength);
truncatedText.add(truncateResult.input);
wasTruncated[i] = truncateResult.truncated;
}

return new TruncationResult(truncatedText, wasTruncated);
}

private static int maxLength(Integer maxTokens) {
if (maxTokens == null) {
return Integer.MAX_VALUE;
}

return (int) Math.floor(maxTokens * CHARS_PER_TOKEN);
}

private static TruncationEntry truncate(String text, int textLength) {
var truncatedText = text.substring(0, Math.min(text.length(), textLength));
var truncated = truncatedText.length() < text.length();

return new TruncationEntry(truncatedText, truncated);
}

/**
* Truncate each entry in the list by the percentage value specified in the {@link #REDUCTION_PERCENTAGE_SETTING} setting.
* @param input list of strings
* @return the resulting list of text and whether it was truncated
*/
public TruncationResult truncate(List<String> input) {
var truncatedText = new ArrayList<String>(input.size());
var wasTruncated = new boolean[input.size()];

for (int i = 0; i < input.size(); i++) {
var text = input.get(i);
var truncateResult = truncate(text);
truncatedText.add(truncateResult.input);
wasTruncated[i] = truncateResult.truncated;
}

return new TruncationResult(truncatedText, wasTruncated);
}

private TruncationEntry truncate(String text) {
var length = (int) Math.floor(text.length() * reductionPercentage);
return truncate(text, length);
}

private record TruncationEntry(String input, boolean truncated) {}

public record TruncationResult(List<String> input, boolean[] truncated) {

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TruncationResult that = (TruncationResult) o;
return Objects.equals(input, that.input) && Arrays.equals(truncated, that.truncated);
}

@Override
public int hashCode() {
return Objects.hash(input, Arrays.hashCode(truncated));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.huggingface.HuggingFaceAccount;
import org.elasticsearch.xpack.inference.external.request.huggingface.HuggingFaceInferenceRequest;
import org.elasticsearch.xpack.inference.external.request.huggingface.HuggingFaceInferenceRequestEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

Expand All @@ -37,6 +38,8 @@ public class HuggingFaceAction implements ExecutableAction {
private final String errorMessage;
private final RetryingHttpSender sender;
private final ResponseHandler responseHandler;
private final Truncator truncator;
private final Integer tokenLimit;

public HuggingFaceAction(
Sender sender,
Expand All @@ -60,15 +63,19 @@ public HuggingFaceAction(
);
this.account = new HuggingFaceAccount(model.getUri(), model.getApiKey());
this.errorMessage = format("Failed to send Hugging Face %s request to [%s]", requestType, model.getUri().toString());
this.truncator = Objects.requireNonNull(serviceComponents.truncator());
this.tokenLimit = model.getTokenLimit();
}

@Override
public void execute(List<String> input, ActionListener<InferenceServiceResults> listener) {
try {
HuggingFaceInferenceRequest request = new HuggingFaceInferenceRequest(account, new HuggingFaceInferenceRequestEntity(input));
var truncatedInput = truncate(input, tokenLimit);

HuggingFaceInferenceRequest request = new HuggingFaceInferenceRequest(truncator, account, truncatedInput);
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);

sender.send(request.createRequest(), responseHandler, wrappedListener);
sender.send(request, responseHandler, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount;
import org.elasticsearch.xpack.inference.external.openai.OpenAiClient;
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequestEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;

Expand All @@ -25,6 +25,7 @@
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

Expand All @@ -34,6 +35,7 @@ public class OpenAiEmbeddingsAction implements ExecutableAction {
private final OpenAiClient client;
private final OpenAiEmbeddingsModel model;
private final String errorMessage;
private final Truncator truncator;

public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
Expand All @@ -44,6 +46,7 @@ public OpenAiEmbeddingsAction(Sender sender, OpenAiEmbeddingsModel model, Servic
);
this.client = new OpenAiClient(Objects.requireNonNull(sender), Objects.requireNonNull(serviceComponents));
this.errorMessage = getErrorMessage(this.model.getServiceSettings().uri());
this.truncator = Objects.requireNonNull(serviceComponents.truncator());
}

private static String getErrorMessage(@Nullable URI uri) {
Expand All @@ -57,10 +60,9 @@ private static String getErrorMessage(@Nullable URI uri) {
@Override
public void execute(List<String> input, ActionListener<InferenceServiceResults> listener) {
try {
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(
account,
new OpenAiEmbeddingsRequestEntity(input, model.getTaskSettings().model(), model.getTaskSettings().user())
);
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());

OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, account, truncatedInput, model.getTaskSettings());
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);

client.send(request, wrappedListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.io.IOException;
Expand Down Expand Up @@ -50,7 +51,7 @@ public String getRequestType() {
}

@Override
public InferenceServiceResults parseResult(HttpResult result) throws RetryException {
public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException {
try {
return parseFunction.apply(result);
} catch (Exception e) {
Expand Down
Loading

0 comments on commit 013fa94

Please sign in to comment.