Skip to content

Commit

Permalink
[serving] Adds workflow model loading for SageMaker (deepjavalibrary#661
Browse files Browse the repository at this point in the history
)

* [serving] Adds workflow model loading for SageMaker

1. Removed global preferences for workflow definition
2. Updated README
  • Loading branch information
frankfliu authored Apr 25, 2023
1 parent e5c2113 commit fb68c5b
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 228 deletions.
12 changes: 3 additions & 9 deletions serving/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ usage: djl-serving [OPTIONS]
-h,--help Print this help.
-m,--models <MODELS> Models to be loaded at startup.
-s,--model-store <MODELS-STORE> Model store location where models can be loaded.
-w,--workflows <WORKFLOWS> Workflows to be loaded at startup.
```

Details about the models, model-store, and workflows can be found in the equivalent configuration properties.
Expand Down Expand Up @@ -65,7 +64,7 @@ model_store=build/models

**Load Models**

The `load_models` config property can be used to define a list of models to be loaded.
The `load_models` config property can be used to define a list of models (or workflows) to be loaded.
The list should be defined as a comma separated list of urls to load models from.

Each model can be defined either as a URL directly or optionally with prepended endpoint data like `[EndpointData]=modelUrl`.
Expand All @@ -89,15 +88,10 @@ load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=htt

**Workflows**

Use the `load_workflows` config property to define initial workflows that should be loaded on startup.
It should be a comma separated list of workflow URLs.

You can also specify the device that the model should be loaded on by using `modelUrl:deviceNames`.
The `deviceNames` matches the format used in the `load_models` property described above.
An example is shown below:
Use the `load_models` config property to define initial workflows that should be loaded on startup.

```properties
load_workflows=https://resources.djl.ai/test-models/basic-serving-workflow.json
load_models=https://resources.djl.ai/test-models/basic-serving-workflow.json
```

View the [workflow documentation](workflows.md) to see more information about workflows and their configuration format.
Expand Down
9 changes: 0 additions & 9 deletions serving/docs/workflows.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ As the system is built in YAML, the overall structure is a configuration object
name: "MyWorkflow"
version: "1.2.0"
# Default model properties based on https://github.com/pytorch/serve/blob/master/docs/workflows.md#workflow-model-properties
# optional
minWorkers: 1
maxWorkers: 4
batchSize: 3
maxBatchDelayMillis: 5000
retryAttempts: 3
timeout: 5000
# Defined below
models: ...
functions: ...
Expand Down
18 changes: 0 additions & 18 deletions serving/src/main/java/ai/djl/serving/Arguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public final class Arguments {
private String configFile;
private String modelStore;
private String[] models;
private String[] workflows;
private boolean help;

/**
Expand All @@ -40,7 +39,6 @@ public Arguments(CommandLine cmd) {
configFile = cmd.getOptionValue("config-file");
modelStore = cmd.getOptionValue("model-store");
models = cmd.getOptionValues("models");
workflows = cmd.getOptionValues("workflows");
help = cmd.hasOption("help");
}

Expand Down Expand Up @@ -74,13 +72,6 @@ public static Options getOptions() {
.argName("MODELS-STORE")
.desc("Model store location where models can be loaded.")
.build());
options.addOption(
Option.builder("w")
.longOpt("workflows")
.hasArgs()
.argName("WORKFLOWS")
.desc("Workflows to be loaded at startup.")
.build());
options.addOption(
Option.builder("i")
.longOpt("install")
Expand Down Expand Up @@ -140,15 +131,6 @@ public String[] getModels() {
return models;
}

/**
* Returns the workflow urls that specified in command line.
*
* @return the workflow urls that specified in command line
*/
public String[] getWorkflows() {
return workflows;
}

/**
* Returns if the command line has help option.
*
Expand Down
107 changes: 29 additions & 78 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.GeneralSecurityException;
Expand Down Expand Up @@ -186,8 +185,7 @@ public List<ChannelFuture> start()

try {
initModelStore();
initWorkflows();
} catch (URISyntaxException | BadWorkflowException e) {
} catch (BadWorkflowException e) {
throw new ServerStartupException(
"Failed to initialize startup models and workflows", e);
}
Expand Down Expand Up @@ -298,7 +296,7 @@ private ChannelFuture initializeServer(
return f;
}

private void initModelStore() throws IOException {
private void initModelStore() throws IOException, BadWorkflowException {
Set<String> startupModels = ModelManager.getInstance().getStartupWorkflows();

String loadModels = configManager.getLoadModels();
Expand Down Expand Up @@ -355,7 +353,7 @@ private void initModelStore() throws IOException {
String version = null;
String engineName = null;
String deviceMapping = null;
String modelName;
String modelName = null;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
modelName = tokens[0];
Expand All @@ -368,26 +366,33 @@ private void initModelStore() throws IOException {
if (tokens.length > 3) {
deviceMapping = tokens[3];
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
modelName,
modelUrl,
version,
engineName,
deviceMapping,
Input.class,
Output.class,
-1,
-1,
-1,
-1,
-1,
-1);
Workflow workflow = new Workflow(modelInfo);
Workflow workflow;
URI uri = WorkflowDefinition.toWorkflowUri(modelUrl);
if (uri != null) {
workflow = WorkflowDefinition.parse(modelName, uri).toWorkflow();
} else {
if (modelName == null) {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
modelName,
modelUrl,
version,
engineName,
deviceMapping,
Input.class,
Output.class,
-1,
-1,
-1,
-1,
-1,
-1);
workflow = new Workflow(modelInfo);
}
CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow)
Expand Down Expand Up @@ -416,60 +421,6 @@ private void initModelStore() throws IOException {
}
}

private void initWorkflows() throws IOException, URISyntaxException, BadWorkflowException {
Set<String> startupWorkflows = ModelManager.getInstance().getStartupWorkflows();
String loadWorkflows = configManager.getLoadWorkflows();
if (loadWorkflows == null || loadWorkflows.isEmpty()) {
return;
}

ModelManager modelManager = ModelManager.getInstance();
String[] urls = loadWorkflows.split("[, ]+");

for (String url : urls) {
logger.info("Initializing workflow: {}", url);
Matcher matcher = MODEL_STORE_PATTERN.matcher(url);
if (!matcher.matches()) {
throw new AssertionError("Invalid model store url: " + url);
}
String endpoint = matcher.group(2);
String workflowUrlString = matcher.group(3);
String workflowName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
workflowName = tokens[0];
} else {
workflowName = ModelInfo.inferModelNameFromUrl(workflowUrlString);
}

URL workflowUrl = new URL(workflowUrlString);
Workflow workflow =
WorkflowDefinition.parse(workflowUrl.toURI(), workflowUrl.openStream())
.toWorkflow();

CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow)
.exceptionally(
t -> {
logger.error("Failed register workflow", t);
// delay 3 seconds, allows REST API to send PING
// response (health check)
try {
Thread.sleep(3000);
} catch (InterruptedException ignore) {
// ignore
}
stop();
return null;
});
if (configManager.waitModelLoading()) {
f.join();
}
startupWorkflows.add(workflowName);
}
}

String mapModelUrl(Path path) {
try {
logger.info("Found file in model_store: {}", path);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
import io.netty.util.CharsetUtil;

import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -190,22 +189,33 @@ private void handleRegisterModel(
req = new LoadModelRequest(decoder);
}

ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
req.getModelName(),
req.getModelUrl(),
req.getVersion(),
req.getEngineName(),
req.getDeviceName(),
Input.class,
Output.class,
req.getJobQueueSize(),
req.getMaxIdleSeconds(),
req.getMaxBatchDelayMillis(),
req.getBatchSize(),
req.getMinWorkers(),
req.getMaxWorkers());
Workflow workflow = new Workflow(modelInfo);
Workflow workflow;
URI uri = WorkflowDefinition.toWorkflowUri(req.getModelUrl());
if (uri != null) {
try {
workflow = WorkflowDefinition.parse(req.getModelName(), uri).toWorkflow();
} catch (IOException | BadWorkflowException e) {
NettyUtils.sendError(ctx, e.getCause());
return;
}
} else {
ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
req.getModelName(),
req.getModelUrl(),
req.getVersion(),
req.getEngineName(),
req.getDeviceName(),
Input.class,
Output.class,
req.getJobQueueSize(),
req.getMaxIdleSeconds(),
req.getMaxBatchDelayMillis(),
req.getBatchSize(),
req.getMinWorkers(),
req.getMaxWorkers());
workflow = new Workflow(modelInfo);
}
final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<Void> f =
modelManager
Expand Down Expand Up @@ -240,9 +250,8 @@ private void handleRegisterWorkflow(
NettyUtils.getParameter(decoder, LoadModelRequest.SYNCHRONOUS, "true"));

try {
URL url = new URL(workflowUrl);
Workflow workflow =
WorkflowDefinition.parse(url.toURI(), url.openStream()).toWorkflow();
URI uri = URI.create(workflowUrl);
Workflow workflow = WorkflowDefinition.parse(null, uri).toWorkflow();
String workflowName = workflow.getName();

final ModelManager modelManager = ModelManager.getInstance();
Expand All @@ -263,8 +272,7 @@ private void handleRegisterWorkflow(
NettyUtils.sendJsonResponse(
ctx, new StatusResponse(msg), HttpResponseStatus.ACCEPTED);
}

} catch (URISyntaxException | IOException | BadWorkflowException e) {
} catch (IOException | BadWorkflowException e) {
NettyUtils.sendError(ctx, e.getCause());
}
}
Expand Down
16 changes: 0 additions & 16 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public final class ConfigManager {
private static final String INFERENCE_ADDRESS = "inference_address";
private static final String MANAGEMENT_ADDRESS = "management_address";
private static final String LOAD_MODELS = "load_models";
private static final String LOAD_WORKFLOWS = "load_workflows";
private static final String WAIT_MODEL_LOADING = "wait_model_loading";
private static final String ALLOW_MULTI_STATUS = "allow_multi_status";
private static final String NUMBER_OF_NETTY_THREADS = "number_of_netty_threads";
Expand Down Expand Up @@ -110,10 +109,6 @@ private ConfigManager(Arguments args) {
if (models != null) {
prop.setProperty(LOAD_MODELS, String.join(",", models));
}
String[] workflows = args.getWorkflows();
if (workflows != null) {
prop.setProperty(LOAD_WORKFLOWS, String.join(",", workflows));
}
for (Map.Entry<String, String> env : Utils.getenv().entrySet()) {
String key = env.getKey();
if (key.startsWith("SERVING_")) {
Expand Down Expand Up @@ -266,15 +261,6 @@ public String getLoadModels() {
return prop.getProperty(LOAD_MODELS);
}

/**
* Returns the workflow urls that to be loaded at startup.
*
* @return the workflow urls that to be loaded at startup
*/
public String getLoadWorkflows() {
return prop.getProperty(LOAD_WORKFLOWS);
}

/**
* Returns the devices the default model will be loaded on at startup.
*
Expand Down Expand Up @@ -440,8 +426,6 @@ public String dumpConfigurations() {
+ (getModelStore() == null ? "N/A" : getModelStore())
+ "\nInitial Models: "
+ (getLoadModels() == null ? "N/A" : getLoadModels())
+ "\nInitial Workflows: "
+ (getLoadWorkflows() == null ? "N/A" : getLoadWorkflows())
+ "\nNetty threads: "
+ getNettyThreads()
+ "\nMaximum Request Size: "
Expand Down
Loading

0 comments on commit fb68c5b

Please sign in to comment.