Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.17][ML] Fix timeout ingesting an empty string into a semantic_text field #118746

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/117840.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117840
summary: Fix timeout ingesting an empty string into a `semantic_text` field
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
*
* @param input Text to chunk
* @param maxNumberWordsPerChunk Maximum size of the chunk
* @return The input text chunked
* @param includePrecedingSentence Include the previous sentence
* @return The input text offsets
*/
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
var chunks = new ArrayList<String>();
Expand Down Expand Up @@ -154,6 +155,11 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
chunks.add(input.substring(chunkStart));
}

if (chunks.isEmpty()) {
// The input did not chunk, return the entire input
chunks.add(input);
}

return chunks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0");
}

if (input.isEmpty()) {
return List.of();
}

var chunkPositions = new ArrayList<ChunkPosition>();

// This position in the chunk is where the next overlapping chunk will start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.hamcrest.Matchers;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -31,16 +32,62 @@

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput() {
public void testEmptyInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testBlankInput() {
public void testEmptyInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testWhitespaceInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
}

public void testBlankInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testBlankInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testInputThatDoesNotChunk_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
testListener()
);
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testInputThatDoesNotChunk_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testShortInputsAreSingleBatch() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;

import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;
Expand All @@ -27,6 +28,53 @@

public class SentenceBoundaryChunkerTests extends ESTestCase {

/**
* Utility method for testing.
*/
private List<String> textChunks(
SentenceBoundaryChunker chunker,
String input,
int maxNumberWordsPerChunk,
boolean includePrecedingSentence
) {
return chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence);
}

public void testEmptyString() {
var chunks = textChunks(new SentenceBoundaryChunker(), "", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(""));
}

public void testBlankString() {
var chunks = textChunks(new SentenceBoundaryChunker(), " ", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = textChunks(new SentenceBoundaryChunker(), " b", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" b"));

chunks = textChunks(new SentenceBoundaryChunker(), "b", 100, randomBoolean());
assertThat(chunks, Matchers.contains("b"));

chunks = textChunks(new SentenceBoundaryChunker(), ". ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(". "));

chunks = textChunks(new SentenceBoundaryChunker(), " , ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" , "));

chunks = textChunks(new SentenceBoundaryChunker(), " ,", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = textChunks(new SentenceBoundaryChunker(), input, 100, randomBoolean());
assertThat(chunks, Matchers.contains(input));
}

public void testChunkSplitLargeChunkSizes() {
for (int maxWordsPerChunk : new int[] { 100, 200 }) {
var chunker = new SentenceBoundaryChunker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -226,6 +227,39 @@ public void testWhitespace() {
assertThat(chunks, contains(" "));
}

private List<String> textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) {
return chunker.chunk(input, chunkSize, overlap);
}

public void testBlankString() {
var chunks = textChunks(new WordBoundaryChunker(), " ", 100, 10);
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = textChunks(new WordBoundaryChunker(), " b", 100, 10);
assertThat(chunks, Matchers.contains(" b"));

chunks = textChunks(new WordBoundaryChunker(), "b", 100, 10);
assertThat(chunks, Matchers.contains("b"));

chunks = textChunks(new WordBoundaryChunker(), ". ", 100, 10);
assertThat(chunks, Matchers.contains(". "));

chunks = textChunks(new WordBoundaryChunker(), " , ", 100, 10);
assertThat(chunks, Matchers.contains(" , "));

chunks = textChunks(new WordBoundaryChunker(), " ,", 100, 10);
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = textChunks(new WordBoundaryChunker(), input, 100, 10);
assertThat(chunks, Matchers.contains(input));
}

public void testPunctuation() {
int chunkSize = 1;
var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,22 @@ public void testDeploymentThreadsIncludedInUsage() throws IOException {
}
}

public void testInferEmptyInput() throws IOException {
String modelId = "empty_input";
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId);

Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
request.setJsonEntity("""
{ "docs": [] }
""");

var inferenceResponse = client().performRequest(request);
assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}"));
}

private void putModelDefinition(String modelId) throws IOException {
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
Response.Builder responseBuilder = Response.builder();
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());

if (request.numberOfDocuments() == 0) {
listener.onResponse(responseBuilder.setId(request.getId()).build());
return;
}

if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
responseBuilder.setLicensed(true);
doInfer(task, request, responseBuilder, parentTaskId, listener);
Expand Down