From b50432e87be5d92020380ac4da9ac1a0309316eb Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 22 Mar 2023 23:34:02 -0700 Subject: [PATCH] [serving] support load entryPoint with url --- .../python/setup/djl_python/service_loader.py | 21 +++++++---- .../java/ai/djl/python/engine/PyModel.java | 37 +++++++++++++------ .../resources/huggingface/serving.properties | 1 + tests/integration/llm/prepare.py | 6 +-- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/engines/python/setup/djl_python/service_loader.py b/engines/python/setup/djl_python/service_loader.py index bb2bb1379..b46651b6e 100644 --- a/engines/python/setup/djl_python/service_loader.py +++ b/engines/python/setup/djl_python/service_loader.py @@ -15,6 +15,7 @@ import json import logging import os +from importlib.machinery import SourceFileLoader class ModelService(object): @@ -31,14 +32,20 @@ def invoke_handler(self, function_name, inputs): def load_model_service(model_dir, entry_point, device_id): manifest_file = os.path.join(model_dir, "MAR-INF/MANIFEST.json") if not os.path.exists(manifest_file): - entry_point_file = os.path.join(model_dir, entry_point) - if entry_point_file.endswith(".py"): - entry_point = entry_point[:-3] - if not os.path.exists(entry_point_file): - raise ValueError( - f"entry-point file not found {entry_point_file}.") + if os.path.isabs(entry_point): + if not os.path.exists(entry_point): + raise ValueError(f"entry-point file not found {entry_point}.") + module = SourceFileLoader("model", entry_point).load_module() + else: + if entry_point.endswith(".py"): + entry_point_file = os.path.join(model_dir, entry_point) + entry_point = entry_point[:-3] + if not os.path.exists(entry_point_file): + raise ValueError( + f"entry-point file not found {entry_point_file}.") + + module = importlib.import_module(entry_point) - module = importlib.import_module(entry_point) if module is None: raise ValueError( f"Unable to load entry_point {model_dir}/{entry_point}.py") diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index 3eaeba399..79acbe1c5 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -19,6 +19,7 @@ import ai.djl.inference.Predictor; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; +import ai.djl.training.util.DownloadUtils; import ai.djl.translate.Translator; import ai.djl.util.Utils; @@ -28,11 +29,13 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; +import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -157,13 +160,25 @@ public void load(Path modelPath, String prefix, Map options) throws I throw new FileNotFoundException(".py file not found in: " + modelPath); } } + } else if (entryPoint.toLowerCase(Locale.ROOT).startsWith("http")) { + logger.info("downloading entryPoint file: {}", entryPoint); + Path modelFile = getDownloadDir().resolve("model.py"); + DownloadUtils.download(new URL(entryPoint), modelFile, null); + entryPoint = modelFile.toAbsolutePath().toString(); } pyEnv.setEntryPoint(entryPoint); String s3Url = pyEnv.getInitParameters().get("s3url"); if (s3Url != null) { + if (pyEnv.getInitParameters().containsKey("model_id")) { + throw new IllegalArgumentException("model_id and s3url could not both set!"); + } + logger.info("S3 url found, start downloading from {}", s3Url); - downloadS3(s3Url); + String downloadDir = getDownloadDir().toString(); + downloadS3(s3Url, downloadDir); + // point model_id to download directory + pyEnv.addParameter("model_id", downloadDir); } if (pyEnv.isMpiMode()) { @@ -312,18 +327,18 @@ private void createAllPyProcesses(int mpiWorkers) { logger.info("{} model loaded in {} ms.", modelName, duration); } - private void downloadS3(String url) { - if (pyEnv.getInitParameters().containsKey("model_id")) { - throw new IllegalArgumentException("model_id and s3url could not both set!"); + private Path getDownloadDir() throws IOException { + // SageMaker model_dir are readonly, default to use temp directory + Path tmp = Files.createTempDirectory("download").toAbsolutePath(); + String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", tmp.toString()); + if ("default".equals(downloadDir)) { + downloadDir = modelDir.toAbsolutePath().toString(); } - // TODO: Workaround on SageMaker readonly disk + return Paths.get(downloadDir); + } + + private void downloadS3(String url, String downloadDir) { try { - Path tmp = Files.createTempDirectory("download").toAbsolutePath(); - String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", tmp.toString()); - if ("default".equals(downloadDir)) { - downloadDir = modelDir.toAbsolutePath().toString(); - } - pyEnv.addParameter("model_id", downloadDir); String[] commands; if (Files.exists(Paths.get("/opt/djl/bin/s5cmd"))) { if (!url.endsWith("*")) { diff --git a/engines/python/src/test/resources/huggingface/serving.properties b/engines/python/src/test/resources/huggingface/serving.properties index a9293de8a..8c14ebd27 100644 --- a/engines/python/src/test/resources/huggingface/serving.properties +++ b/engines/python/src/test/resources/huggingface/serving.properties @@ -1,4 +1,5 @@ engine=Python #option.entryPoint=djl_python.huggingface +#option.entryPoint=https://raw.githubusercontent.com/deepjavalibrary/djl-serving/master/engines/python/setup/djl_python/huggingface.py option.model_id=distilbert-base-cased-distilled-squad option.task=question-answering diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index 4920551c1..c197a00d1 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -199,7 +199,7 @@ def build_ds_handler_model(model): ) options = ds_handler_list[model] options["engine"] = "DeepSpeed" - options["option.entryPoint"] = "djl_python.deepspeed" + # options["option.entryPoint"] = "djl_python.deepspeed" write_properties(options) @@ -241,7 +241,7 @@ def build_ft_handler_model(model): ) options = ft_handler_list[model] options["engine"] = "FasterTransformer" - options["option.entryPoint"] = "djl_python.fastertransformer" + # options["option.entryPoint"] = "djl_python.fastertransformer" write_properties(options) @@ -275,7 +275,7 @@ def builder_ft_handler_aot_model(model): ) options = ft_model_list[model] options["engine"] = "FasterTransformer" - options["entryPoint"] = "djl_python.fastertransformer" + # options["entryPoint"] = "djl_python.fastertransformer" options["option.save_mp_checkpoint_path"] = "/opt/ml/model/partition-test" write_properties(options)