Skip to content

Commit

Permalink
handle inputs that do not chunk
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Dec 3, 2024
1 parent 9513926 commit 0ab1ca8
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
chunks.add(input.substring(chunkStart));
}

if (chunks.isEmpty()) {
// The input did not chunk, return the entire input
chunks.add(input);
}

return chunks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ public void testEmptyInput_SentenceChunker() {
assertThat(batches, empty());
}

public void testWhitespaceInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
}

public void testBlankInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
Expand All @@ -62,6 +71,25 @@ public void testBlankInput_SentenceChunker() {
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testInputThatDoesNotChunk_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
testListener()
);
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testInputThatDoesNotChunk_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testShortInputsAreSingleBatch() {
String input = "one chunk";
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@ public void testEmptyString() {
assertThat(chunks.get(0), Matchers.is(""));
}

public void testBlankString() {
var chunks = new SentenceBoundaryChunker().chunk(" ", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = new SentenceBoundaryChunker().chunk(" b", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" b"));

chunks = new SentenceBoundaryChunker().chunk("b", 100, randomBoolean());
assertThat(chunks, Matchers.contains("b"));

chunks = new SentenceBoundaryChunker().chunk(". ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(". "));

chunks = new SentenceBoundaryChunker().chunk(" , ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" , "));

chunks = new SentenceBoundaryChunker().chunk(" ,", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = new SentenceBoundaryChunker().chunk(input, 100, randomBoolean());
assertThat(chunks, Matchers.contains(input));
}

public void testChunkSplitLargeChunkSizes() {
for (int maxWordsPerChunk : new int[] { 100, 200 }) {
var chunker = new SentenceBoundaryChunker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -226,6 +227,35 @@ public void testWhitespace() {
assertThat(chunks, contains(" "));
}

public void testBlankString() {
var chunks = new WordBoundaryChunker().chunk(" ", 100, 10);
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = new WordBoundaryChunker().chunk(" b", 100, 10);
assertThat(chunks, Matchers.contains(" b"));

chunks = new WordBoundaryChunker().chunk("b", 100, 10);
assertThat(chunks, Matchers.contains("b"));

chunks = new WordBoundaryChunker().chunk(". ", 100, 10);
assertThat(chunks, Matchers.contains(". "));

chunks = new WordBoundaryChunker().chunk(" , ", 100, 10);
assertThat(chunks, Matchers.contains(" , "));

chunks = new WordBoundaryChunker().chunk(" ,", 100, 10);
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = new WordBoundaryChunker().chunk(input, 100, 10);
assertThat(chunks, Matchers.contains(input));
}

public void testPunctuation() {
int chunkSize = 1;
var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);
Expand Down

0 comments on commit 0ab1ca8

Please sign in to comment.