Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serving] support load entryPoint with url #566

Merged
merged 1 commit into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions engines/python/setup/djl_python/service_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
from importlib.machinery import SourceFileLoader


class ModelService(object):
Expand All @@ -31,14 +32,20 @@ def invoke_handler(self, function_name, inputs):
def load_model_service(model_dir, entry_point, device_id):
manifest_file = os.path.join(model_dir, "MAR-INF/MANIFEST.json")
if not os.path.exists(manifest_file):
entry_point_file = os.path.join(model_dir, entry_point)
if entry_point_file.endswith(".py"):
entry_point = entry_point[:-3]
if not os.path.exists(entry_point_file):
raise ValueError(
f"entry-point file not found {entry_point_file}.")
if os.path.isabs(entry_point):
if not os.path.exists(entry_point):
raise ValueError(f"entry-point file not found {entry_point}.")
module = SourceFileLoader("model", entry_point).load_module()
else:
if entry_point.endswith(".py"):
entry_point_file = os.path.join(model_dir, entry_point)
entry_point = entry_point[:-3]
if not os.path.exists(entry_point_file):
raise ValueError(
f"entry-point file not found {entry_point_file}.")

module = importlib.import_module(entry_point)

module = importlib.import_module(entry_point)
if module is None:
raise ValueError(
f"Unable to load entry_point {model_dir}/{entry_point}.py")
Expand Down
37 changes: 26 additions & 11 deletions engines/python/src/main/java/ai/djl/python/engine/PyModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.util.DownloadUtils;
import ai.djl.translate.Translator;
import ai.djl.util.Utils;

Expand All @@ -28,11 +29,13 @@
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -157,13 +160,25 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
throw new FileNotFoundException(".py file not found in: " + modelPath);
}
}
} else if (entryPoint.toLowerCase(Locale.ROOT).startsWith("http")) {
logger.info("downloading entryPoint file: {}", entryPoint);
Path modelFile = getDownloadDir().resolve("model.py");
DownloadUtils.download(new URL(entryPoint), modelFile, null);
entryPoint = modelFile.toAbsolutePath().toString();
}
pyEnv.setEntryPoint(entryPoint);

String s3Url = pyEnv.getInitParameters().get("s3url");
if (s3Url != null) {
if (pyEnv.getInitParameters().containsKey("model_id")) {
throw new IllegalArgumentException("model_id and s3url could not both set!");
}

logger.info("S3 url found, start downloading from {}", s3Url);
downloadS3(s3Url);
String downloadDir = getDownloadDir().toString();
downloadS3(s3Url, downloadDir);
// point model_id to download directory
pyEnv.addParameter("model_id", downloadDir);
}

if (pyEnv.isMpiMode()) {
Expand Down Expand Up @@ -312,18 +327,18 @@ private void createAllPyProcesses(int mpiWorkers) {
logger.info("{} model loaded in {} ms.", modelName, duration);
}

private void downloadS3(String url) {
if (pyEnv.getInitParameters().containsKey("model_id")) {
throw new IllegalArgumentException("model_id and s3url could not both set!");
private Path getDownloadDir() throws IOException {
// SageMaker model_dir are readonly, default to use temp directory
Path tmp = Files.createTempDirectory("download").toAbsolutePath();
String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", tmp.toString());
if ("default".equals(downloadDir)) {
downloadDir = modelDir.toAbsolutePath().toString();
}
// TODO: Workaround on SageMaker readonly disk
return Paths.get(downloadDir);
}

private void downloadS3(String url, String downloadDir) {
try {
Path tmp = Files.createTempDirectory("download").toAbsolutePath();
String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", tmp.toString());
if ("default".equals(downloadDir)) {
downloadDir = modelDir.toAbsolutePath().toString();
}
pyEnv.addParameter("model_id", downloadDir);
String[] commands;
if (Files.exists(Paths.get("/opt/djl/bin/s5cmd"))) {
if (!url.endsWith("*")) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
engine=Python
#option.entryPoint=djl_python.huggingface
#option.entryPoint=https://raw.githubusercontent.com/deepjavalibrary/djl-serving/master/engines/python/setup/djl_python/huggingface.py
option.model_id=distilbert-base-cased-distilled-squad
option.task=question-answering
6 changes: 3 additions & 3 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def build_ds_handler_model(model):
)
options = ds_handler_list[model]
options["engine"] = "DeepSpeed"
options["option.entryPoint"] = "djl_python.deepspeed"
# options["option.entryPoint"] = "djl_python.deepspeed"
write_properties(options)


Expand Down Expand Up @@ -241,7 +241,7 @@ def build_ft_handler_model(model):
)
options = ft_handler_list[model]
options["engine"] = "FasterTransformer"
options["option.entryPoint"] = "djl_python.fastertransformer"
# options["option.entryPoint"] = "djl_python.fastertransformer"
write_properties(options)


Expand Down Expand Up @@ -275,7 +275,7 @@ def builder_ft_handler_aot_model(model):
)
options = ft_model_list[model]
options["engine"] = "FasterTransformer"
options["entryPoint"] = "djl_python.fastertransformer"
# options["entryPoint"] = "djl_python.fastertransformer"
options["option.save_mp_checkpoint_path"] = "/opt/ml/model/partition-test"
write_properties(options)

Expand Down