Skip to content

Commit

Permalink
[ML] Refactor the Chunker classes to return offsets (elastic#117977)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Dec 9, 2024
1 parent 20dc928 commit ea721e5
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
import java.util.List;

public interface Chunker {
List<String> chunk(String input, ChunkingSettings chunkingSettings);
record ChunkOffset(int start, int end) {};

List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
private final EmbeddingType embeddingType;
private final ChunkingSettings chunkingSettings;

private List<List<String>> chunkedInputs;
private List<ChunkOffsetsAndInput> chunkedOffsets;
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
Expand Down Expand Up @@ -109,7 +109,7 @@ public EmbeddingRequestChunker(
}

private void splitIntoBatchedRequests(List<String> inputs) {
Function<String, List<String>> chunkFunction;
Function<String, List<Chunker.ChunkOffset>> chunkFunction;
if (chunkingSettings != null) {
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
chunkFunction = input -> chunker.chunk(input, chunkingSettings);
Expand All @@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap);
}

chunkedInputs = new ArrayList<>(inputs.size());
chunkedOffsets = new ArrayList<>(inputs.size());
switch (embeddingType) {
case FLOAT -> floatResults = new ArrayList<>(inputs.size());
case BYTE -> byteResults = new ArrayList<>(inputs.size());
Expand All @@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List<String> inputs) {

for (int i = 0; i < inputs.size(); i++) {
var chunks = chunkFunction.apply(inputs.get(i));
int numberOfSubBatches = addToBatches(chunks, i);
var offSetsAndInput = new ChunkOffsetsAndInput(chunks, inputs.get(i));
int numberOfSubBatches = addToBatches(offSetsAndInput, i);
// size the results array with the expected number of request/responses
switch (embeddingType) {
case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
}
chunkedInputs.add(chunks);
chunkedOffsets.add(offSetsAndInput);
}
}

private int addToBatches(List<String> chunks, int inputIndex) {
private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) {
BatchRequest lastBatch;
if (batchedRequests.isEmpty()) {
lastBatch = new BatchRequest(new ArrayList<>());
Expand All @@ -157,16 +158,24 @@ private int addToBatches(List<String> chunks, int inputIndex) {

if (freeSpace > 0) {
// use any free space in the previous batch before creating new batches
int toAdd = Math.min(freeSpace, chunks.size());
lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)));
int toAdd = Math.min(freeSpace, chunk.offsets().size());
lastBatch.addSubBatch(
new SubBatch(
new ChunkOffsetsAndInput(chunk.offsets().subList(0, toAdd), chunk.input()),
new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)
)
);
}

int start = freeSpace;
while (start < chunks.size()) {
int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start);
while (start < chunk.offsets().size()) {
int toAdd = Math.min(maxNumberOfInputsPerBatch, chunk.offsets().size() - start);
var batch = new BatchRequest(new ArrayList<>());
batch.addSubBatch(
new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))
new SubBatch(
new ChunkOffsetsAndInput(chunk.offsets().subList(start, start + toAdd), chunk.input()),
new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)
)
);
batchedRequests.add(batch);
start += toAdd;
Expand Down Expand Up @@ -333,8 +342,8 @@ public void onFailure(Exception e) {
}

private void sendResponse() {
var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedInputs.size());
for (int i = 0; i < chunkedInputs.size(); i++) {
var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedOffsets.size());
for (int i = 0; i < chunkedOffsets.size(); i++) {
if (errors.get(i) != null) {
response.add(errors.get(i));
} else {
Expand All @@ -348,9 +357,9 @@ private void sendResponse() {

private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) {
return switch (embeddingType) {
case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex));
case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex));
case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex));
case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex));
case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex));
case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex));
};
}

Expand Down Expand Up @@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) {
}

public List<String> inputs() {
return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList());
return subBatches.stream().flatMap(s -> s.requests().toChunkText().stream()).collect(Collectors.toList());
}
}

Expand All @@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
*/
record SubBatchPositionsAndCount(int inputIndex, int chunkIndex, int embeddingCount) {}

record SubBatch(List<String> requests, SubBatchPositionsAndCount positions) {
public int size() {
return requests.size();
record SubBatch(ChunkOffsetsAndInput requests, SubBatchPositionsAndCount positions) {
int size() {
return requests.offsets().size();
}
}

record ChunkOffsetsAndInput(List<Chunker.ChunkOffset> offsets, String input) {
List<String> toChunkText() {
return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public class SentenceBoundaryChunker implements Chunker {
public SentenceBoundaryChunker() {
sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT);
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);

}

/**
Expand All @@ -45,7 +44,7 @@ public SentenceBoundaryChunker() {
* @return The input text chunked
*/
@Override
public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) {
return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0);
} else {
Expand All @@ -65,8 +64,8 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
* @param maxNumberWordsPerChunk Maximum size of the chunk
* @return The input text chunked
*/
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
var chunks = new ArrayList<String>();
public List<ChunkOffset> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
var chunks = new ArrayList<ChunkOffset>();

sentenceIterator.setText(input);
wordIterator.setText(input);
Expand All @@ -91,7 +90,7 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
int nextChunkWordCount = wordsInSentenceCount;
if (chunkWordCount > 0) {
// add a new chunk containing all the input up to this sentence
chunks.add(input.substring(chunkStart, chunkEnd));
chunks.add(new ChunkOffset(chunkStart, chunkEnd));

if (includePrecedingSentence) {
if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
Expand Down Expand Up @@ -127,12 +126,17 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
for (; i < sentenceSplits.size() - 1; i++) {
// Because the substring was passed to splitLongSentence()
// the returned positions need to be offset by chunkStart
chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end()));
chunks.add(
new ChunkOffset(
chunkStart + sentenceSplits.get(i).offsets().start(),
chunkStart + sentenceSplits.get(i).offsets().end()
)
);
}
// The final split is partially filled.
// Set the next chunk start to the beginning of the
// final split of the long sentence.
chunkStart = chunkStart + sentenceSplits.get(i).start(); // start pos needs to be offset by chunkStart
chunkStart = chunkStart + sentenceSplits.get(i).offsets().start(); // start pos needs to be offset by chunkStart
chunkWordCount = sentenceSplits.get(i).wordCount();
}
} else {
Expand All @@ -151,7 +155,7 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
}

if (chunkWordCount > 0) {
chunks.add(input.substring(chunkStart));
chunks.add(new ChunkOffset(chunkStart, input.length()));
}

return chunks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

/**
* Breaks text into smaller strings or chunks on Word boundaries.
Expand All @@ -35,7 +36,7 @@ public WordBoundaryChunker() {
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
}

record ChunkPosition(int start, int end, int wordCount) {}
record ChunkPosition(ChunkOffset offsets, int wordCount) {}

/**
* Break the input text into small chunks as dictated
Expand All @@ -45,7 +46,7 @@ record ChunkPosition(int start, int end, int wordCount) {}
* @return List of chunked text
*/
@Override
public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) {
return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap);
} else {
Expand All @@ -64,18 +65,9 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
* Can be 0 but must be non-negative.
* @return List of chunked text
*/
public List<String> chunk(String input, int chunkSize, int overlap) {

if (input.isEmpty()) {
return List.of("");
}

public List<ChunkOffset> chunk(String input, int chunkSize, int overlap) {
var chunkPositions = chunkPositions(input, chunkSize, overlap);
var chunks = new ArrayList<String>(chunkPositions.size());
for (var pos : chunkPositions) {
chunks.add(input.substring(pos.start, pos.end));
}
return chunks;
return chunkPositions.stream().map(ChunkPosition::offsets).collect(Collectors.toList());
}

/**
Expand Down Expand Up @@ -127,7 +119,7 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
wordsSinceStartWindowWasMarked++;

if (wordsInChunkCountIncludingOverlap >= chunkSize) {
chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap));
chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, boundary), wordsInChunkCountIncludingOverlap));
wordsInChunkCountIncludingOverlap = overlap;

if (overlap == 0) {
Expand All @@ -149,7 +141,7 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
// if it ends on a boundary than the count should equal overlap in which case
// we can ignore it, unless this is the first chunk in which case we want to add it
if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) {
chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap));
chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, input.length()), wordsInChunkCountIncludingOverlap));
}

return chunkPositions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public void testMultipleShortInputsAreSingleBatch() {
var subBatches = batches.get(0).batch().subBatches();
for (int i = 0; i < inputs.size(); i++) {
var subBatch = subBatches.get(i);
assertThat(subBatch.requests(), contains(inputs.get(i)));
assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
assertEquals(0, subBatch.positions().chunkIndex());
assertEquals(i, subBatch.positions().inputIndex());
assertEquals(1, subBatch.positions().embeddingCount());
Expand Down Expand Up @@ -102,7 +102,7 @@ public void testManyInputsMakeManyBatches() {
var subBatches = batches.get(0).batch().subBatches();
for (int i = 0; i < batches.size(); i++) {
var subBatch = subBatches.get(i);
assertThat(subBatch.requests(), contains(inputs.get(i)));
assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
assertEquals(0, subBatch.positions().chunkIndex());
assertEquals(inputIndex, subBatch.positions().inputIndex());
assertEquals(1, subBatch.positions().embeddingCount());
Expand Down Expand Up @@ -146,7 +146,7 @@ public void testChunkingSettingsProvided() {
var subBatches = batches.get(0).batch().subBatches();
for (int i = 0; i < batches.size(); i++) {
var subBatch = subBatches.get(i);
assertThat(subBatch.requests(), contains(inputs.get(i)));
assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
assertEquals(0, subBatch.positions().chunkIndex());
assertEquals(inputIndex, subBatch.positions().inputIndex());
assertEquals(1, subBatch.positions().embeddingCount());
Expand Down Expand Up @@ -184,17 +184,17 @@ public void testLongInputChunkedOverMultipleBatches() {
assertEquals(0, subBatch.positions().inputIndex());
assertEquals(0, subBatch.positions().chunkIndex());
assertEquals(1, subBatch.positions().embeddingCount());
assertThat(subBatch.requests(), contains("1st small"));
assertThat(subBatch.requests().toChunkText(), contains("1st small"));
}
{
var subBatch = batch.subBatches().get(1);
assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
assertEquals(0, subBatch.positions().chunkIndex()); // 1st part of the 2nd input
assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks
assertThat(subBatch.requests().get(0), startsWith("passage_input0 "));
assertThat(subBatch.requests().get(1), startsWith(" passage_input20 "));
assertThat(subBatch.requests().get(2), startsWith(" passage_input40 "));
assertThat(subBatch.requests().get(3), startsWith(" passage_input60 "));
assertThat(subBatch.requests().toChunkText().get(0), startsWith("passage_input0 "));
assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input20 "));
assertThat(subBatch.requests().toChunkText().get(2), startsWith(" passage_input40 "));
assertThat(subBatch.requests().toChunkText().get(3), startsWith(" passage_input60 "));
}
}
{
Expand All @@ -207,22 +207,22 @@ public void testLongInputChunkedOverMultipleBatches() {
assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input
assertEquals(2, subBatch.positions().embeddingCount());
assertThat(subBatch.requests().get(0), startsWith(" passage_input80 "));
assertThat(subBatch.requests().get(1), startsWith(" passage_input100 "));
assertThat(subBatch.requests().toChunkText().get(0), startsWith(" passage_input80 "));
assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input100 "));
}
{
var subBatch = batch.subBatches().get(1);
assertEquals(2, subBatch.positions().inputIndex()); // 3rd input
assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part
assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
assertThat(subBatch.requests(), contains("2nd small"));
assertThat(subBatch.requests().toChunkText(), contains("2nd small"));
}
{
var subBatch = batch.subBatches().get(2);
assertEquals(3, subBatch.positions().inputIndex()); // 4th input
assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part
assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
assertThat(subBatch.requests(), contains("3rd small"));
assertThat(subBatch.requests().toChunkText(), contains("3rd small"));
}
}
}
Expand Down
Loading

0 comments on commit ea721e5

Please sign in to comment.