From 71845be5a80d75fd1578023648a5665961fdef61 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 12 Sep 2023 10:23:08 -0700 Subject: [PATCH] Worker type (#1022) --- .../java/ai/djl/benchmark/WlmBenchmark.java | 6 +- .../serving/kserve/KServeRequestHandler.java | 18 +- serving/docs/architecture.md | 58 ++- .../http/DescribeWorkflowResponse.java | 38 +- .../http/ManagementRequestHandler.java | 25 +- .../ai/djl/serving/models/ModelManager.java | 105 +++--- .../ai/djl/serving/workflow/Workflow.java | 52 +-- .../serving/workflow/WorkflowDefinition.java | 5 +- ...Function.java => WlmWorkflowFunction.java} | 29 +- serving/src/main/puml/frontend.puml | 39 ++ serving/src/main/puml/wlm.puml | 50 +++ .../java/ai/djl/serving/ModelServerTest.java | 7 +- .../java/ai/djl/serving/WorkflowTest.java | 18 +- .../ai/djl/serving/wlm/BatchAggregator.java | 11 +- wlm/src/main/java/ai/djl/serving/wlm/Job.java | 16 +- .../java/ai/djl/serving/wlm/ModelInfo.java | 306 +++------------- .../serving/wlm/PermanentBatchAggregator.java | 6 +- .../serving/wlm/TemporaryBatchAggregator.java | 8 +- .../ai/djl/serving/wlm/WorkLoadManager.java | 85 +++-- .../java/ai/djl/serving/wlm/WorkerGroup.java | 12 +- .../java/ai/djl/serving/wlm/WorkerPool.java | 34 +- .../ai/djl/serving/wlm/WorkerPoolConfig.java | 341 ++++++++++++++++++ .../java/ai/djl/serving/wlm/WorkerThread.java | 54 +-- .../djl/serving/wlm/WorkLoadManagerTest.java | 2 +- 24 files changed, 814 insertions(+), 511 deletions(-) rename serving/src/main/java/ai/djl/serving/workflow/function/{ModelWorkflowFunction.java => WlmWorkflowFunction.java} (57%) create mode 100644 serving/src/main/puml/frontend.puml create mode 100644 serving/src/main/puml/wlm.puml create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/WorkerPoolConfig.java diff --git a/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java b/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java index b01ccb22073..d5a2850cf7f 100644 --- a/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java +++ b/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java @@ -73,10 +73,12 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration) { ModelInfo modelInfo = new ModelInfo<>("model", arguments.getModelUrl(), criteria); - WorkerPool wp = wlm.registerModel(modelInfo); int workersPerDevice = numOfWorkers / devices.length; + modelInfo.setMinWorkers(workersPerDevice); + modelInfo.setMaxWorkers(workersPerDevice); + WorkerPool wp = wlm.registerWorkerPool(modelInfo); for (String deviceName : devices) { - wp.initWorkers(deviceName, workersPerDevice, workersPerDevice); + wp.initWorkers(deviceName); } // Measure memory before worker kickoff diff --git a/plugins/kserve/src/main/java/ai/djl/serving/kserve/KServeRequestHandler.java b/plugins/kserve/src/main/java/ai/djl/serving/kserve/KServeRequestHandler.java index 57d5216446f..66a183fd5c0 100644 --- a/plugins/kserve/src/main/java/ai/djl/serving/kserve/KServeRequestHandler.java +++ b/plugins/kserve/src/main/java/ai/djl/serving/kserve/KServeRequestHandler.java @@ -29,6 +29,7 @@ import ai.djl.serving.plugins.RequestHandler; import ai.djl.serving.util.NettyUtils; import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.wlm.util.WlmException; import ai.djl.serving.workflow.Workflow; import ai.djl.translate.TranslateException; @@ -184,8 +185,9 @@ private void handleKServeDescribeModel(ChannelHandlerContext ctx, String[] segme continue; } - for (ModelInfo modelInfo : wf.getModels()) { - if (modelInfo.getStatus() == ModelInfo.Status.READY) { + for (WorkerPoolConfig wpc : wf.getWpcs()) { + if (wpc.getStatus() == WorkerPoolConfig.Status.READY && wpc instanceof ModelInfo) { + ModelInfo modelInfo = (ModelInfo) wpc; response.name = wf.getName(); response.setPlatform(modelInfo.getEngineName()); Device device = modelInfo.getModels().keySet().iterator().next(); @@ -242,11 +244,11 @@ private void handleKServeDescribeModelReady(ChannelHandlerContext ctx, String[] if (segments.length > 5) { modelVersion = segments[5]; } - ModelInfo modelInfo = getModelInfo(modelName, modelVersion); - ModelInfo.Status status = modelInfo.getStatus(); + WorkerPoolConfig wpc = getWpc(modelName, modelVersion); + ModelInfo.Status status = wpc.getStatus(); HttpResponseStatus httpResponseStatus; - if (status == ModelInfo.Status.READY) { + if (status == WorkerPoolConfig.Status.READY) { httpResponseStatus = HttpResponseStatus.OK; } else { httpResponseStatus = HttpResponseStatus.FAILED_DEPENDENCY; @@ -254,13 +256,13 @@ private void handleKServeDescribeModelReady(ChannelHandlerContext ctx, String[] NettyUtils.sendJsonResponse(ctx, EMPTY_BODY, httpResponseStatus); } - private ModelInfo getModelInfo(String modelName, String modelVersion) + private WorkerPoolConfig getWpc(String modelName, String modelVersion) throws ModelNotFoundException { ModelManager modelManager = ModelManager.getInstance(); Workflow workflow = modelManager.getWorkflow(modelName, modelVersion, false); - Collection> models; + Collection> models; if (workflow != null) { - models = workflow.getModels(); + models = workflow.getWpcs(); } else { models = Collections.emptyList(); } diff --git a/serving/docs/architecture.md b/serving/docs/architecture.md index 664641649f5..95617b9b9f4 100644 --- a/serving/docs/architecture.md +++ b/serving/docs/architecture.md @@ -3,16 +3,68 @@ DJL serving is built on top of [Deep Java Library](https://djl.ai). You can visit the [DJL github repository](https://github.com/deepjavalibrary/djl) to learn more. For module, dependency and class overview refer to [generated diagrams](https://sourcespy.com/github/deepjavalibrarydjlserving/). -DJL Serving uses a [Netty](https://netty.io/) frontend on top of backend worker thread pools. -The frontend uses a single Netty setup with multiple [HttpRequestHandler](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/HttpRequestHandler.html)s. +DJL Serving exists in roughly four layers: + +1. Frontend - A [Netty](https://netty.io/) HTTP client that accepts and manages incoming requests +2. Workflows - A system for combining multiple models and glue code together into execution plans +3. WorkLoadManager (WLM) - A worker thread management system supporting batching and routing. + +There is also the model store, specified in the [configuration](configuration.md), that specifies the models to be loaded on startup in the ModelManager. + +![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture-02.png) + + +## Frontend + +![Frontend Diagram](https://resources.djl.ai/images/djl-serving/frontend-01.png) + +DJL Serving uses a [Netty](https://netty.io/) frontend to handle the incoming requests. +It has a single Netty setup with multiple [HttpRequestHandler](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/HttpRequestHandler.html)s. Different request handlers will provide support for the [inference API](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/InferenceRequestHandler.html), [Management API](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/http/ManagementRequestHandler.html), or other APIs available from various plugins. +Those request handlers can then communicate with other layers of the DJL stack. +Mainly, this is through the ModelManager which manages the various endpoints. +Each endpoint can have multiple versions that each correspond to a Workflow. +The inference API will call various workflows (assigned round-robin to the workflows in an endpoint). +The management API can be used to CRUD the various endpoints and workflows in them. + +## Workflows + +The workflow system is used to support various use cases involving multiple levels of models and glue code. +It is configured using a `workflow.json` file where you can describe the workflow pipeline and the models involved. +More details about workflows can be found in the [workflows guide](workflows.md). + +Within the workflow, the main purpose is to call various models. +This is done through the WorkLoadManager. +Creating and removing the workflow will add and remove the respective models from the WLM. +The same model in multiple workflows will correspond to a single model in the WLM. + +## WorkLoadManager + +![WorkLoadManager Diagram](https://resources.djl.ai/images/djl-serving/wlm-01.png) + The backend is based around the [WorkLoadManager](../../wlm/README.md) module. The WLM takes care of multiple worker threads for each model along with the batching and request routing to them. It is also available separately and can be utilized through the WLM module (`ai.djl.serving:wlm`). +This may be useful if you want the DJL Serving worker scaling support, but not the HTTP frontend. + +For each model, there is a worker pool corresponding to it's full support. +Each worker pool has a job queue to manage the incoming requests. +It also has a worker pool config describing the task (model) it runs. + +The worker pool can then contain multiple worker groups. +The groups correspond to the support for the model on a particular device. +So, the same model can have worker groups for both CPU and GPU or on multiple GPUs by creating multiple worker groups. + +Finally, each worker group can contain the individual worker threads. +This allows for multiple threads on the same device (typically CPU) for the same model. +The number of workers in a worker group can be automatically scaled with a minimum and maximum value. +A constant number of workers can be set by having the same minimum and maximum workers. Within each worker thread inside the WLM, there is a DJL Predictor. Depending on what Engine the Predictor is, it can run various models such as those from PyTorch, Tensorflow, XGBoost, or any of the other engines supported by DJL. Notably, there is also a [Python Engine](../../engines/python/README.md) which can be used to run models, preprocessing, and postprocessing defined in a python script. -![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture.png) +When using the Python engine, the DJL Python Predictor (PyPredictor) contains a python process. +This means that each worker thread (in each worker group in each worker pool) has it's own process. +The process can be used and closed through the PyPredictor. \ No newline at end of file diff --git a/serving/src/main/java/ai/djl/serving/http/DescribeWorkflowResponse.java b/serving/src/main/java/ai/djl/serving/http/DescribeWorkflowResponse.java index f129b41f750..c5badc169f0 100644 --- a/serving/src/main/java/ai/djl/serving/http/DescribeWorkflowResponse.java +++ b/serving/src/main/java/ai/djl/serving/http/DescribeWorkflowResponse.java @@ -16,10 +16,10 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.serving.models.ModelManager; -import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; import ai.djl.serving.wlm.WorkerGroup; import ai.djl.serving.wlm.WorkerPool; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.wlm.WorkerThread; import java.util.ArrayList; @@ -32,7 +32,7 @@ public class DescribeWorkflowResponse { private String workflowName; private String version; - private List models; + private List wpcs; /** * Constructs a new {@code DescribeWorkflowResponse} instance. @@ -42,31 +42,31 @@ public class DescribeWorkflowResponse { public DescribeWorkflowResponse(ai.djl.serving.workflow.Workflow workflow) { this.workflowName = workflow.getName(); this.version = workflow.getVersion(); - models = new ArrayList<>(); + wpcs = new ArrayList<>(); ModelManager manager = ModelManager.getInstance(); WorkLoadManager wlm = manager.getWorkLoadManager(); Set startupWorkflows = manager.getStartupWorkflows(); - for (ModelInfo model : workflow.getModels()) { - ModelInfo.Status status = model.getStatus(); + for (WorkerPoolConfig wpc : workflow.getWpcs()) { + WorkerPoolConfig.Status status = wpc.getStatus(); int activeWorker = 0; int targetWorker = 0; - Model m = new Model(); - models.add(m); - WorkerPool pool = wlm.getWorkerPool(model); + DescribeWorkerPoolConfig m = new DescribeWorkerPoolConfig(); + wpcs.add(m); + WorkerPool pool = wlm.getWorkerPool(wpc); if (pool != null) { pool.cleanup(); - m.setModelName(model.getId()); - m.setModelUrl(model.getModelUrl()); - m.setBatchSize(model.getBatchSize()); - m.setMaxBatchDelayMillis(model.getMaxBatchDelayMillis()); - m.setMaxIdleSeconds(model.getMaxIdleSeconds()); - m.setQueueSize(model.getQueueSize()); + m.setModelName(wpc.getId()); + m.setModelUrl(wpc.getModelUrl()); + m.setBatchSize(wpc.getBatchSize()); + m.setMaxBatchDelayMillis(wpc.getMaxBatchDelayMillis()); + m.setMaxIdleSeconds(wpc.getMaxIdleSeconds()); + m.setQueueSize(wpc.getQueueSize()); m.setRequestInQueue(pool.getJobQueue().size()); - m.setLoadedAtStartup(startupWorkflows.contains(model.getId())); + m.setLoadedAtStartup(startupWorkflows.contains(wpc.getId())); for (WorkerGroup group : pool.getWorkerGroups().values()) { Device device = group.getDevice(); @@ -86,7 +86,7 @@ public DescribeWorkflowResponse(ai.djl.serving.workflow.Workflow workflow) { } } - if (status == ModelInfo.Status.READY) { + if (status == WorkerPoolConfig.Status.READY) { m.setStatus(activeWorker >= targetWorker ? "Healthy" : "Unhealthy"); } else { m.setStatus(status.name()); @@ -117,12 +117,12 @@ public String getVersion() { * * @return a list of models */ - public List getModels() { - return models; + public List getWpcs() { + return wpcs; } /** A class represents model information. */ - public static final class Model { + public static final class DescribeWorkerPoolConfig { private String modelName; private String modelUrl; diff --git a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 55de30430da..8a2e1019490 100644 --- a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -20,6 +20,7 @@ import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.NettyUtils; import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.BadWorkflowException; import ai.djl.serving.workflow.Workflow; import ai.djl.serving.workflow.WorkflowDefinition; @@ -130,16 +131,16 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco for (int i = pagination.pageToken; i < pagination.last; ++i) { String workflowName = keys.get(i); for (Workflow workflow : endpoints.get(workflowName).getWorkflows()) { - for (ModelInfo m : workflow.getModels()) { - String status = m.getStatus().toString(); - String id = m.getId(); - String modelName; + for (WorkerPoolConfig wpc : workflow.getWpcs()) { + String status = wpc.getStatus().toString(); + String id = wpc.getId(); + String name; if (workflowName.equals(id)) { - modelName = workflowName; + name = workflowName; } else { - modelName = workflowName + ':' + id; + name = workflowName + ':' + id; } - list.addModel(modelName, workflow.getVersion(), m.getModelUrl(), status); + list.addModel(name, workflow.getVersion(), wpc.getModelUrl(), status); } } } @@ -322,16 +323,16 @@ private void handleScaleWorkflow( List messages = new ArrayList<>(); for (Workflow workflow : workflows) { - // make sure all models are loaded and ready - for (ModelInfo modelInfo : workflow.getModels()) { - if (modelInfo.getStatus() != ModelInfo.Status.READY) { + // make sure all WorkerPoolConfigs (models) are loaded and ready + for (WorkerPoolConfig wpc : workflow.getWpcs()) { + if (wpc.getStatus() != WorkerPoolConfig.Status.READY) { throw new ServiceUnavailableException( "Model or workflow is not ready: " + workflow.getName()); } } - for (ModelInfo modelInfo : workflow.getModels()) { - modelManager.scaleWorkers(modelInfo, deviceName, minWorkers, maxWorkers); + for (WorkerPoolConfig wpc : workflow.getWpcs()) { + modelManager.scaleWorkers(wpc, deviceName, minWorkers, maxWorkers); String msg = "Workflow \"" + workflow.getName() diff --git a/serving/src/main/java/ai/djl/serving/models/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java index c1b68b1e983..e9091a4b436 100644 --- a/serving/src/main/java/ai/djl/serving/models/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -24,6 +24,7 @@ import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; import ai.djl.serving.wlm.WorkerPool; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.Workflow; import org.slf4j.Logger; @@ -88,41 +89,56 @@ public CompletableFuture registerWorkflow(Workflow workflow) { return CompletableFuture.supplyAsync( () -> { - Map> models = workflow.getModelMap(); - for (Map.Entry> entry : models.entrySet()) { + Map> wpcs = workflow.getWpcMap(); + for (Map.Entry> entry : + wpcs.entrySet()) { String key = entry.getKey(); - ModelInfo model = entry.getValue(); + WorkerPoolConfig workerPoolConfig = entry.getValue(); try { // download model and configure per model settings - model.initialize(); + workerPoolConfig.initialize(); // Install engine if necessary - String engine = model.getEngineName(); - DependencyManager dm = DependencyManager.getInstance(); - dm.installEngine(engine); - Thread.currentThread() - .setContextClassLoader(MutableClassLoader.getInstance()); - WorkerPool wp = wlm.getWorkerPool(model); - if (wp != null) { - models.put(key, wp.getModel()); - wp.increaseRef(); - logger.info("Model {} is registered by other workflow", model); - continue; + String engine = null; + if (workerPoolConfig instanceof ModelInfo) { + ModelInfo model = + (ModelInfo) workerPoolConfig; + engine = model.getEngineName(); + DependencyManager dm = DependencyManager.getInstance(); + dm.installEngine(engine); + Thread.currentThread() + .setContextClassLoader(MutableClassLoader.getInstance()); + WorkerPool wp = wlm.getWorkerPool(model); + if (wp != null) { + wpcs.put(key, wp.getWpc()); + wp.increaseRef(); + logger.info("Model {} is registered by other workflow", model); + continue; + } } - wlm.registerModel(model); - String[] devices = model.getLoadOnDevices(); - logger.info("Loading model on {}:{}", engine, Arrays.toString(devices)); + wlm.registerWorkerPool(workerPoolConfig); + String[] devices = workerPoolConfig.getLoadOnDevices(); + if (engine != null) { + logger.info( + "Loading model on {}:{}", engine, Arrays.toString(devices)); + } else { + logger.info("Loading worker: {}", Arrays.toString(devices)); + } ExecutorService pool = null; List> futures = new ArrayList<>(); - if (model.isParallelLoading()) { + if (workerPoolConfig.isParallelLoading()) { pool = Executors.newFixedThreadPool(devices.length); } for (String deviceName : devices) { if (pool != null) { - futures.add(pool.submit(() -> initWorkers(model, deviceName))); + futures.add( + pool.submit( + () -> + initWorkers( + workerPoolConfig, deviceName))); } else { - initWorkers(model, deviceName); + initWorkers(workerPoolConfig, deviceName); } } if (pool != null) { @@ -159,11 +175,11 @@ public boolean unregisterWorkflow(String workflowName, String version) { logger.warn("Model not found: {}", workflowName); return false; } - Set> candidateModelsToUnregister = new HashSet<>(); + Set> candidateWpcsToUnregister = new HashSet<>(); if (version == null) { // unregister all versions for (Workflow workflow : endpoint.getWorkflows()) { - candidateModelsToUnregister.addAll(workflow.getModels()); + candidateWpcsToUnregister.addAll(workflow.getWpcs()); workflow.stop(); } startupWorkflows.remove(workflowName); @@ -175,7 +191,7 @@ public boolean unregisterWorkflow(String workflowName, String version) { logger.warn("Workflow not found: {}:{}", workflowName, version); return false; } - candidateModelsToUnregister.addAll(workflow.getModels()); + candidateWpcsToUnregister.addAll(workflow.getWpcs()); workflow.stop(); startupWorkflows.remove(workflowName); logger.info("Model {}/{} unregistered.", workflowName, version); @@ -185,39 +201,42 @@ public boolean unregisterWorkflow(String workflowName, String version) { } // Unregister candidate models if they are not used for a remaining endpoint - candidateModelsToUnregister.removeAll(getModels()); - for (ModelInfo model : candidateModelsToUnregister) { - wlm.unregisterModel(model); + candidateWpcsToUnregister.removeAll(getWpcs()); + for (WorkerPoolConfig wpc : candidateWpcsToUnregister) { + wlm.unregisterWorkerPool(wpc); } return true; } /** - * Initializes the workers for a model. + * Initializes the workers for a workerPoolConfig. * - * @param model the model to scale workers for - * @param deviceName the device for the model + * @param wpc the workerPoolConfig to scale workers for + * @param deviceName the device for the workerPoolConfig * @see WorkerPool#initWorkers(String) */ - public void initWorkers(ModelInfo model, String deviceName) { + public void initWorkers(WorkerPoolConfig wpc, String deviceName) { Thread.currentThread().setContextClassLoader(MutableClassLoader.getInstance()); - wlm.getWorkerPool(model).initWorkers(deviceName); + wlm.getWorkerPool(wpc).initWorkers(deviceName); } /** * Scales the workers for a model. * - * @param model the model to scale workers for + * @param wpc the model to scale workers for * @param deviceName the device for the model * @param minWorkers the min workers, -1 for auto-scale * @param maxWorkers the max workers, -1 for auto-scale * @see WorkerPool#scaleWorkers(String, int, int) */ public void scaleWorkers( - ModelInfo model, String deviceName, int minWorkers, int maxWorkers) { + WorkerPoolConfig wpc, + String deviceName, + int minWorkers, + int maxWorkers) { Thread.currentThread().setContextClassLoader(MutableClassLoader.getInstance()); - wlm.getWorkerPool(model).scaleWorkers(deviceName, minWorkers, maxWorkers); + wlm.getWorkerPool(wpc).scaleWorkers(deviceName, minWorkers, maxWorkers); } /** @@ -230,14 +249,14 @@ public Map getEndpoints() { } /** - * Returns all models in an endpoint. + * Returns all {@link WorkerPoolConfig}s in an endpoint. * - * @return all models in an endpoint + * @return all {@link WorkerPoolConfig}s in an endpoint */ - public Set> getModels() { + public Set> getWpcs() { return getEndpoints().values().stream() .flatMap(e -> e.getWorkflows().stream()) - .flatMap(w -> w.getModels().stream()) + .flatMap(w -> w.getWpcs().stream()) .collect(Collectors.toSet()); } @@ -350,12 +369,12 @@ public CompletableFuture> workerStatus() { for (Endpoint endpoint : endpoints.values()) { for (Workflow wf : endpoint.getWorkflows()) { String workflowName = wf.getName(); - for (ModelInfo m : wf.getModels()) { - String modelName = m.getId(); + for (WorkerPoolConfig wpc : wf.getWpcs()) { + String modelName = wpc.getId(); if (!modelName.equals(workflowName)) { modelName = workflowName + ':' + modelName; // NOPMD } - ModelInfo.Status status = m.getStatus(); + WorkerPoolConfig.Status status = wpc.getStatus(); switch (status) { case FAILED: data.put(modelName, new StatusResponse(status.name())); @@ -366,7 +385,7 @@ public CompletableFuture> workerStatus() { hasPending = true; break; default: - if (wlm.getWorkerPool(m).isFullyScaled()) { + if (wlm.getWorkerPool(wpc).isFullyScaled()) { data.put(modelName, new StatusResponse("Healthy")); } else { data.put(modelName, new StatusResponse("Unhealthy")); diff --git a/serving/src/main/java/ai/djl/serving/workflow/Workflow.java b/serving/src/main/java/ai/djl/serving/workflow/Workflow.java index c23b8268ca4..396288a1ca2 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/Workflow.java +++ b/serving/src/main/java/ai/djl/serving/workflow/Workflow.java @@ -14,14 +14,14 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; -import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.WorkflowExpression.Item; import ai.djl.serving.workflow.WorkflowExpression.Item.ItemType; import ai.djl.serving.workflow.function.EnsembleMerge; import ai.djl.serving.workflow.function.FunctionsApply; import ai.djl.serving.workflow.function.IdentityWF; -import ai.djl.serving.workflow.function.ModelWorkflowFunction; +import ai.djl.serving.workflow.function.WlmWorkflowFunction; import ai.djl.serving.workflow.function.WorkflowFunction; import org.slf4j.Logger; @@ -58,21 +58,21 @@ public class Workflow { String name; String version; - Map> models; + Map> wpcs; Map expressions; Map funcs; Map> configs; /** - * Constructs a workflow containing only a single model. + * Constructs a workflow containing only a single workerPoolConfig. * - * @param model the model for the workflow + * @param wpc the workerPoolConfig for the workflow */ - public Workflow(ModelInfo model) { + public Workflow(WorkerPoolConfig wpc) { String modelName = "model"; - this.name = model.getId(); - this.version = model.getVersion(); - models = Collections.singletonMap(modelName, model); + this.name = wpc.getId(); + this.version = wpc.getVersion(); + wpcs = Collections.singletonMap(modelName, wpc); expressions = Collections.singletonMap( OUT, new WorkflowExpression(new Item(modelName), new Item(IN))); @@ -85,7 +85,7 @@ public Workflow(ModelInfo model) { * * @param name workflow name * @param version workflow version - * @param models a map of executableNames for a model (how it is referred to in the {@link + * @param wpcs a map of executableNames for a wpc (how it is referred to in the {@link * WorkflowExpression}s to model * @param expressions a map of names to refer to an expression to the expression * @param configs the configuration objects @@ -94,34 +94,34 @@ public Workflow(ModelInfo model) { public Workflow( String name, String version, - Map> models, + Map> wpcs, Map expressions, Map> configs, Map funcs) { this.name = name; this.version = version; - this.models = models; + this.wpcs = wpcs; this.expressions = expressions; this.funcs = funcs; this.configs = configs; } /** - * Returns the models used in the workflow. + * Returns the {@link WorkerPoolConfig}s used in the workflow. * - * @return the models used in the workflow + * @return the wpcs used in the workflow */ - public Collection> getModels() { - return models.values(); + public Collection> getWpcs() { + return wpcs.values(); } /** - * Returns the model map in the workflow. + * Returns the wpc map in the workflow. * - * @return the model map in the workflow + * @return the wpc map in the workflow */ - public Map> getModelMap() { - return models; + public Map> getWpcMap() { + return wpcs; } /** @@ -194,10 +194,10 @@ public String toString() { return name; } - /** Stops the workflow and unloads all the models in the workflow. */ + /** Stops the workflow and unloads all the wpcs in the workflow. */ public void stop() { - for (ModelInfo m : getModels()) { - m.close(); + for (WorkerPoolConfig wpc : getWpcs()) { + wpc.close(); } } @@ -299,9 +299,9 @@ public WorkflowFunction getExecutable(WorkflowArgument arg) { * @return the function to execute the found executable */ public WorkflowFunction getExecutable(String name) { - ModelInfo model = models.get(name); - if (model != null) { - return new ModelWorkflowFunction(model); + WorkerPoolConfig wpc = wpcs.get(name); + if (wpc != null) { + return new WlmWorkflowFunction(wpc); } if (funcs.containsKey(name)) { diff --git a/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java b/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java index 1ff2d31f9c9..964705f15d5 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java +++ b/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java @@ -17,6 +17,7 @@ import ai.djl.repository.FilenameUtils; import ai.djl.serving.util.MutableClassLoader; import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.WorkflowExpression.Item; import ai.djl.serving.workflow.function.WorkflowFunction; import ai.djl.util.ClassLoaderUtils; @@ -233,7 +234,9 @@ public Workflow toWorkflow() throws BadWorkflowException { } } - return new Workflow(name, version, models, expressions, configs, loadedFunctions); + Map> wpcs = new ConcurrentHashMap<>(models); + wpcs.putAll(models); + return new Workflow(name, version, wpcs, expressions, configs, loadedFunctions); } private static final class ModelDefinitionDeserializer diff --git a/serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java b/serving/src/main/java/ai/djl/serving/workflow/function/WlmWorkflowFunction.java similarity index 57% rename from serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java rename to serving/src/main/java/ai/djl/serving/workflow/function/WlmWorkflowFunction.java index 136e6658355..42b282677fe 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java +++ b/serving/src/main/java/ai/djl/serving/workflow/function/WlmWorkflowFunction.java @@ -15,25 +15,28 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.serving.wlm.Job; -import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.Workflow; import ai.djl.serving.workflow.WorkflowExpression.Item; import java.util.List; import java.util.concurrent.CompletableFuture; -/** An internal {@link WorkflowFunction} that is used to execute models in the workflow. */ -public class ModelWorkflowFunction extends WorkflowFunction { +/** + * An internal {@link WorkflowFunction} that is used to execute a {@link WorkerPoolConfig} + * (typically a model) through the {@link ai.djl.serving.wlm.WorkLoadManager} in the workflow. + */ +public class WlmWorkflowFunction extends WorkflowFunction { - ModelInfo model; + WorkerPoolConfig workerPoolConfig; /** - * Constructs a {@link ModelWorkflowFunction} with a given model. + * Constructs a {@link WlmWorkflowFunction} with a given workerPoolConfig. * - * @param model the model to run + * @param wpc the workerPoolConfig to run */ - public ModelWorkflowFunction(ModelInfo model) { - this.model = model; + public WlmWorkflowFunction(WorkerPoolConfig wpc) { + this.workerPoolConfig = wpc; } /** {@inheritDoc} */ @@ -43,14 +46,20 @@ public CompletableFuture run( Workflow.WorkflowExecutor executor, List args) { if (args.size() != 1) { throw new IllegalArgumentException( - "The model " + model.getId() + " should have one arg, but has " + args.size()); + "The model or worker type " + + workerPoolConfig.getId() + + " should have one arg, but has " + + args.size()); } return evaluateArgs(args) .thenComposeAsync( processedArgs -> executor.getWlm() - .runJob(new Job<>(model, processedArgs.get(0).getInput())) + .runJob( + new Job<>( + workerPoolConfig, + processedArgs.get(0).getInput())) .thenApply(Item::new)); } } diff --git a/serving/src/main/puml/frontend.puml b/serving/src/main/puml/frontend.puml new file mode 100644 index 00000000000..bd9fd22d671 --- /dev/null +++ b/serving/src/main/puml/frontend.puml @@ -0,0 +1,39 @@ +@startuml +skinparam componentStyle uml1 +package "DJL Serving - single process" { + + frame ModelManager as mm { + frame "Endpoints" { + frame "Endpoint mlp" { + [Workflow 0] as w0 + } + } + } + + frame Netty { + package REST_API [ + /resnet18 + .. + /resnet18/v1 + .. + /resnet18/v2 + -- + /BERT + === + Management API + ] + + package HttpRequestHandlers [ + InferenceRequestHandler + ManagementRequestHandler + KServeRequestHandler + ... + ] + HTTP -r-> REST_API + } + + REST_API -r-> HttpRequestHandlers + HttpRequestHandlers -r-> w0 + +} +@enduml diff --git a/serving/src/main/puml/wlm.puml b/serving/src/main/puml/wlm.puml new file mode 100644 index 00000000000..8efcba2794e --- /dev/null +++ b/serving/src/main/puml/wlm.puml @@ -0,0 +1,50 @@ +@startuml +skinparam componentStyle uml1 + +package "DJL Serving WorkLoad Manager" as wlm { + frame "Worker pool (mlp)" as wp_mlp { + + queue "Job queue\nauto batch" as jq_mlp + + frame "WorkerGroup (CPU)" as wg_mlp { + + [Predictor WorkerThread 1] + [Predictor WorkerThread 2] + [Predictor WorkerThread 3] + [Predictor WorkerThread 4] + } + } + + frame "Worker pool (resnet18_v1)" as wp_resnet { + + queue "Job queue\nauto batch" as jq_resnet + + frame "WorkerGroup (GPU0)" as wg0 { + [Predictor WorkerThread] as wt0 + } + } + + frame "Worker pool (llama)" as wp_llama { + queue "Job queue\nauto batch" as jq_llama + + frame "WorkerGroup (GPU1)" as wg1 { + [Predictor WorkerThread] as wt1 + } + frame "WorkerGroup (GPU2)" as wg2 { + [Predictor WorkerThread] as wt2 + } + wg1-[hidden]down-wg2 + } + wp_mlp-[hidden]down-wp_resnet + wp_resnet-[hidden]down-wp_llama +} + +frame "Python Workers" { + [Worker llama GPU1] as py1 + [Worker llama GPU2] as py2 + py1-[hidden]down-py2 +} + +wt1 -> py1 +wt2 -> py2 +@enduml diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 3fae09a87ad..d83d62eaa49 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -23,6 +23,7 @@ import ai.djl.repository.MRL; import ai.djl.repository.Repository; import ai.djl.serving.http.DescribeWorkflowResponse; +import ai.djl.serving.http.DescribeWorkflowResponse.DescribeWorkerPoolConfig; import ai.djl.serving.http.ErrorResponse; import ai.djl.serving.http.ListModelsResponse; import ai.djl.serving.http.ListWorkflowsResponse; @@ -476,7 +477,7 @@ private void testPerModelWorkers(Channel channel) throws InterruptedException { Type type = new TypeToken() {}.getType(); DescribeWorkflowResponse[] resp = JsonUtils.GSON.fromJson(result, type); DescribeWorkflowResponse wf = resp[0]; - DescribeWorkflowResponse.Model model = wf.getModels().get(0); + DescribeWorkerPoolConfig model = wf.getWpcs().get(0); assertEquals(model.getQueueSize(), 10); DescribeWorkflowResponse.Group group = model.getWorkGroups().get(0); @@ -638,8 +639,8 @@ private void testDescribeModel(Channel channel) throws InterruptedException { assertEquals(wf.getWorkflowName(), "mlp_2"); assertNull(wf.getVersion()); - List models = wf.getModels(); - DescribeWorkflowResponse.Model model = models.get(0); + List models = wf.getWpcs(); + DescribeWorkerPoolConfig model = models.get(0); assertEquals(model.getModelName(), "mlp_2"); assertNotNull(model.getModelUrl()); assertEquals(model.getBatchSize(), 1); diff --git a/serving/src/test/java/ai/djl/serving/WorkflowTest.java b/serving/src/test/java/ai/djl/serving/WorkflowTest.java index 869a3c71864..845aba75a9d 100644 --- a/serving/src/test/java/ai/djl/serving/WorkflowTest.java +++ b/serving/src/test/java/ai/djl/serving/WorkflowTest.java @@ -15,8 +15,8 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.serving.util.ConfigManager; -import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.BadWorkflowException; import ai.djl.serving.workflow.Workflow; import ai.djl.serving.workflow.WorkflowDefinition; @@ -78,12 +78,12 @@ public void testFunctions() throws IOException, BadWorkflowException { public void testLocalPerf() throws IOException, BadWorkflowException { Path workflowFile = Paths.get("src/test/resources/workflows/localPerf.json"); Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); - ModelInfo m = workflow.getModels().iterator().next(); + WorkerPoolConfig wpc = workflow.getWpcs().iterator().next(); - Assert.assertEquals(m.getQueueSize(), 102); - Assert.assertEquals(m.getMaxIdleSeconds(), 62); - Assert.assertEquals(m.getMaxBatchDelayMillis(), 302); - Assert.assertEquals(m.getBatchSize(), 3); + Assert.assertEquals(wpc.getQueueSize(), 102); + Assert.assertEquals(wpc.getMaxIdleSeconds(), 62); + Assert.assertEquals(wpc.getMaxBatchDelayMillis(), 302); + Assert.assertEquals(wpc.getBatchSize(), 3); } @Test @@ -96,9 +96,9 @@ private Input runWorkflow(Path workflowFile, Input input) throws IOException, BadWorkflowException { Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); WorkLoadManager wlm = new WorkLoadManager(); - for (ModelInfo model : workflow.getModels()) { - model.setMaxWorkers(1); - wlm.registerModel(model).initWorkers("-1"); + for (WorkerPoolConfig wpc : workflow.getWpcs()) { + wpc.setMaxWorkers(1); + wlm.registerWorkerPool(wpc).initWorkers("-1"); } Output output = workflow.execute(wlm, input).join(); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java index db0fe3857d2..a03da89782e 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java @@ -44,13 +44,14 @@ abstract class BatchAggregator { /** * Constructs a new {@code BbatchAggregator} instance. * - * @param model the model to use. + * @param wpc the workerPoolConfig to use. * @param jobQueue the job queue for polling data from. */ - public BatchAggregator(ModelInfo model, LinkedBlockingDeque> jobQueue) { - this.dimension = new Dimension("Model", model.getId()); - this.batchSize = model.getBatchSize(); - this.maxBatchDelayMicros = model.getMaxBatchDelayMillis() * 1000L; + public BatchAggregator( + WorkerPoolConfig wpc, LinkedBlockingDeque> jobQueue) { + this.dimension = new Dimension("Model", wpc.getId()); + this.batchSize = wpc.getBatchSize(); + this.maxBatchDelayMicros = wpc.getMaxBatchDelayMillis() * 1000L; this.jobQueue = jobQueue; wjs = new ArrayList<>(); } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/Job.java b/wlm/src/main/java/ai/djl/serving/wlm/Job.java index 02276745af3..1bdf11f3d99 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/Job.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/Job.java @@ -15,30 +15,30 @@ /** A class represents an inference job. */ public class Job { - private ModelInfo modelInfo; + private WorkerPoolConfig workerPoolConfig; private I input; private long begin; /** * Constructs a new {@code Job} instance. * - * @param modelInfo the model to run the job + * @param wpc the model to run the job * @param input the input data */ - public Job(ModelInfo modelInfo, I input) { - this.modelInfo = modelInfo; + public Job(WorkerPoolConfig wpc, I input) { + this.workerPoolConfig = wpc; this.input = input; begin = System.nanoTime(); } /** - * Returns the model that associated with this job. + * Returns the worker pool config that is associated with this job. * - * @return the model that associated with this job + * @return the worker pool config that is associated with this job */ - public ModelInfo getModel() { - return modelInfo; + public WorkerPoolConfig getWpc() { + return workerPoolConfig; } /** diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index bb2d789c088..d8e02636805 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -18,6 +18,8 @@ import ai.djl.ModelException; import ai.djl.engine.Engine; import ai.djl.engine.EngineException; +import ai.djl.inference.Predictor; +import ai.djl.metric.Metrics; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.ndarray.NDManager; @@ -30,6 +32,7 @@ import ai.djl.repository.zoo.ZooModel; import ai.djl.serving.wlm.util.WlmConfigManager; import ai.djl.serving.wlm.util.WlmOutOfMemoryException; +import ai.djl.translate.TranslateException; import ai.djl.util.NeuronUtils; import ai.djl.util.Utils; import ai.djl.util.cuda.CudaUtils; @@ -58,25 +61,16 @@ import java.util.stream.Stream; /** A class represent a loaded model and it's metadata. */ -public final class ModelInfo { +public final class ModelInfo extends WorkerPoolConfig { private static final Logger logger = LoggerFactory.getLogger(ModelInfo.class); + private static final Logger MODEL_METRIC = LoggerFactory.getLogger("model_metric"); private static final Pattern PATTERN = Pattern.compile("MemAvailable:\\s+(\\d+) kB"); - private transient String id; - private String version; - private String modelUrl; private String engineName; private String loadOnDevices; - private int queueSize; - private int batchSize; - private int maxBatchDelayMillis; - private int maxIdleSeconds; - private Integer minWorkers; // Integer so it becomes null when parsed from JSON - private Integer maxWorkers; - // the following fields can be loaded from workflow json file private Map filters; private Map arguments; @@ -184,13 +178,8 @@ public void postWorkflowParsing(String workflowDir) { modelUrl = modelUrl.replaceAll("\\{model_dir}", workflowDir); } - /** - * Loads the model to the specified device. - * - * @param device the device to load model on - * @throws IOException if failed to read model file - * @throws ModelException if failed to load the specified model - */ + /** {@inheritDoc} */ + @Override @SuppressWarnings("unchecked") public void load(Device device) throws ModelException, IOException { if (getModels().containsKey(device)) { @@ -279,31 +268,10 @@ public ZooModel getModel(Device device) { return getModels().get(device); } - /** - * Sets the model ID. - * - * @param id the model ID - */ - public void setId(String id) { - this.id = id; - } - - /** - * Returns the model ID. - * - * @return the model ID - */ - public String getId() { - return id; - } - - /** - * Returns the model version. - * - * @return the model version - */ - public String getVersion() { - return version; + /** {@inheritDoc} */ + @Override + public ThreadType newThread(Device device) { + return new ModelThread(device); } /** @@ -315,20 +283,8 @@ public String getEngineName() { return engineName; } - /** - * Returns the model url. - * - * @return the model url - */ - public String getModelUrl() { - return modelUrl; - } - - /** - * Returns the model loading status. - * - * @return the model loading status - */ + /** {@inheritDoc} */ + @Override public Status getStatus() { if (status == null) { return Status.PENDING; @@ -385,102 +341,8 @@ public void hasInputOutputClass(Class inputClass, Class outputClass) { this.outputClass = outputClass; } - /** - * Sets the configured max idle time in seconds of workers. - * - * @param maxIdleSeconds the configured max idle time in seconds of workers - */ - public void setMaxIdleSeconds(int maxIdleSeconds) { - this.maxIdleSeconds = maxIdleSeconds; - } - - /** - * Returns the configured max idle time in seconds of workers. - * - * @return the max idle time in seconds - */ - public int getMaxIdleSeconds() { - return maxIdleSeconds; - } - - /** - * Sets the configured batch size. - * - * @param batchSize the configured batch size - */ - public void setBatchSize(int batchSize) { - this.batchSize = batchSize; - } - - /** - * Returns the configured batch size. - * - * @return the configured batch size - */ - public int getBatchSize() { - return batchSize; - } - - /** - * Sets the maximum delay in milliseconds to aggregate a batch. - * - * @param maxBatchDelayMillis the maximum delay in milliseconds to aggregate a batch - */ - public void setMaxBatchDelayMillis(int maxBatchDelayMillis) { - this.maxBatchDelayMillis = maxBatchDelayMillis; - } - - /** - * Returns the maximum delay in milliseconds to aggregate a batch. - * - * @return the maximum delay in milliseconds to aggregate a batch - */ - public int getMaxBatchDelayMillis() { - return maxBatchDelayMillis; - } - - /** - * Sets the configured size of the workers queue. - * - * @param queueSize the configured size of the workers queue - */ - public void setQueueSize(int queueSize) { - this.queueSize = queueSize; - } - - /** - * Returns the configured size of the workers queue. - * - * @return requested size of the workers queue. - */ - public int getQueueSize() { - return queueSize; - } - - /** - * Sets the minimum number of workers. - * - * @param minWorkers the new minimum number of workers - */ - public void setMinWorkers(int minWorkers) { - if (maxWorkers != null && maxWorkers < minWorkers) { - throw new IllegalArgumentException( - "The max workers for a model can't be smaller than the min workers"); - } - if (minWorkers == 0) { - throw new IllegalArgumentException( - "Having a minWorkers of 0 is not currently supported"); - } - - this.minWorkers = minWorkers; - } - - /** - * Returns the minimum number of workers. - * - * @param device the device to get the min workers for - * @return the minimum number of workers - */ + /** {@inheritDoc} */ + @Override public int getMinWorkers(Device device) { if (minWorkers != null && minWorkers >= 0) { return minWorkers; @@ -489,52 +351,8 @@ public int getMinWorkers(Device device) { return getWorkersMinMaxProperty(getModel(device), device, "minWorkers", 1); } - /** - * Sets the maximum number of workers. - * - * @param maxWorkers the new maximum number of workers - */ - public void setMaxWorkers(int maxWorkers) { - if (minWorkers != null && maxWorkers < minWorkers) { - throw new IllegalArgumentException( - "The max workers for a model can't be smaller than the min workers"); - } - if (maxWorkers == 0) { - throw new IllegalArgumentException("Models must have a maxWorkers greater than 0"); - } - - this.maxWorkers = maxWorkers; - } - - /** - * Sets the minimum and maximum number of workers. - * - * @param minWorkers the new minimum number of workers - * @param maxWorkers the new maximum number of workers - */ - public void setMinMaxWorkers(int minWorkers, int maxWorkers) { - if (maxWorkers < minWorkers) { - throw new IllegalArgumentException( - "The max workers for a model can't be smaller than the min workers"); - } - if (minWorkers == 0) { - throw new IllegalArgumentException( - "Having a minWorkers of 0 is not currently supported"); - } - if (maxWorkers == 0) { - throw new IllegalArgumentException("Models must have a maxWorkers greater than 0"); - } - - this.minWorkers = minWorkers; - this.maxWorkers = maxWorkers; - } - - /** - * Returns the maximum number of workers. - * - * @param device the device to get the min workers for - * @return the maximum number of workers - */ + /** {@inheritDoc} */ + @Override public int getMaxWorkers(Device device) { if (maxWorkers != null && maxWorkers >= 0) { return maxWorkers; @@ -591,12 +409,8 @@ private int getWorkersMinMaxProperty(Model model, Device device, String key, int return def; } - /** - * Initialize the model. - * - * @throws IOException if failed to download model - * @throws ModelNotFoundException if model not found - */ + /** {@inheritDoc} */ + @Override public void initialize() throws IOException, ModelException { downloadModel(); loadServingProperties(); @@ -618,7 +432,8 @@ public void initialize() throws IOException, ModelException { } } - /** Close all loaded models. */ + /** {@inheritDoc} */ + @Override public void close() { if (!getModels().isEmpty() && !Boolean.getBoolean("ai.djl.serving.keep_cache")) { logger.info("Unloading model: {}{}", id, version == null ? "" : '/' + version); @@ -668,12 +483,8 @@ public static String inferModelNameFromUrl(String url) { return modelName; } - /** - * Returns the default device for this model if device is null. - * - * @param deviceName the device to use if it is not null - * @return a non-null device - */ + /** {@inheritDoc} */ + @Override public Device withDefaultDevice(String deviceName) { return Device.fromName(deviceName, Engine.getEngine(engineName)); } @@ -881,11 +692,8 @@ void checkAvailableMemory(Device device) throws IOException { } } - /** - * Returns the devices the model will be loaded on at startup. - * - * @return the devices the model will be loaded on at startup - */ + /** {@inheritDoc} */ + @Override public String[] getLoadOnDevices() { Engine engine = Engine.getEngine(engineName); if ("*".equals(loadOnDevices)) { @@ -948,11 +756,8 @@ public String[] getLoadOnDevices() { return new String[] {"-1"}; } - /** - * Returns if the model can be load parallel on multiple devices. - * - * @return if the model can be load parallel on multiple devices - */ + /** {@inheritDoc} */ + @Override public boolean isParallelLoading() { return Boolean.parseBoolean(prop.getProperty("option.parallel_loading")); } @@ -1067,38 +872,37 @@ private static int intValue(Properties prop, String key, int defValue) { return Integer.parseInt(value); } - /** {@inheritDoc} */ - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ModelInfo)) { - return false; - } - ModelInfo modelInfo = (ModelInfo) o; - return id.equals(modelInfo.id) && Objects.equals(version, modelInfo.version); - } + protected class ModelThread extends ThreadType { - /** {@inheritDoc} */ - @Override - public int hashCode() { - return Objects.hash(id, version); - } + private Predictor predictor; + ZooModel model; - /** {@inheritDoc} */ - @Override - public String toString() { - if (version != null) { - return id + ':' + version + " (" + getStatus() + ')'; + protected ModelThread(Device device) { + super(device); + model = getModel(device); + predictor = model.newPredictor(); + + boolean logModelMetric = Boolean.parseBoolean(model.getProperty("log_model_metric")); + if (logModelMetric) { + int metricsAggregation = + Integer.parseInt(model.getProperty("metrics_aggregation", "1000")); + Metrics metrics = new Metrics(); + metrics.setLimit(metricsAggregation); + metrics.setOnLimit((m, s) -> MODEL_METRIC.info("{}-{}", id, m.percentile(s, 50))); + predictor.setMetrics(metrics); + } + } + + /** {@inheritDoc} */ + @Override + public List run(List input) throws TranslateException { + return predictor.batchPredict(input); } - return id + " (" + getStatus() + ')'; - } - /** An enum represents state of a model. */ - public enum Status { - PENDING, - READY, - FAILED + /** {@inheritDoc} */ + @Override + public void close() { + predictor.close(); + } } } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java index e147f9b984c..116e85896bc 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java @@ -34,12 +34,12 @@ public class PermanentBatchAggregator extends BatchAggregator { /** * Constructs a {@code PermanentBatchAggregator} instance. * - * @param model the model to use. + * @param wpc the workerPoolConfig to use. * @param jobQueue the job queue for polling data from. */ public PermanentBatchAggregator( - ModelInfo model, LinkedBlockingDeque> jobQueue) { - super(model, jobQueue); + WorkerPoolConfig wpc, LinkedBlockingDeque> jobQueue) { + super(wpc, jobQueue); } /** {@inheritDoc} */ diff --git a/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java index 11a04f9ee5e..d505f33a69a 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java @@ -37,14 +37,14 @@ public class TemporaryBatchAggregator extends BatchAggregator { /** * a batch aggregator that terminates after a maximum idle time. * - * @param model the model to run for. + * @param wpc the workerPoolConfig to run for. * @param jobQueue reference to external job queue for polling. */ public TemporaryBatchAggregator( - ModelInfo model, LinkedBlockingDeque> jobQueue) { - super(model, jobQueue); + WorkerPoolConfig wpc, LinkedBlockingDeque> jobQueue) { + super(wpc, jobQueue); this.idleSince = System.currentTimeMillis(); - this.maxIdleSeconds = model.getMaxIdleSeconds(); + this.maxIdleSeconds = wpc.getMaxIdleSeconds(); } /** {@inheritDoc} */ diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index 36f0d6554c1..e6e09d12546 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -29,7 +29,7 @@ /** * WorkLoadManager is responsible to manage the work load of worker thread. the manage scales - * up/down the required amount of worker threads per model. + * up/down the required amount of worker threads per wpc. * * @author erik.bamberg@web.de */ @@ -38,7 +38,7 @@ public class WorkLoadManager { private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); private ExecutorService threadPool; - private ConcurrentHashMap, WorkerPool> workerPools; + private ConcurrentHashMap, WorkerPool> workerPools; /** Constructs a {@link WorkLoadManager} instance. */ public WorkLoadManager() { @@ -47,34 +47,33 @@ public WorkLoadManager() { } /** - * Registers a model and returns the {@link WorkerPool} for it. + * Registers a {@link WorkerPool} (model). * - *

This operation is idempotent and will return the existing workerpool if the model was + *

This operation is idempotent and will return the existing workerpool if the wpc was * already registered. * - * @param the model input class - * @param the model output class - * @param modelInfo the model to create the worker pool for + * @param the wpc input class + * @param the wpc output class + * @param wpc the wpc to create the worker pool for * @return the {@link WorkerPool} */ @SuppressWarnings("unchecked") - public WorkerPool registerModel(ModelInfo modelInfo) { + public WorkerPool registerWorkerPool(WorkerPoolConfig wpc) { return (WorkerPool) - workerPools.computeIfAbsent( - modelInfo, k -> new WorkerPool<>(modelInfo, threadPool)); + workerPools.computeIfAbsent(wpc, k -> new WorkerPool<>(wpc, threadPool)); } /** - * Removes a model from management. + * Removes a worker pool from management. * - * @param model the model to remove + * @param wpc the wpc to remove */ - public void unregisterModel(ModelInfo model) { - WorkerPool pool = getWorkerPool(model); + public void unregisterWorkerPool(WorkerPoolConfig wpc) { + WorkerPool pool = getWorkerPool(wpc); if (pool.decreaseRef() <= 0) { - logger.info("Unloading model: {}", model); + logger.info("Unloading model: {}", wpc); pool.shutdownWorkers(); - workerPools.remove(model); + workerPools.remove(wpc); } } @@ -82,24 +81,24 @@ public void unregisterModel(ModelInfo model) { * Adds an inference job to the job queue of the next free worker. scales up worker if * necessary. * - * @param the model input class - * @param the model output class + * @param the wpc input class + * @param the wpc output class * @param job an inference job to be executed. * @return {@code true} if submit success, false otherwise. */ public CompletableFuture runJob(Job job) { CompletableFuture result = new CompletableFuture<>(); - ModelInfo modelInfo = job.getModel(); - if (modelInfo.getStatus() != ModelInfo.Status.READY) { - result.completeExceptionally(new WlmException("Model is not ready: " + modelInfo)); + WorkerPoolConfig wpc = job.getWpc(); + if (wpc.getStatus() != WorkerPoolConfig.Status.READY) { + result.completeExceptionally(new WlmException("Model is not ready: " + wpc)); return result; } - WorkerPool pool = getWorkerPool(modelInfo); + WorkerPool pool = getWorkerPool(wpc); int maxWorkers = pool.getMaxWorkers(); if (maxWorkers == 0) { result.completeExceptionally( - new WlmShutdownException("All model workers has been shutdown: " + modelInfo)); + new WlmShutdownException("All model workers has been shutdown: " + wpc)); return result; } LinkedBlockingDeque> queue = pool.getJobQueue(); @@ -107,41 +106,39 @@ public CompletableFuture runJob(Job job) { || pool.isAllWorkerDied() || !queue.offer(new WorkerJob<>(job, result))) { result.completeExceptionally( - new WlmCapacityException( - "Worker queue capacity exceeded for model: " + modelInfo)); - scaleUp(pool, modelInfo, maxWorkers); + new WlmCapacityException("Worker queue capacity exceeded for model: " + wpc)); + scaleUp(pool, wpc, maxWorkers); return result; } - int currentWorkers = getNumRunningWorkers(modelInfo); + int currentWorkers = getNumRunningWorkers(wpc); if (currentWorkers == 0 - || currentWorkers < maxWorkers && queue.size() > modelInfo.getBatchSize() * 2) { - scaleUp(pool, modelInfo, maxWorkers); + || currentWorkers < maxWorkers && queue.size() > wpc.getBatchSize() * 2) { + scaleUp(pool, wpc, maxWorkers); } return result; } - private void scaleUp(WorkerPool pool, ModelInfo modelInfo, int maxWorkers) { + private void scaleUp(WorkerPool pool, WorkerPoolConfig wpc, int maxWorkers) { synchronized (pool) { - int currentWorkers = getNumRunningWorkers(modelInfo); // check again + int currentWorkers = getNumRunningWorkers(wpc); // check again if (currentWorkers < maxWorkers) { - logger.info( - "Scaling up workers for model {} to {} ", modelInfo, currentWorkers + 1); + logger.info("Scaling up workers for model {} to {} ", wpc, currentWorkers + 1); pool.addThreads(); } } } /** - * Returns the number of running workers of a model. running workers are workers which are not + * Returns the number of running workers of a wpc. running workers are workers which are not * stopped, in error or scheduled to scale down. * - * @param modelInfo the model we are interested in. + * @param wpc the wpc we are interested in. * @return number of running workers. */ - public int getNumRunningWorkers(ModelInfo modelInfo) { + public int getNumRunningWorkers(WorkerPoolConfig wpc) { int numWorking = 0; - WorkerPool pool = workerPools.get(modelInfo); + WorkerPool pool = workerPools.get(wpc); if (pool != null) { pool.cleanup(); List> threads = pool.getWorkers(); @@ -157,19 +154,19 @@ public int getNumRunningWorkers(ModelInfo modelInfo) { } /** - * Returns the {@link WorkerPool} for a model. + * Returns the {@link WorkerPool} for a wpc. * - * @param the model input class - * @param the model output class - * @param modelInfo the model to get the worker pool for + * @param the wpc input class + * @param the wpc output class + * @param wpc the worker type to get the worker pool for * @return the {@link WorkerPool} */ @SuppressWarnings("unchecked") - public WorkerPool getWorkerPool(ModelInfo modelInfo) { - return (WorkerPool) workerPools.get(modelInfo); + public WorkerPool getWorkerPool(WorkerPoolConfig wpc) { + return (WorkerPool) workerPools.get(wpc); } - /** Close all models related to the {@code WorkloadManager}. */ + /** Close all wpcs related to the {@code WorkloadManager}. */ public void close() { threadPool.shutdownNow(); for (WorkerPool wp : workerPools.values()) { diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerGroup.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerGroup.java index 4625a7c2316..8d5b9889c23 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkerGroup.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerGroup.java @@ -31,11 +31,11 @@ public class WorkerGroup { this.workerPool = workerPool; this.device = device; workers = new CopyOnWriteArrayList<>(); - ModelInfo model = workerPool.getModel(); + WorkerPoolConfig wpc = workerPool.getWpc(); - // Default workers from model, may be overridden by configureWorkers on init or scale - minWorkers = model.getMinWorkers(device); - maxWorkers = model.getMaxWorkers(device); + // Default workers from worker type, may be overridden by configureWorkers on init or scale + minWorkers = wpc.getMinWorkers(device); + maxWorkers = wpc.getMaxWorkers(device); minWorkers = Math.min(minWorkers, maxWorkers); } @@ -91,11 +91,11 @@ public void configureWorkers(int minWorkers, int maxWorkers) { } void addThreads(int count, boolean permanent) { - ModelInfo model = workerPool.getModel(); + WorkerPoolConfig wpc = workerPool.getWpc(); ExecutorService threadPool = workerPool.getThreadPool(); for (int i = 0; i < count; ++i) { WorkerThread thread = - WorkerThread.builder(model) + WorkerThread.builder(wpc) .setDevice(device) .setJobQueue(workerPool.getJobQueue()) .optFixPoolThread(permanent) diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java index 89274d40f48..e4ff7e952d5 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java @@ -40,7 +40,7 @@ public class WorkerPool { private static final Logger logger = LoggerFactory.getLogger(WorkerPool.class); - private final ModelInfo model; + private final WorkerPoolConfig wpc; private ExecutorService threadPool; private Map> workerGroups; private LinkedBlockingDeque> jobQueue; @@ -49,11 +49,11 @@ public class WorkerPool { /** * Construct and initial data structure. * - * @param model the model this WorkerPool belongs to + * @param wpc the model this WorkerPool belongs to * @param threadPool the thread pool executor */ - WorkerPool(ModelInfo model, ExecutorService threadPool) { - this.model = model; + WorkerPool(WorkerPoolConfig wpc, ExecutorService threadPool) { + this.wpc = wpc; this.threadPool = threadPool; workerGroups = new ConcurrentHashMap<>(); refCnt = new AtomicInteger(1); @@ -78,8 +78,8 @@ public int decreaseRef() { * * @return the model of the worker pool */ - public ModelInfo getModel() { - return model; + public WorkerPoolConfig getWpc() { + return wpc; } ExecutorService getThreadPool() { @@ -188,23 +188,23 @@ public void initWorkers(String deviceName) { */ public void initWorkers(String deviceName, int minWorkers, int maxWorkers) { Device device; - synchronized (model) { + synchronized (wpc) { try { - model.initialize(); - device = model.withDefaultDevice(deviceName); - logger.info("loading model {} on {} ...", model, device); - model.load(device); + wpc.initialize(); + device = wpc.withDefaultDevice(deviceName); + logger.info("loading model {} on {} ...", wpc, device); + wpc.load(device); } catch (ModelException | IOException e) { throw new CompletionException(e); } - if (model.getStatus() != ModelInfo.Status.READY) { - logger.warn("Cannot scale workers while model is not READY: {}", model); + if (wpc.getStatus() != WorkerPoolConfig.Status.READY) { + logger.warn("Cannot scale workers while model is not READY: {}", wpc); } } // jobQueue should be initialized after model is configure if (jobQueue == null) { - jobQueue = new LinkedBlockingDeque<>(model.getQueueSize()); + jobQueue = new LinkedBlockingDeque<>(wpc.getQueueSize()); } cleanup(); @@ -267,7 +267,7 @@ private void doScaleWorker(WorkerGroup group) { /** Shutdown all works. */ public void shutdownWorkers() { - synchronized (model) { + synchronized (wpc) { List> threads = getWorkers(); for (WorkerThread thread : threads) { thread.shutdown(WorkerState.WORKER_SCALED_DOWN); @@ -285,7 +285,7 @@ public void cleanup() { /** Shuts down all the worker threads in the work pool. */ public void shutdown() { - model.close(); + wpc.close(); for (WorkerGroup group : workerGroups.values()) { for (WorkerThread worker : group.workers) { worker.shutdown(WorkerState.WORKER_STOPPED); @@ -338,7 +338,7 @@ private void log() { buf.append("-tmpPool\n"); } }); - logger.debug("worker pool for model {}:\n {}", model, buf); + logger.debug("worker pool for model {}:\n {}", wpc, buf); } } } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPoolConfig.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPoolConfig.java new file mode 100644 index 00000000000..0e2f61a527c --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPoolConfig.java @@ -0,0 +1,341 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.translate.TranslateException; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * A {@link WorkerPoolConfig} represents a task that could be run in the {@link WorkLoadManager}. + * + *

Each {@link WorkerThread} (also {@link WorkerPool} and {@link WorkerGroup}) focuses on + * executing a single worker type. They contain the configuration for the thread, any persistent + * data, and the code to run on the thread. + * + * @param the input type + * @param the output type + */ +public abstract class WorkerPoolConfig { + + protected transient String id; + protected String version; + protected String modelUrl; + protected int queueSize; + protected int batchSize; + protected int maxBatchDelayMillis; + protected int maxIdleSeconds; + protected Integer minWorkers; // Integer so it becomes null when parsed from JSON + protected Integer maxWorkers; // Integer so it becomes null when parsed from JSON + + /** + * Loads the worker type to the specified device. + * + * @param device the device to load worker type on + * @throws IOException if failed to read worker type file + * @throws ModelException if failed to load the specified model + */ + public abstract void load(Device device) throws ModelException, IOException; + + /** + * Starts a new {@link WorkerThread} for this {@link WorkerPoolConfig}. + * + * @param device the device to run on + * @return the new {@link ThreadType} + */ + public abstract ThreadType newThread(Device device); + + /** + * Initialize the worker. + * + * @throws IOException if failed to download worker + * @throws ModelNotFoundException if model not found + */ + public abstract void initialize() throws IOException, ModelException; + + /** Close all loaded workers. */ + public abstract void close(); + + /** + * Returns the default device for this model if device is null. + * + * @param deviceName the device to use if it is not null + * @return a non-null device + */ + public Device withDefaultDevice(String deviceName) { + return Device.fromName(deviceName); + } + + /** + * Returns the worker type loading status. + * + * @return the worker type loading status + */ + public abstract Status getStatus(); + + /** + * Returns if the worker type can be load parallel on multiple devices. + * + * @return if the worker type can be load parallel on multiple devices + */ + public abstract boolean isParallelLoading(); + + /** + * Returns the devices the worker type will be loaded on at startup. + * + * @return the devices the worker type will be loaded on at startup + */ + public abstract String[] getLoadOnDevices(); + + /** + * Sets the worker type ID. + * + * @param id the worker type ID + */ + public void setId(String id) { + this.id = id; + } + + /** + * Returns the worker type ID. + * + * @return the worker type ID + */ + public String getId() { + return id; + } + + /** + * Returns the worker type version. + * + * @return the worker type version + */ + public String getVersion() { + return version; + } + + /** + * Returns the worker type url. + * + * @return the worker type url + */ + public String getModelUrl() { + return modelUrl; + } + + /** + * Sets the configured max idle time in seconds of workers. + * + * @param maxIdleSeconds the configured max idle time in seconds of workers + */ + public void setMaxIdleSeconds(int maxIdleSeconds) { + this.maxIdleSeconds = maxIdleSeconds; + } + + /** + * Returns the configured max idle time in seconds of workers. + * + * @return the max idle time in seconds + */ + public int getMaxIdleSeconds() { + return maxIdleSeconds; + } + + /** + * Sets the configured batch size. + * + * @param batchSize the configured batch size + */ + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + /** + * Returns the configured batch size. + * + * @return the configured batch size + */ + public int getBatchSize() { + return batchSize; + } + + /** + * Sets the maximum delay in milliseconds to aggregate a batch. + * + * @param maxBatchDelayMillis the maximum delay in milliseconds to aggregate a batch + */ + public void setMaxBatchDelayMillis(int maxBatchDelayMillis) { + this.maxBatchDelayMillis = maxBatchDelayMillis; + } + + /** + * Returns the maximum delay in milliseconds to aggregate a batch. + * + * @return the maximum delay in milliseconds to aggregate a batch + */ + public int getMaxBatchDelayMillis() { + return maxBatchDelayMillis; + } + + /** + * Sets the configured size of the workers queue. + * + * @param queueSize the configured size of the workers queue + */ + public void setQueueSize(int queueSize) { + this.queueSize = queueSize; + } + + /** + * Returns the configured size of the workers queue. + * + * @return requested size of the workers queue. + */ + public int getQueueSize() { + return queueSize; + } + + /** + * Sets the starting number of min workers. + * + * @param minWorkers Sets the starting number of min workers + */ + public void setMinWorkers(int minWorkers) { + if (maxWorkers != null && maxWorkers < minWorkers) { + throw new IllegalArgumentException( + "The max workers for a model or worker can't be smaller than the min workers"); + } + + this.minWorkers = minWorkers; + } + + /** + * Returns the minimum number of workers. + * + * @param device the device to get the min workers for + * @return the minimum number of workers + */ + public int getMinWorkers(Device device) { + return minWorkers; + } + + /** + * Sets the starting number of max workers. + * + * @param maxWorkers Sets the starting number of max workers + */ + public void setMaxWorkers(int maxWorkers) { + if (minWorkers != null && maxWorkers < minWorkers) { + throw new IllegalArgumentException( + "The max workers for a model or worker can't be smaller than the min workers"); + } + if (maxWorkers == 0) { + throw new IllegalArgumentException("Models must have a maxWorkers greater than 0"); + } + + this.maxWorkers = maxWorkers; + } + + /** + * Returns the maximum number of workers. + * + * @param device the device to get the max workers for + * @return the maximum number of workers + */ + public int getMaxWorkers(Device device) { + return maxWorkers; + } + + /** + * Sets the starting minimum and maximum number of workers. + * + * @param minWorkers the new minimum number of workers + * @param maxWorkers the new maximum number of workers + */ + public void setMinMaxWorkers(int minWorkers, int maxWorkers) { + if (maxWorkers < minWorkers) { + throw new IllegalArgumentException( + "The max workers for a model or worker can't be smaller than the min workers"); + } + if (minWorkers == 0) { + throw new IllegalArgumentException( + "Having a minWorkers of 0 is not currently supported"); + } + if (maxWorkers == 0) { + throw new IllegalArgumentException("Models must have a maxWorkers greater than 0"); + } + + this.minWorkers = minWorkers; + this.maxWorkers = maxWorkers; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof WorkerPoolConfig)) { + return false; + } + WorkerPoolConfig wpc = (WorkerPoolConfig) o; + return id.equals(wpc.id) && Objects.equals(version, wpc.version); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(id, version); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + if (version != null) { + return id + ':' + version + " (" + getStatus() + ')'; + } + return id + " (" + getStatus() + ')'; + } + + /** An enum represents state of a worker type. */ + public enum Status { + PENDING, + READY, + FAILED + } + + protected abstract static class ThreadType { + Device device; + + protected ThreadType(Device device) { + this.device = device; + } + + /** + * Runs the work on the {@link WorkerThread}. + * + * @param input the work input + * @return the computed output + * @throws TranslateException if it failed to compute + */ + public abstract List run(List input) throws TranslateException; + + /** Closes the thread type and frees any resources. */ + public abstract void close(); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java index 34f79e4d492..fcd34fc0c2b 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java @@ -13,9 +13,7 @@ package ai.djl.serving.wlm; import ai.djl.Device; -import ai.djl.inference.Predictor; -import ai.djl.metric.Metrics; -import ai.djl.repository.zoo.ZooModel; +import ai.djl.serving.wlm.WorkerPoolConfig.ThreadType; import ai.djl.serving.wlm.util.WlmException; import ai.djl.serving.wlm.util.WorkerJob; import ai.djl.translate.TranslateException; @@ -33,12 +31,10 @@ public final class WorkerThread implements Runnable { private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class); - private static final Logger MODEL_METRIC = LoggerFactory.getLogger("model_metric"); private String workerName; - private String modelName; - private Predictor predictor; + private ThreadType threadType; private AtomicBoolean running = new AtomicBoolean(true); private BatchAggregator aggregator; @@ -48,8 +44,6 @@ public final class WorkerThread implements Runnable { private int workerId; private long startTime; private boolean fixPoolThread; - private boolean logModelMetric; - private int metricsAggregation; private long stateChangeTime; /** @@ -58,18 +52,13 @@ public final class WorkerThread implements Runnable { * @param builder build a new worker thread using this builder. */ private WorkerThread(Builder builder) { - this.workerName = buildWorkerName(builder.model); + this.workerName = buildWorkerName(builder.workerPoolConfig); this.aggregator = builder.aggregator; this.workerId = new WorkerIdGenerator().generate(); this.startTime = System.currentTimeMillis(); this.fixPoolThread = builder.fixPoolThread; this.device = builder.device; - ZooModel model = builder.model.getModel(device); - - predictor = model.newPredictor(); - modelName = builder.model.getId(); - logModelMetric = Boolean.parseBoolean(model.getProperty("log_model_metric")); - metricsAggregation = Integer.parseInt(model.getProperty("metrics_aggregation", "1000")); + threadType = builder.workerPoolConfig.newThread(device); } /** {@inheritDoc} */ @@ -82,19 +71,12 @@ public void run() { List req = null; String errorMessage = "Worker shutting down"; try { - if (logModelMetric) { - Metrics metrics = new Metrics(); - metrics.setLimit(metricsAggregation); - metrics.setOnLimit( - (m, s) -> MODEL_METRIC.info("{}-{}", modelName, m.percentile(s, 50))); - predictor.setMetrics(metrics); - } while (isRunning() && !aggregator.isFinished()) { req = aggregator.getRequest(); if (req != null && !req.isEmpty()) { state = WorkerState.WORKER_BUSY; try { - List reply = predictor.batchPredict(req); + List reply = threadType.run(req); aggregator.sendResponse(reply); } catch (TranslateException e) { logger.warn("Failed to predict", e); @@ -181,11 +163,11 @@ public void shutdown(WorkerState state) { aggregator.sendError(e); } logger.info("shutdown temporary worker: {}", workerName); - predictor.close(); + threadType.close(); } - private String buildWorkerName(ModelInfo model) { - String modelId = model.getId(); + private String buildWorkerName(WorkerPoolConfig wpc) { + String modelId = wpc.getId(); if (modelId.length() > 25) { modelId = modelId.substring(0, 25); } @@ -222,26 +204,26 @@ public boolean isFixPoolThread() { /** * Creates a builder to build a {@code WorkerThread}. * - * @param the model input class - * @param the model output class - * @param model the {@code ModelInfo} the thread will be responsible for + * @param the workerPoolConfig input class + * @param the workerPoolConfig output class + * @param wpc the {@link WorkerPoolConfig} the thread will be responsible for * @return a new builder */ - public static Builder builder(ModelInfo model) { - return new Builder<>(model); + public static Builder builder(WorkerPoolConfig wpc) { + return new Builder<>(wpc); } /** A Builder to construct a {@code WorkerThread}. */ public static final class Builder { - private ModelInfo model; + private WorkerPoolConfig workerPoolConfig; private Device device; private BatchAggregator aggregator; private LinkedBlockingDeque> jobQueue; private boolean fixPoolThread; - Builder(ModelInfo model) { - this.model = model; + Builder(WorkerPoolConfig wpc) { + this.workerPoolConfig = wpc; this.fixPoolThread = true; } @@ -292,9 +274,9 @@ public WorkerThread build() { throw new IllegalArgumentException("jobQueue has to be set."); } if (fixPoolThread) { - aggregator = new PermanentBatchAggregator<>(model, jobQueue); + aggregator = new PermanentBatchAggregator<>(workerPoolConfig, jobQueue); } else { - aggregator = new TemporaryBatchAggregator<>(model, jobQueue); + aggregator = new TemporaryBatchAggregator<>(workerPoolConfig, jobQueue); } return new WorkerThread<>(this); } diff --git a/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java b/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java index f569471a6c1..07afd397dc5 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java @@ -38,7 +38,7 @@ public void testFromCriteria() throws IOException { .optModelUrls(modelUrl) .build(); ModelInfo modelInfo = new ModelInfo<>("model", modelUrl, criteria); - wlm.registerModel(modelInfo).initWorkers(null, 1, 2); + wlm.registerWorkerPool(modelInfo).initWorkers(null); Input input = new Input(); URL url = new URL("https://resources.djl.ai/images/0.png"); try (InputStream is = url.openStream()) {