Skip to content

Commit

Permalink
fix client tests and format python code (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Apr 25, 2023
1 parent 446a33e commit 8fb9c87
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
11 changes: 7 additions & 4 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,16 @@ def _parse_properties(self, properties):
self.model_id_or_path = properties.get("model_id") or properties.get(
"model_dir")
self.task = properties.get("task")
self.data_type = get_torch_dtype_from_str(properties.get("dtype", default_dtype()))
self.data_type = get_torch_dtype_from_str(
properties.get("dtype", default_dtype()))
self.max_tokens = int(properties.get("max_tokens", 1024))
self.device = int(os.getenv("LOCAL_RANK", 0))
self.tensor_parallel_degree = int(
properties.get("tensor_parallel_degree", 1))
self.low_cpu_mem_usage = properties.get("low_cpu_mem_usage",
"true").lower() == "true"
self.enable_streaming = properties.get("enable_streaming",
"false").lower() == "true"
"false").lower() == "true"
if properties.get("deepspeed_config_path"):
with open(properties.get("deepspeed_config_path"), "r") as f:
self.ds_config = json.load(f)
Expand Down Expand Up @@ -271,7 +272,8 @@ def inference(self, inputs: Input):
try:
content_type = inputs.get_property("Content-Type")
model_kwargs = {}
if content_type is not None and content_type.startswith("application/json"):
if content_type is not None and content_type.startswith(
"application/json"):
json_input = inputs.get_as_json()
if isinstance(json_input, dict):
input_data = self.format_input_for_task(
Expand All @@ -297,7 +299,8 @@ def inference(self, inputs: Input):
with torch.no_grad():
output_tokens = self.model.generate(
**tokenized_inputs, **model_kwargs)
generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
generated_text = self.tokenizer.batch_decode(
output_tokens, skip_special_tokens=True)
outputs.add([{"generated_text": s} for s in generated_text])
return outputs

Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def get_data(self, key=None):
return self.get_as_numpy(key)
if content_type == "tensor/npz":
return self.get_as_npz(key)
elif content_type is not None and content_type.startswith("application/json"):
elif content_type is not None and content_type.startswith(
"application/json"):
return self.get_as_json(key)
elif content_type is not None and content_type.startswith("text/"):
return self.get_as_string(key)
Expand Down
7 changes: 5 additions & 2 deletions engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,15 @@ def _has_met_stopping_criteria(not_eos_token_ids, current_token_count,
def _validate_inputs(model, inputs):
if not model.config.architectures:
## do best effort validation as there is no simple way to cover all the cases
logging.warning(f"Model config does not contain architectures field. Supported architectures: *{StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES}")
logging.warning(
f"Model config does not contain architectures field. Supported architectures: *{StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES}"
)
model_arch_supported = True
else:
model_arch_list = model.config.architectures
model_arch_supported = any(
model_arch.endswith(StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES)
model_arch.endswith(
StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES)
for model_arch in model_arch_list)
if not model_arch_supported:
assert False, f"model archs: {model_arch_list} is not in supported list: *{StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES}"
Expand Down
5 changes: 3 additions & 2 deletions serving/docker/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def download_model_from_s3(self):
if not model_id or not model_id.startswith("s3://"):
return

download_dir = os.environ.get("SERVING_DOWNLOAD_DIR",
get_download_dir(properties_manager.properties_dir, 'model'))
download_dir = os.environ.get(
"SERVING_DOWNLOAD_DIR",
get_download_dir(properties_manager.properties_dir, 'model'))

s3url = model_id
if Path("/opt/djl/bin/s5cmd").is_file():
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@
"stream_output": True,
},
"no-code/nomic-ai/gpt4all-j": {
"max_memory_per_gpu": [10.0, 10.0, 11.0, 12.0],
"max_memory_per_gpu": [10.0, 12.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 1
},
"no-code/databricks/dolly-v2-7b": {
"max_memory_per_gpu": [10.0, 10.0, 12.0, 12.0],
"max_memory_per_gpu": [10.0, 12.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 2,
},
"no-code/google/flan-t5-xl": {
"max_memory_per_gpu": [7.0, 7.0, 7.0, 7.0],
"max_memory_per_gpu": [7.0, 7.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 2
Expand Down

0 comments on commit 8fb9c87

Please sign in to comment.