diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index bc72a31d1e8b8..2e3a0667bb13b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -158,6 +158,11 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl chunks.add(input.substring(chunkStart)); } + if (chunks.isEmpty()) { + // The input did not chunk, return the entire input + chunks.add(input); + } + return chunks; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 6b2a53de306aa..760882fa2b39d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -45,6 +45,15 @@ public void testEmptyInput_SentenceChunker() { assertThat(batches, empty()); } + public void testWhitespaceInput_SentenceChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" ")); + } + public void testBlankInput_WordChunker() { var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); @@ -62,6 +71,25 @@ public void testBlankInput_SentenceChunker() { assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); } + public void testInputThatDoesNotChunk_WordChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners( + testListener() + ); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + } + + public void testInputThatDoesNotChunk_SentenceChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + } + public void testShortInputsAreSingleBatch() { String input = "one chunk"; var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index 5a350439ceffc..e8b3f09b11750 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -33,6 +33,35 @@ public void testEmptyString() { assertThat(chunks.get(0), Matchers.is("")); } + public void testBlankString() { + var chunks = new SentenceBoundaryChunker().chunk(" ", 100, randomBoolean()); + assertThat(chunks, hasSize(1)); + assertThat(chunks.get(0), Matchers.is(" ")); + } + + public void testSingleChar() { + var chunks = new SentenceBoundaryChunker().chunk(" b", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(" b")); + + chunks = new SentenceBoundaryChunker().chunk("b", 100, randomBoolean()); + assertThat(chunks, Matchers.contains("b")); + + chunks = new SentenceBoundaryChunker().chunk(". ", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(". ")); + + chunks = new SentenceBoundaryChunker().chunk(" , ", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(" , ")); + + chunks = new SentenceBoundaryChunker().chunk(" ,", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(" ,")); + } + + public void testSingleCharRepeated() { + var input = "a".repeat(32_000); + var chunks = new SentenceBoundaryChunker().chunk(input, 100, randomBoolean()); + assertThat(chunks, Matchers.contains(input)); + } + public void testChunkSplitLargeChunkSizes() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index ef643a4b36fdc..795f90aa29b02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; import java.util.List; import java.util.Locale; @@ -226,6 +227,35 @@ public void testWhitespace() { assertThat(chunks, contains(" ")); } + public void testBlankString() { + var chunks = new WordBoundaryChunker().chunk(" ", 100, 10); + assertThat(chunks, hasSize(1)); + assertThat(chunks.get(0), Matchers.is(" ")); + } + + public void testSingleChar() { + var chunks = new WordBoundaryChunker().chunk(" b", 100, 10); + assertThat(chunks, Matchers.contains(" b")); + + chunks = new WordBoundaryChunker().chunk("b", 100, 10); + assertThat(chunks, Matchers.contains("b")); + + chunks = new WordBoundaryChunker().chunk(". ", 100, 10); + assertThat(chunks, Matchers.contains(". ")); + + chunks = new WordBoundaryChunker().chunk(" , ", 100, 10); + assertThat(chunks, Matchers.contains(" , ")); + + chunks = new WordBoundaryChunker().chunk(" ,", 100, 10); + assertThat(chunks, Matchers.contains(" ,")); + } + + public void testSingleCharRepeated() { + var input = "a".repeat(32_000); + var chunks = new WordBoundaryChunker().chunk(input, 100, 10); + assertThat(chunks, Matchers.contains(input)); + } + public void testPunctuation() { int chunkSize = 1; var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);