Skip to content

Commit

Permalink
[python] Includes individual headers for server side batching (deepja…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Aug 8, 2023
1 parent 306225f commit 4052442
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
9 changes: 8 additions & 1 deletion engines/python/setup/djl_python/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,14 @@ def get_batches(self) -> list:
batch = []
for i in range(batch_size):
item = Input()
item.properties = self.properties
item.properties = {}
prefix = f"batch_{i}."
length = len(prefix)
for key, value in self.properties.items():
if key.startswith(prefix):
key = key[:length]
item.properties[key] = value

batch.append(item)

p = re.compile("batch_(\\d+)\\.(.*)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -82,16 +83,20 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {

Input batch = new Input();
List<O> ret = new ArrayList<>(size);
batch.setProperties(((Input) first).getProperties());
batch.addProperty("batch_size", String.valueOf(size));
for (int i = 0; i < size; ++i) {
Input in = (Input) inputs.get(i);
String prefix = "batch_" + i + '.';
for (Map.Entry<String, String> entry : in.getProperties().entrySet()) {
String key = prefix + entry.getKey();
batch.addProperty(key, entry.getValue());
}

PairList<String, BytesSupplier> content = in.getContent();
String prefix = "batch_" + i;
for (Pair<String, BytesSupplier> pair : content) {
String key = pair.getKey();
key = key == null ? "data" : key;
batch.add(prefix + '.' + key, pair.getValue());
batch.add(prefix + key, pair.getValue());
}
}
Output output = process.predict(batch, timeout, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -86,15 +87,16 @@ public void run() {
int size = list.size();
for (int i = 0; i < size; ++i) {
Request req = list.get(i);
String prefix = "batch_" + i + ".data";
if (i == 0) {
batch.setProperties(req.input.getProperties());
String prefix = "batch_" + i + '.';
for (Map.Entry<String, String> entry : req.input.getProperties().entrySet()) {
String key = prefix + entry.getKey();
batch.addProperty(key, entry.getValue());
}
batch.add(prefix, req.getRequest());

batch.add(prefix + "data", req.getRequest());
String seed = req.getSeed();
if (seed != null) {
String seedPrefix = "batch_" + i + ".seed";
batch.add(seedPrefix, req.seed);
batch.add(prefix + "seed", req.seed);
}
}
batch.addProperty("batch_size", String.valueOf(size));
Expand Down

0 comments on commit 4052442

Please sign in to comment.