Skip to content

Commit

Permalink
Support maximum batch size (opensearch-project#2428)
Browse files Browse the repository at this point in the history
Signed-off-by: Liyun Xiu <xiliyun@amazon.com>
  • Loading branch information
chishui committed May 17, 2024
1 parent d74c623 commit 45021e8
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -29,4 +30,5 @@ public class ExecutionContext {
private CountDownLatch countDownLatch;
// This is to hold any exception thrown in a split-batch request
private AtomicReference<Exception> exceptionHolder;
private List<Integer> originalOrder;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
Expand Down Expand Up @@ -137,10 +138,25 @@ private void processResponse(
private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
ModelTensors[] modelTensors = new ModelTensors[tensorOutputs.size()];
log.debug("Reordered tensor outputs size is {}", tensorOutputs.size());
// step 1: reorder batches
for (Map.Entry<Integer, ModelTensors> entry : tensorOutputs.entrySet()) {
modelTensors[entry.getKey()] = entry.getValue();
}
actionListener.onResponse(Arrays.asList(modelTensors));

// step 2: restore to original order as textDocs might be sorted
List<Integer> originalOrderIndexes = executionContext.getOriginalOrder();
List<ModelTensors> modelTensorsList = Arrays.asList(modelTensors);
if (CollectionUtils.isEmpty(originalOrderIndexes)) {
actionListener.onResponse(modelTensorsList);
} else {
// if the originalOrder is not empty, reorder based on it
List<ModelTensors> sortedModelTensors = Arrays.asList(new ModelTensors[modelTensorsList.size()]);
assert (originalOrderIndexes.size() == modelTensors.length);
for (int i = 0; i < originalOrderIndexes.size(); ++i) {
sortedModelTensors.set(i, modelTensorsList.get(originalOrderIndexes.get(i)));
}
actionListener.onResponse(sortedModelTensors);
}
}

protected class MLResponseSubscriber implements Subscriber<ByteBuffer> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
Expand All @@ -41,6 +45,9 @@
import org.opensearch.script.ScriptService;

public interface RemoteConnectorExecutor {
int DEFAULT_BATCH_SIZE = -1;
String MAX_BATCH_SIZE_KEY = "max_batch_size";
String STEP_SIZE_KEY = "input_docs_processed_step_size";

default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
ActionListener<List<ModelTensors>> tensorActionListener = ActionListener.wrap(r -> {
Expand All @@ -51,28 +58,44 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(textDocsInputDataSet);
int maxBatchSize = getMaxBatchSize();
Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(
textDocsInputDataSet,
maxBatchSize
);
List<Integer> textDocOriginalOrder = Collections.emptyList();
if (shouldSortBeforeCuttingBatches(maxBatchSize, calculatedChunkSize)) {
Tuple<TextDocsInputDataSet, List<Integer>> sortedData = sortTextDocsByTextLength(textDocsInputDataSet);
textDocsInputDataSet = sortedData.v1();
textDocOriginalOrder = Collections.unmodifiableList(sortedData.v2());
}

CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1());
int sequence = 0;
for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize
.v2()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize.v2()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs,
Math.min(processedDocs + calculatedChunkSize.v2(), textDocsInputDataSet.getDocs().size()));
preparePayloadAndInvokeRemoteModel(
MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
modelTensors,
new ExecutionContext(sequence++, countDownLatch, exceptionHolder),
new ExecutionContext(
sequence++,
countDownLatch,
exceptionHolder,
textDocOriginalOrder
),
tensorActionListener
);
}
} else {
preparePayloadAndInvokeRemoteModel(
mlInput,
modelTensors,
new ExecutionContext(0, new CountDownLatch(1), exceptionHolder),
new ExecutionContext(0, new CountDownLatch(1), exceptionHolder, Collections.emptyList()),
tensorActionListener
);
}
Expand All @@ -86,20 +109,16 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
* @param textDocsInputDataSet
* @return Tuple of chunk size and step size.
*/
private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) {
private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet, int maxBatchSize) {
int textDocsLength = textDocsInputDataSet.getDocs().size();
Map<String, String> parameters = getConnector().getParameters();
if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size"));
if (parameters != null && parameters.containsKey(STEP_SIZE_KEY)) {
int stepSize = Integer.parseInt(parameters.get(STEP_SIZE_KEY));
// We need to check the parameter on runtime as parameter can be passed into predict request
if (stepSize <= 0) {
throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
} else {
boolean isDivisible = textDocsLength % stepSize == 0;
if (isDivisible) {
return Tuple.tuple(textDocsLength / stepSize, stepSize);
}
return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
return Tuple.tuple((int) Math.ceil((double) textDocsLength / stepSize), stepSize);
}
} else {
Optional<ConnectorAction> predictAction = getConnector().findPredictAction();
Expand All @@ -111,9 +130,53 @@ private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocs
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
}
// consider as batch.
return Tuple.tuple(1, textDocsLength);

if (maxBatchSize == DEFAULT_BATCH_SIZE || textDocsLength <= maxBatchSize) {
return Tuple.tuple(1, textDocsLength);
}
return Tuple.tuple((int) Math.ceil((double) textDocsLength / maxBatchSize), maxBatchSize);
}
}

/**
* Get user configured max_batch_size parameter, throw IllegalArgumentException if it's invalid.
* Return default value if it's not configured.
* @return max batch size
*/
private int getMaxBatchSize() {
Map<String, String> parameters = getConnector().getParameters();
if (parameters == null || !parameters.containsKey(MAX_BATCH_SIZE_KEY)) {
return DEFAULT_BATCH_SIZE;
}
int maxBatchSize = Integer.parseInt(parameters.get(MAX_BATCH_SIZE_KEY));
if (maxBatchSize <= 0) {
throw new IllegalArgumentException("Invalid parameter: " + MAX_BATCH_SIZE_KEY + ". It must be positive integer.");
}
return maxBatchSize;
}

private boolean shouldSortBeforeCuttingBatches(int maxBatchSize, Tuple<Integer, Integer> calculatedChunkSize) {
if (maxBatchSize <= 1 || calculatedChunkSize.v1() <= 1 || calculatedChunkSize.v2() <= 1) {
return false;
}
// skip step size situation
Map<String, String> parameters = getConnector().getParameters();
if (parameters != null && parameters.containsKey(STEP_SIZE_KEY)) {
return false;
}
return true;
}

private Tuple<TextDocsInputDataSet, List<Integer>> sortTextDocsByTextLength(TextDocsInputDataSet textDocsInputDataSet) {
List<Tuple<Integer, String>> docsWithIndex = new ArrayList<>();
for (int i = 0; i < textDocsInputDataSet.getDocs().size(); ++i) {
docsWithIndex.add(Tuple.tuple(i, textDocsInputDataSet.getDocs().get(i)));
}
docsWithIndex.sort(Comparator.comparingInt(t -> t.v2().length()));
List<String> sortedDocs = docsWithIndex.stream().map(Tuple::v2).collect(Collectors.toList());
List<Integer> originalIndexOrder = docsWithIndex.stream().map(Tuple::v1).collect(Collectors.toList());
TextDocsInputDataSet sortedTextDocsInputDataSet = TextDocsInputDataSet.builder().docs(sortedDocs).build();
return Tuple.tuple(sortedTextDocsInputDataSet, originalIndexOrder);
}

default void setScriptService(ScriptService scriptService) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -94,7 +95,7 @@ public void invokeRemoteModel_invalidIpAddress() {
new HashMap<>(),
"{\"input\": \"hello world\"}",
new HashMap<>(),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>(), Collections.emptyList()),
actionListener
);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Expand Down Expand Up @@ -126,7 +127,7 @@ public void invokeRemoteModel_Empty_payload() {
new HashMap<>(),
null,
new HashMap<>(),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>(), Collections.emptyList()),
actionListener
);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Expand Down Expand Up @@ -158,7 +159,7 @@ public void invokeRemoteModel_get_request() {
new HashMap<>(),
null,
new HashMap<>(),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>(), Collections.emptyList()),
actionListener
);
}
Expand Down Expand Up @@ -186,7 +187,7 @@ public void invokeRemoteModel_post_request() {
new HashMap<>(),
"hello world",
new HashMap<>(),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>(), Collections.emptyList()),
actionListener
);
}
Expand Down Expand Up @@ -217,7 +218,7 @@ public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFie
new HashMap<>(),
"hello world",
new HashMap<>(),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>(), Collections.emptyList()),
actionListener
);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -41,7 +42,7 @@
import software.amazon.awssdk.http.SdkHttpResponse;

public class MLSdkAsyncHttpResponseHandlerTest {
private final ExecutionContext executionContext = new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>());
private final ExecutionContext executionContext = new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>(), Collections.emptyList());
@Mock
private ActionListener<List<ModelTensors>> actionListener;
@Mock
Expand Down Expand Up @@ -260,7 +261,7 @@ public void test_onComplete_partial_success_exceptionSecond() {
String response2 = "Model current status is: FAILED";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(0, count, exceptionHolder),
new ExecutionContext(0, count, exceptionHolder, Collections.emptyList()),
actionListener,
parameters,
tensorOutputs,
Expand All @@ -269,7 +270,7 @@ public void test_onComplete_partial_success_exceptionSecond() {
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(1, count, exceptionHolder),
new ExecutionContext(1, count, exceptionHolder, Collections.emptyList()),
actionListener,
parameters,
tensorOutputs,
Expand Down Expand Up @@ -328,7 +329,7 @@ public void test_onComplete_partial_success_exceptionFirst() {
String response2 = "Model current status is: FAILED";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(0, count, exceptionHolder),
new ExecutionContext(0, count, exceptionHolder, Collections.emptyList()),
actionListener,
parameters,
tensorOutputs,
Expand All @@ -337,7 +338,7 @@ public void test_onComplete_partial_success_exceptionFirst() {
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new ExecutionContext(1, count, exceptionHolder),
new ExecutionContext(1, count, exceptionHolder, Collections.emptyList()),
actionListener,
parameters,
tensorOutputs,
Expand Down
Loading

0 comments on commit 45021e8

Please sign in to comment.