Skip to content

Commit

Permalink
Worker type (deepjavalibrary#1022)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored Sep 12, 2023
1 parent cece592 commit 71845be
Show file tree
Hide file tree
Showing 24 changed files with 814 additions and 511 deletions.
6 changes: 4 additions & 2 deletions benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration) {
ModelInfo<Void, float[]> modelInfo =
new ModelInfo<>("model", arguments.getModelUrl(), criteria);

WorkerPool<Void, float[]> wp = wlm.registerModel(modelInfo);
int workersPerDevice = numOfWorkers / devices.length;
modelInfo.setMinWorkers(workersPerDevice);
modelInfo.setMaxWorkers(workersPerDevice);
WorkerPool<Void, float[]> wp = wlm.registerWorkerPool(modelInfo);
for (String deviceName : devices) {
wp.initWorkers(deviceName, workersPerDevice, workersPerDevice);
wp.initWorkers(deviceName);
}

// Measure memory before worker kickoff
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import ai.djl.serving.plugins.RequestHandler;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerPoolConfig;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.workflow.Workflow;
import ai.djl.translate.TranslateException;
Expand Down Expand Up @@ -184,8 +185,9 @@ private void handleKServeDescribeModel(ChannelHandlerContext ctx, String[] segme
continue;
}

for (ModelInfo<Input, Output> modelInfo : wf.getModels()) {
if (modelInfo.getStatus() == ModelInfo.Status.READY) {
for (WorkerPoolConfig<Input, Output> wpc : wf.getWpcs()) {
if (wpc.getStatus() == WorkerPoolConfig.Status.READY && wpc instanceof ModelInfo) {
ModelInfo<Input, Output> modelInfo = (ModelInfo<Input, Output>) wpc;
response.name = wf.getName();
response.setPlatform(modelInfo.getEngineName());
Device device = modelInfo.getModels().keySet().iterator().next();
Expand Down Expand Up @@ -242,25 +244,25 @@ private void handleKServeDescribeModelReady(ChannelHandlerContext ctx, String[]
if (segments.length > 5) {
modelVersion = segments[5];
}
ModelInfo<Input, Output> modelInfo = getModelInfo(modelName, modelVersion);
ModelInfo.Status status = modelInfo.getStatus();
WorkerPoolConfig<Input, Output> wpc = getWpc(modelName, modelVersion);
ModelInfo.Status status = wpc.getStatus();

HttpResponseStatus httpResponseStatus;
if (status == ModelInfo.Status.READY) {
if (status == WorkerPoolConfig.Status.READY) {
httpResponseStatus = HttpResponseStatus.OK;
} else {
httpResponseStatus = HttpResponseStatus.FAILED_DEPENDENCY;
}
NettyUtils.sendJsonResponse(ctx, EMPTY_BODY, httpResponseStatus);
}

private ModelInfo<Input, Output> getModelInfo(String modelName, String modelVersion)
private WorkerPoolConfig<Input, Output> getWpc(String modelName, String modelVersion)
throws ModelNotFoundException {
ModelManager modelManager = ModelManager.getInstance();
Workflow workflow = modelManager.getWorkflow(modelName, modelVersion, false);
Collection<ModelInfo<Input, Output>> models;
Collection<WorkerPoolConfig<Input, Output>> models;
if (workflow != null) {
models = workflow.getModels();
models = workflow.getWpcs();
} else {
models = Collections.emptyList();
}
Expand Down
58 changes: 55 additions & 3 deletions serving/docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,68 @@
DJL serving is built on top of [Deep Java Library](https://djl.ai). You can visit the
[DJL github repository](https://github.com/deepjavalibrary/djl) to learn more. For module, dependency and class overview refer to [generated diagrams](https://sourcespy.com/github/deepjavalibrarydjlserving/).

DJL Serving uses a [Netty](https://netty.io/) frontend on top of backend worker thread pools.
The frontend uses a single Netty setup with multiple [HttpRequestHandler](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/HttpRequestHandler.html)s.
DJL Serving exists in roughly four layers:

1. Frontend - A [Netty](https://netty.io/) HTTP client that accepts and manages incoming requests
2. Workflows - A system for combining multiple models and glue code together into execution plans
3. WorkLoadManager (WLM) - A worker thread management system supporting batching and routing.

There is also the model store, specified in the [configuration](configuration.md), that specifies the models to be loaded on startup in the ModelManager.

![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture-02.png)


## Frontend

![Frontend Diagram](https://resources.djl.ai/images/djl-serving/frontend-01.png)

DJL Serving uses a [Netty](https://netty.io/) frontend to handle the incoming requests.
It has a single Netty setup with multiple [HttpRequestHandler](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/HttpRequestHandler.html)s.
Different request handlers will provide support for the [inference API](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/InferenceRequestHandler.html), [Management API](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/ManagementRequestHandler.html), or other APIs available from various plugins.

Those request handlers can then communicate with other layers of the DJL stack.
Mainly, this is through the ModelManager which manages the various endpoints.
Each endpoint can have multiple versions that each correspond to a Workflow.
The inference API will call various workflows (assigned round-robin to the workflows in an endpoint).
The management API can be used to CRUD the various endpoints and workflows in them.

## Workflows

The workflow system is used to support various use cases involving multiple levels of models and glue code.
It is configured using a `workflow.json` file where you can describe the workflow pipeline and the models involved.
More details about workflows can be found in the [workflows guide](workflows.md).

Within the workflow, the main purpose is to call various models.
This is done through the WorkLoadManager.
Creating and removing the workflow will add and remove the respective models from the WLM.
The same model in multiple workflows will correspond to a single model in the WLM.

## WorkLoadManager

![WorkLoadManager Diagram](https://resources.djl.ai/images/djl-serving/wlm-01.png)

The backend is based around the [WorkLoadManager](../../wlm/README.md) module.
The WLM takes care of multiple worker threads for each model along with the batching and request routing to them.
It is also available separately and can be utilized through the WLM module (`ai.djl.serving:wlm`).
This may be useful if you want the DJL Serving worker scaling support, but not the HTTP frontend.

For each model, there is a worker pool corresponding to it's full support.
Each worker pool has a job queue to manage the incoming requests.
It also has a worker pool config describing the task (model) it runs.

The worker pool can then contain multiple worker groups.
The groups correspond to the support for the model on a particular device.
So, the same model can have worker groups for both CPU and GPU or on multiple GPUs by creating multiple worker groups.

Finally, each worker group can contain the individual worker threads.
This allows for multiple threads on the same device (typically CPU) for the same model.
The number of workers in a worker group can be automatically scaled with a minimum and maximum value.
A constant number of workers can be set by having the same minimum and maximum workers.

Within each worker thread inside the WLM, there is a DJL Predictor.
Depending on what Engine the Predictor is, it can run various models such as those from PyTorch, Tensorflow, XGBoost, or any of the other engines supported by DJL.
Notably, there is also a [Python Engine](../../engines/python/README.md) which can be used to run models, preprocessing, and postprocessing defined in a python script.

![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture.png)
When using the Python engine, the DJL Python Predictor (PyPredictor) contains a python process.
This means that each worker thread (in each worker group in each worker pool) has it's own process.
The process can be used and closed through the PyPredictor.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkLoadManager;
import ai.djl.serving.wlm.WorkerGroup;
import ai.djl.serving.wlm.WorkerPool;
import ai.djl.serving.wlm.WorkerPoolConfig;
import ai.djl.serving.wlm.WorkerThread;

import java.util.ArrayList;
Expand All @@ -32,7 +32,7 @@ public class DescribeWorkflowResponse {

private String workflowName;
private String version;
private List<Model> models;
private List<DescribeWorkerPoolConfig> wpcs;

/**
* Constructs a new {@code DescribeWorkflowResponse} instance.
Expand All @@ -42,31 +42,31 @@ public class DescribeWorkflowResponse {
public DescribeWorkflowResponse(ai.djl.serving.workflow.Workflow workflow) {
this.workflowName = workflow.getName();
this.version = workflow.getVersion();
models = new ArrayList<>();
wpcs = new ArrayList<>();

ModelManager manager = ModelManager.getInstance();
WorkLoadManager wlm = manager.getWorkLoadManager();
Set<String> startupWorkflows = manager.getStartupWorkflows();

for (ModelInfo<Input, Output> model : workflow.getModels()) {
ModelInfo.Status status = model.getStatus();
for (WorkerPoolConfig<Input, Output> wpc : workflow.getWpcs()) {
WorkerPoolConfig.Status status = wpc.getStatus();
int activeWorker = 0;
int targetWorker = 0;

Model m = new Model();
models.add(m);
WorkerPool<Input, Output> pool = wlm.getWorkerPool(model);
DescribeWorkerPoolConfig m = new DescribeWorkerPoolConfig();
wpcs.add(m);
WorkerPool<Input, Output> pool = wlm.getWorkerPool(wpc);
if (pool != null) {
pool.cleanup();

m.setModelName(model.getId());
m.setModelUrl(model.getModelUrl());
m.setBatchSize(model.getBatchSize());
m.setMaxBatchDelayMillis(model.getMaxBatchDelayMillis());
m.setMaxIdleSeconds(model.getMaxIdleSeconds());
m.setQueueSize(model.getQueueSize());
m.setModelName(wpc.getId());
m.setModelUrl(wpc.getModelUrl());
m.setBatchSize(wpc.getBatchSize());
m.setMaxBatchDelayMillis(wpc.getMaxBatchDelayMillis());
m.setMaxIdleSeconds(wpc.getMaxIdleSeconds());
m.setQueueSize(wpc.getQueueSize());
m.setRequestInQueue(pool.getJobQueue().size());
m.setLoadedAtStartup(startupWorkflows.contains(model.getId()));
m.setLoadedAtStartup(startupWorkflows.contains(wpc.getId()));

for (WorkerGroup<Input, Output> group : pool.getWorkerGroups().values()) {
Device device = group.getDevice();
Expand All @@ -86,7 +86,7 @@ public DescribeWorkflowResponse(ai.djl.serving.workflow.Workflow workflow) {
}
}

if (status == ModelInfo.Status.READY) {
if (status == WorkerPoolConfig.Status.READY) {
m.setStatus(activeWorker >= targetWorker ? "Healthy" : "Unhealthy");
} else {
m.setStatus(status.name());
Expand Down Expand Up @@ -117,12 +117,12 @@ public String getVersion() {
*
* @return a list of models
*/
public List<Model> getModels() {
return models;
public List<DescribeWorkerPoolConfig> getWpcs() {
return wpcs;
}

/** A class represents model information. */
public static final class Model {
public static final class DescribeWorkerPoolConfig {

private String modelName;
private String modelUrl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerPoolConfig;
import ai.djl.serving.workflow.BadWorkflowException;
import ai.djl.serving.workflow.Workflow;
import ai.djl.serving.workflow.WorkflowDefinition;
Expand Down Expand Up @@ -130,16 +131,16 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco
for (int i = pagination.pageToken; i < pagination.last; ++i) {
String workflowName = keys.get(i);
for (Workflow workflow : endpoints.get(workflowName).getWorkflows()) {
for (ModelInfo<Input, Output> m : workflow.getModels()) {
String status = m.getStatus().toString();
String id = m.getId();
String modelName;
for (WorkerPoolConfig<Input, Output> wpc : workflow.getWpcs()) {
String status = wpc.getStatus().toString();
String id = wpc.getId();
String name;
if (workflowName.equals(id)) {
modelName = workflowName;
name = workflowName;
} else {
modelName = workflowName + ':' + id;
name = workflowName + ':' + id;
}
list.addModel(modelName, workflow.getVersion(), m.getModelUrl(), status);
list.addModel(name, workflow.getVersion(), wpc.getModelUrl(), status);
}
}
}
Expand Down Expand Up @@ -322,16 +323,16 @@ private void handleScaleWorkflow(

List<String> messages = new ArrayList<>();
for (Workflow workflow : workflows) {
// make sure all models are loaded and ready
for (ModelInfo<Input, Output> modelInfo : workflow.getModels()) {
if (modelInfo.getStatus() != ModelInfo.Status.READY) {
// make sure all WorkerPoolConfigs (models) are loaded and ready
for (WorkerPoolConfig<Input, Output> wpc : workflow.getWpcs()) {
if (wpc.getStatus() != WorkerPoolConfig.Status.READY) {
throw new ServiceUnavailableException(
"Model or workflow is not ready: " + workflow.getName());
}
}

for (ModelInfo<Input, Output> modelInfo : workflow.getModels()) {
modelManager.scaleWorkers(modelInfo, deviceName, minWorkers, maxWorkers);
for (WorkerPoolConfig<Input, Output> wpc : workflow.getWpcs()) {
modelManager.scaleWorkers(wpc, deviceName, minWorkers, maxWorkers);
String msg =
"Workflow \""
+ workflow.getName()
Expand Down
Loading

0 comments on commit 71845be

Please sign in to comment.