Skip to content

Commit

Permalink
[api] Add batch support to TextEmbeddingServingTranslator (#3084)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Apr 15, 2024
1 parent 4a434c2 commit 3813405
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 40 deletions.
58 changes: 27 additions & 31 deletions api/src/main/java/ai/djl/inference/Predictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -153,42 +155,46 @@ protected NDList predictInternal(TranslatorContext ctx, NDList ndList)
* @return a list of output objects defined by the user
* @throws TranslateException if an error occurs during prediction
*/
@SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"})
@SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches", "unchecked"})
public List<O> batchPredict(List<I> inputs) throws TranslateException {
long begin = System.nanoTime();
try (PredictorContext context = new PredictorContext()) {
if (!prepared) {
translator.prepare(context);
prepared = true;
}
Batchifier batchifier = translator.getBatchifier();
if (batchifier == null) {
Translator<I[], O[]> batchTranslator = translator.toBatchTranslator();
if (batchTranslator == null) {
List<O> ret = new ArrayList<>(inputs.size());
for (I input : inputs) {
timestamp = System.nanoTime();
begin = timestamp;
long begin = timestamp;
NDList ndList = translator.processInput(context, input);
preprocessEnd(ndList);
preprocessEnd(ndList, 1);

NDList result = predictInternal(context, ndList);
predictEnd(result);
predictEnd(result, 1);

ret.add(translator.processOutput(context, result));
postProcessEnd(begin);
postProcessEnd(begin, 1);
}
return ret;
}

int batchSize = inputs.size();
I[] empty = (I[]) Array.newInstance(inputs.get(0).getClass(), 0);
I[] in = inputs.toArray(empty);

timestamp = System.nanoTime();
NDList inputBatch = processInputs(context, inputs);
preprocessEnd(inputBatch);
long begin = timestamp;
NDList ndList = batchTranslator.processInput(context, in);
preprocessEnd(ndList, batchSize);

NDList result = predictInternal(context, inputBatch);
predictEnd(result);
NDList result = predictInternal(context, ndList);
predictEnd(result, batchSize);

List<O> ret = processOutputs(context, result);
postProcessEnd(begin);
return ret;
O[] ret = batchTranslator.processOutput(context, result);
postProcessEnd(begin, batchSize);
return Arrays.asList(ret);
} catch (TranslateException e) {
throw e;
} catch (Exception e) {
Expand Down Expand Up @@ -302,40 +308,30 @@ private NDList processInputs(TranslatorContext ctx, List<I> inputs) throws Excep
return translator.getBatchifier().batchify(preprocessed);
}

@SuppressWarnings("PMD.SignatureDeclareThrowsException")
private List<O> processOutputs(TranslatorContext ctx, NDList list) throws Exception {
NDList[] unbatched = translator.getBatchifier().unbatchify(list);
List<O> outputs = new ArrayList<>(unbatched.length);
for (NDList output : unbatched) {
outputs.add(translator.processOutput(ctx, output));
}
return outputs;
}

private void preprocessEnd(NDList list) {
private void preprocessEnd(NDList list, int batchSize) {
if (metrics != null) {
waitToRead(list);
long tmp = System.nanoTime();
long duration = (tmp - timestamp) / 1000;
long duration = (tmp - timestamp) / 1000 / batchSize;
timestamp = tmp;
metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS, dimension);
}
}

private void predictEnd(NDList list) {
private void predictEnd(NDList list, int batchSize) {
if (metrics != null) {
waitToRead(list);
long tmp = System.nanoTime();
long duration = (tmp - timestamp) / 1000;
long duration = (tmp - timestamp) / 1000 / batchSize;
timestamp = tmp;
metrics.addMetric("Inference", duration, Unit.MICROSECONDS, dimension);
}
}

private void postProcessEnd(long begin) {
private void postProcessEnd(long begin, int batchSize) {
if (metrics != null) {
long tmp = System.nanoTime();
long duration = (tmp - timestamp) / 1000;
long duration = (tmp - timestamp) / 1000 / batchSize;
timestamp = tmp;
metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS, dimension);
long prediction = (tmp - begin) / 1000;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/** A {@link Translator} that can handle generic text embedding {@link Input} and {@link Output}. */
public class TextEmbeddingServingTranslator implements NoBatchifyTranslator<Input, Output> {
public class TextEmbeddingServingTranslator implements Translator<Input, Output> {

private Translator<String, float[]> translator;
private Translator<String[], float[][]> batchTranslator;
Expand Down Expand Up @@ -84,4 +89,64 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
}
return output;
}

/** {@inheritDoc} */
@Override
public Translator<Input[], Output[]> toBatchTranslator(Batchifier batchifier) {
return new NoBatchifyTranslator<Input[], Output[]>() {

/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public NDList processInput(TranslatorContext ctx, Input[] inputs) throws Exception {
List<String> prompts = new ArrayList<>(inputs.length);
int[] mapping = new int[inputs.length];
for (int i = 0; i < inputs.length; ++i) {
TextPrompt prompt = TextPrompt.parseInput(inputs[i]);
if (prompt.isBatch()) {
String[] batch = prompt.getBatch();
mapping[i] = batch.length;
prompts.addAll(Arrays.asList(batch));
} else {
mapping[i] = -1;
prompts.add(prompt.getText());
}
}
ctx.setAttachment("mapping", mapping);
return batchTranslator.processInput(ctx, prompts.toArray(Utils.EMPTY_ARRAY));
}

/** {@inheritDoc} */
@Override
@SuppressWarnings({"PMD.SignatureDeclareThrowsException", "unchecked"})
public Output[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
NDList[] unbatched = batchifier.unbatchify(list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
Object[] encodings = (Object[]) ctx.getAttachment("encodings");
Output[] ret = new Output[mapping.length];
int index = 0;
for (int i = 0; i < ret.length; ++i) {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (mapping[i] == -1) {
// non-batching
ctx.setAttachment("encoding", encodings[index]);
float[] embedding = translator.processOutput(ctx, unbatched[index]);
++index;
output.add(BytesSupplier.wrapAsJson(embedding));
} else {
float[][] embeddings = new float[mapping[i]][];
for (int j = 0; j < mapping[i]; ++j) {
ctx.setAttachment("encoding", encodings[index]);
embeddings[j] = translator.processOutput(ctx, unbatched[index]);
++index;
}
output.add(BytesSupplier.wrapAsJson(embeddings));
}
ret[i] = output;
}
return ret;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ public class TextClassificationBatchTranslator
private PretrainedConfig config;

TextClassificationBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
Batchifier batchifier,
PretrainedConfig config) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
this.config = config;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
@Override
public TextClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new TextClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier);
return new TextClassificationBatchTranslator(
tokenizer, includeTokenTypes, batchifier, config);
}

static Classifications toClassifications(PretrainedConfig config, NDList list) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

import java.util.Arrays;

/** The translator for Huggingface text embedding model. */
public class TextEmbeddingBatchTranslator implements NoBatchifyTranslator<String[], float[][]> {

Expand All @@ -47,11 +49,11 @@ public class TextEmbeddingBatchTranslator implements NoBatchifyTranslator<String
@Override
public NDList processInput(TranslatorContext ctx, String[] input) {
NDManager manager = ctx.getNDManager();
Encoding[] encodings = tokenizer.batchEncode(input);
Object[] encodings = Arrays.stream(tokenizer.batchEncode(input)).toArray();
ctx.setAttachment("encodings", encodings);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = ((Encoding) encodings[i]).toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand All @@ -60,13 +62,13 @@ public NDList processInput(TranslatorContext ctx, String[] input) {
@Override
public float[][] processOutput(TranslatorContext ctx, NDList list) {
NDList[] batch = batchifier.unbatchify(list);
Encoding[] encoding = (Encoding[]) ctx.getAttachment("encodings");
Object[] encoding = (Object[]) ctx.getAttachment("encodings");
NDManager manager = ctx.getNDManager();
float[][] ret = new float[batch.length][];
for (int i = 0; i < batch.length; ++i) {
NDArray array =
TextEmbeddingTranslator.processEmbedding(
manager, batch[i], encoding[i], pooling);
manager, batch[i], (Encoding) encoding[i], pooling);
if (normalize) {
array = array.normalize(2, 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ public class TokenClassificationBatchTranslator
private PretrainedConfig config;

TokenClassificationBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
Batchifier batchifier,
PretrainedConfig config) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
this.config = config;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) {
@Override
public TokenClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new TokenClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier);
return new TokenClassificationBatchTranslator(
tokenizer, includeTokenTypes, batchifier, config);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class TextEmbeddingTranslatorTest {
Expand Down Expand Up @@ -275,4 +277,53 @@ public void testTextEmbeddingBatchTranslator()
});
}
}

@Test
public void testTextEmbeddingTranslatorServingBatch()
throws ModelException, IOException, TranslateException {
String[] text = {"This is an example sentence", "This is the second sentence"};

Block block =
new LambdaBlock(
a -> {
NDManager manager = a.getManager();
NDArray arr = manager.ones(new Shape(4, 7, 384));
arr.setName("last_hidden_state");
return new NDList(arr);
},
"model");
Path modelDir = Paths.get("build/model");
Files.createDirectories(modelDir);

Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(modelDir)
.optBlock(block)
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optOption("hasParameter", "false")
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.build();

try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Input input1 = new Input();
input1.add(JsonUtils.GSON.toJson(text));
input1.addProperty("Content-Type", "application/json");

Input input2 = new Input();
Map<String, String[]> map = new HashMap<>();
map.put("inputs", text);
input2.add(JsonUtils.GSON.toJson(map));
input2.addProperty("Content-Type", "application/json");
List<Input> batchInput = Arrays.asList(input1, input2);

List<Output> batchOutput = predictor.batchPredict(batchInput);
Assert.assertEquals(batchOutput.size(), 2);
float[][] res = (float[][]) batchOutput.get(0).getData().getAsObject();
Assert.assertEquals(res[0].length, 384);
Assertions.assertAlmostEquals(res[0][0], 0.05103);
}
}
}

0 comments on commit 3813405

Please sign in to comment.