Skip to content

Commit

Permalink
[serving] Read x-synchronus and x-starting-token from input payload
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 18, 2023
1 parent 0776f3b commit a18f738
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.serving.cache.CacheEngine;
import ai.djl.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -33,8 +34,6 @@
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteRequest;
import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
Expand Down Expand Up @@ -63,7 +62,8 @@ public final class DdbCacheEngine implements CacheEngine {

private static final Logger logger = LoggerFactory.getLogger(DdbCacheEngine.class);

private static final String TABLE_NAME = "djl-serving-pagination-table";
private static final String TABLE_NAME =
Utils.getenv("DDB_TABLE_NAME", "djl-serving-pagination-table");
private static final String CACHE_ID = "CACHE_ID";
private static final String INDEX = "INDEX_KEY";
private static final String HEADER = "HEADER";
Expand All @@ -85,7 +85,7 @@ public final class DdbCacheEngine implements CacheEngine {
private DdbCacheEngine(DynamoDbClient ddbClient) {
this.ddbClient = ddbClient;
cacheTtl = Duration.ofMillis(30).toMillis();
writeBatch = 5;
writeBatch = Integer.parseInt(Utils.getenv("SERVING_DDB_BATCH", "5"));
}

/**
Expand Down Expand Up @@ -208,7 +208,7 @@ public CompletableFuture<Void> put(String key, Output output) {
/** {@inheritDoc} */
@Override
public Output get(String key, int limit) {
int start = 0;
int start = -1;
if (key.length() > 36) {
start = Integer.parseInt(key.substring(36));
key = key.substring(0, 36);
Expand All @@ -222,30 +222,24 @@ public Output get(String key, int limit) {
.tableName(TABLE_NAME)
.keyConditionExpression(EXPRESSION)
.expressionAttributeValues(attrValues)
.limit(limit)
.limit(limit == Integer.MAX_VALUE ? limit : limit + 1)
.build();

QueryResponse response = ddbClient.query(request);
if (response.count() == 0) {
if (start == 0) {
Map<String, AttributeValue> map = new ConcurrentHashMap<>(2);
map.put(CACHE_ID, AttributeValue.builder().s(key).build());
map.put(INDEX, AttributeValue.builder().n("-1").build());
GetItemRequest get =
GetItemRequest.builder().tableName(TABLE_NAME).key(map).build();
GetItemResponse resp = ddbClient.getItem(get);
if (resp.hasItem()) {
AttributeValue header = resp.item().get(HEADER);
return decode(header);
}
}
return null;
}

Output output = new Output();
boolean complete = false;
boolean first = true;
List<byte[]> list = new ArrayList<>();
for (Map<String, AttributeValue> item : response.items()) {
// skip first one
if (first) {
first = false;
continue;
}
AttributeValue header = item.get(HEADER);
if (header != null) {
Output o = decode(header);
Expand All @@ -261,13 +255,14 @@ public Output get(String key, int limit) {
if (lastContent != null) {
complete = true;
}
start++;
start = Integer.parseInt(item.get(INDEX).n());
}
if (!list.isEmpty()) {
output.add(join(list));
}
if (!complete) {
output.addProperty("x-next-token", key + start);
output.addProperty("X-Amzn-SageMaker-Custom-Attributes", "x-next-token=" + key + start);
}
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ public void testDdbCacheEngine() throws InterruptedException, ExecutionException

// query before model generate output
o = engine.get(key1, Integer.MAX_VALUE);
Assert.assertEquals(o.getCode(), 202);
Assert.assertEquals(o.getCode(), 200);
Assert.assertNull(o.getData());
String nextToken = o.getProperty("x-next-token", null);
Assert.assertEquals(nextToken, key1);
Assert.assertEquals(nextToken, key1 + "-1");

// retry before model generate output
o = engine.get(nextToken, Integer.MAX_VALUE);
Assert.assertEquals(o.getCode(), 202);
Assert.assertEquals(o.getCode(), 200);

// real output from model
Output output1 = new Output();
Expand Down
2 changes: 1 addition & 1 deletion serving/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RUN scripts/install_python.sh && \
echo "${djl_version} cpufull" > /opt/djl/bin/telemetry && \
djl-serving -i ai.djl.mxnet:mxnet-native-mkl:1.9.1:linux-x86_64 && \
djl-serving -i ai.djl.pytorch:pytorch-native-cpu:$torch_version:linux-x86_64 && \
djl-serving -i ai.djl.tensorflow:tensorflow-native-cpu:2.7.4:linux-x86_64 && \
djl-serving -i ai.djl.tensorflow:tensorflow-native-cpu:2.10.1:linux-x86_64 && \
scripts/patch_oss_dlc.sh python && \
rm -rf /opt/djl/logs && \
chown -R djl:djl /opt/djl && \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public Output get(String key, int limit) {
}
if (cbs.hasNext()) {
o.addProperty("x-next-token", key);
o.addProperty("X-Amzn-SageMaker-Custom-Attributes", "x-next-token=" + key);
} else {
// clean up cache
remove(key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class InferenceRequestHandler extends HttpRequestHandler {
private static final String X_STARTING_TOKEN = "x-starting-token";
private static final String X_NEXT_TOKEN = "x-next-token";
private static final String X_MAX_ITEMS = "x-max-items";
private static final String X_CUSTOM_ATTRIBUTES = "X-Amzn-SageMaker-Custom-Attributes";

private RequestParser requestParser;

Expand Down Expand Up @@ -292,11 +293,13 @@ void runJob(
pending.setMessage("The model result is not yet available");
pending.setCode(202);
pending.addProperty(X_NEXT_TOKEN, nextToken);
pending.addProperty(X_CUSTOM_ATTRIBUTES, X_NEXT_TOKEN + '=' + nextToken);
cache.put(nextToken, pending);

// Send back token to user
Output out = new Output();
out.addProperty(X_NEXT_TOKEN, nextToken);
out.addProperty(X_CUSTOM_ATTRIBUTES, X_NEXT_TOKEN + '=' + nextToken);
sendOutput(out, ctx);

// Run model
Expand Down Expand Up @@ -325,7 +328,7 @@ private void getCacheResult(ChannelHandlerContext ctx, Input input, String start
CacheEngine cache = CacheManager.getCacheEngine();
Output output = cache.get(startingToken, limit);
if (output == null) {
throw new BadRequestException("Invalid " + X_STARTING_TOKEN);
throw new BadRequestException("Invalid " + X_STARTING_TOKEN + ": " + startingToken);
}
sendOutput(output, ctx);
}
Expand Down Expand Up @@ -392,6 +395,7 @@ void sendOutput(Output output, ChannelHandlerContext ctx) {
void onException(Throwable t, ChannelHandlerContext ctx) {
HttpResponseStatus status;
if (t instanceof TranslateException || t instanceof BadRequestException) {
logger.debug(t.getMessage(), t);
SERVER_METRIC.info("{}", RESPONSE_4_XX);
status = HttpResponseStatus.BAD_REQUEST;
} else if (t instanceof WlmException) {
Expand Down
10 changes: 9 additions & 1 deletion serving/src/main/java/ai/djl/serving/http/RequestParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ public Input parseRequest(FullHttpRequest req, QueryStringDecoder decoder) {

for (Map.Entry<String, String> entry : req.headers().entries()) {
String key = entry.getKey();
if (!HttpHeaderNames.CONTENT_TYPE.contentEqualsIgnoreCase(key)) {
if ("X-Amzn-SageMaker-Custom-Attributes".equalsIgnoreCase(key)) {
String[] tokens = entry.getValue().split(";");
for (String token : tokens) {
String[] pair = token.split("=", 2);
if (pair.length == 2) {
input.addProperty(pair[0].trim(), pair[1].trim());
}
}
} else if (!HttpHeaderNames.CONTENT_TYPE.contentEqualsIgnoreCase(key)) {
input.addProperty(key, entry.getValue());
}
}
Expand Down
10 changes: 5 additions & 5 deletions serving/src/main/java/ai/djl/serving/util/NettyUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -242,24 +242,24 @@ public static void sendFile(
*/
public static void sendError(ChannelHandlerContext ctx, Throwable t) {
if (t instanceof ResourceNotFoundException || t instanceof ModelNotFoundException) {
logger.trace("", t);
logger.debug("", t);
NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, t);
} else if (t instanceof BadRequestException) {
logger.trace("", t);
logger.debug("", t);
BadRequestException e = (BadRequestException) t;
HttpResponseStatus status = HttpResponseStatus.valueOf(e.getCode(), e.getMessage());
NettyUtils.sendError(ctx, status, t);
} else if (t instanceof WlmOutOfMemoryException) {
logger.warn("", t);
NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, t);
} else if (t instanceof ModelException) {
logger.trace("", t);
logger.debug("", t);
NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, t);
} else if (t instanceof MethodNotAllowedException) {
logger.trace("", t);
logger.debug("", t);
NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, t);
} else if (t instanceof ServiceUnavailableException || t instanceof WlmException) {
logger.trace("", t);
logger.warn("", t);
NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, t);
} else {
logger.error("", t);
Expand Down

0 comments on commit a18f738

Please sign in to comment.