Skip to content

Commit

Permalink
Merge branch '8.16' into 816-117840
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Dec 16, 2024
2 parents c81f06f + dfa1c87 commit 2c647e0
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;

Expand Down Expand Up @@ -656,25 +657,13 @@ public void chunkedInfer(
esModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var batch : batchedRequests) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);

ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
);

var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
);

client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
if (batchedRequests.isEmpty()) {
listener.onResponse(List.of());
} else {
// Avoid filling the inference queue by executing the batches in series
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
sequentialRunner.run();
}
} else {
listener.onFailure(notElasticsearchModelException(model));
Expand Down Expand Up @@ -990,4 +979,80 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
return null;
}
}

/**
* Iterates over the batch executing a limited number requests at a time to avoid
* filling the ML node inference queue.
*
* First, a single request is executed, which can also trigger deploying a model
* if necessary. When this request is successfully executed, a callback executes
* N requests in parallel next. Each of these requests also has a callback that
* executes one more request, so that at all time N requests are in-flight. This
* continues until all requests are executed.
*/
class BatchIterator {
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200

private final AtomicInteger index = new AtomicInteger();
private final ElasticsearchInternalModel esModel;
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
private final InputType inputType;
private final TimeValue timeout;

BatchIterator(
ElasticsearchInternalModel esModel,
InputType inputType,
TimeValue timeout,
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
) {
this.esModel = esModel;
this.requestAndListeners = requestAndListeners;
this.inputType = inputType;
this.timeout = timeout;
}

void run() {
// The first request may deploy the model, and upon completion runs
// NUM_REQUESTS_INFLIGHT in parallel.
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
}

private void inferBatch(int runAfterCount, boolean maybeDeploy) {
int batchIndex = index.getAndIncrement();
if (batchIndex >= requestAndListeners.size()) {
return;
}
executeRequest(batchIndex, maybeDeploy, () -> {
for (int i = 0; i < runAfterCount; i++) {
// Subsequent requests may not deploy the model, because the first request
// already did so. Upon completion, it runs one more request.
inferenceExecutor.execute(() -> inferBatch(1, false));
}
});
}

private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);
logger.trace("Executing batch index={}", batchIndex);

ActionListener<InferModelAction.Response> listener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
);
if (runAfter != null) {
listener = ActionListener.runAfter(listener, runAfter);
}
if (maybeDeploy) {
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
}
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.Level;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -59,19 +60,22 @@
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -824,16 +828,16 @@ public void testParsePersistedConfig() {
}
}

public void testChunkInfer_E5WithNullChunkingSettings() {
public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
testChunkInfer_e5(null);
}

public void testChunkInfer_E5ChunkingSettingsSet() {
public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException {
testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings());
}

@SuppressWarnings("unchecked")
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
Expand Down Expand Up @@ -881,6 +885,9 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -889,22 +896,23 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

public void testChunkInfer_SparseWithNullChunkingSettings() {
public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException {
testChunkInfer_Sparse(null);
}

public void testChunkInfer_SparseWithChunkingSettingsSet() {
public void testChunkInfer_SparseWithChunkingSettingsSet() throws InterruptedException {
testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings());
}

@SuppressWarnings("unchecked")
private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
Expand All @@ -928,6 +936,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
var service = createService(client);

var gotResults = new AtomicBoolean();

var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
assertThat(chunkedResponse, hasSize(2));
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class));
Expand All @@ -947,6 +956,9 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -955,22 +967,23 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

public void testChunkInfer_ElserWithNullChunkingSettings() {
public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException {
testChunkInfer_Elser(null);
}

public void testChunkInfer_ElserWithChunkingSettingsSet() {
public void testChunkInfer_ElserWithChunkingSettingsSet() throws InterruptedException {
testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings());
}

@SuppressWarnings("unchecked")
private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
Expand Down Expand Up @@ -1014,6 +1027,9 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -1022,9 +1038,10 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

Expand Down Expand Up @@ -1085,7 +1102,7 @@ public void testChunkInferSetsTokenization() {
}

@SuppressWarnings("unchecked")
public void testChunkInfer_FailsBatch() {
public void testChunkInfer_FailsBatch() throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
Expand Down Expand Up @@ -1121,6 +1138,9 @@ public void testChunkInfer_FailsBatch() {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -1129,12 +1149,86 @@ public void testChunkInfer_FailsBatch() {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

@SuppressWarnings("unchecked")
public void testChunkingLargeDocument() throws InterruptedException {
int numBatches = randomIntBetween(3, 6);

// how many response objects to return in each batch
int[] numResponsesPerBatch = new int[numBatches];
for (int i = 0; i < numBatches - 1; i++) {
numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
}
numResponsesPerBatch[numBatches - 1] = randomIntBetween(1, ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE);
int numChunks = Arrays.stream(numResponsesPerBatch).sum();

// build a doc with enough words to make numChunks of chunks
int wordsPerChunk = 10;
int numWords = numChunks * wordsPerChunk;
var input = "word ".repeat(numWords);

Client client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);

// mock the inference response
doAnswer(invocationOnMock -> {
var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
var mlTrainedModelResults = new ArrayList<InferenceResults>();
for (int i = 0; i < request.numberOfDocuments(); i++) {
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
}
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
listener.onResponse(response);
return null;
}).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));

var service = createService(client);

var gotResults = new AtomicBoolean();
var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
assertThat(chunkedResponse, hasSize(1));
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0);
assertThat(sparseResults.chunks(), hasSize(numChunks));

gotResults.set(true);
}, ESTestCase::fail);

// Create model using the word boundary chunker.
var model = new MultilingualE5SmallModel(
"foo",
TaskType.TEXT_EMBEDDING,
"e5",
new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null),
new WordBoundaryChunkingSettings(wordsPerChunk, 0)
);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

// For the given input we know how many requests will be made
service.chunkedInfer(
model,
null,
List.of(input),
Map.of(),
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
latchedListener
);

latch.await();
assertTrue("Listener not called with results", gotResults.get());
}

public void testParsePersistedConfig_Rerank() {
// with task settings
{
Expand Down

0 comments on commit 2c647e0

Please sign in to comment.