Skip to content

Commit

Permalink
[Handler] add dynamic batching to transformers neuronx (deepjavalibra…
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Sep 14, 2023
1 parent 2f4e987 commit 296d759
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 58 deletions.
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/stable_diffusion_inf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def load_compiled(self, saved_dir):
unet_filename = os.path.join(saved_dir, 'unet.pt')
self.pipeline.unet.unetwrap = torch.jit.load(unet_filename)

def infer(self, inputs: Input):
def inference(self, inputs: Input):
try:
content_type = inputs.get_property("Content-Type")
if content_type == "application/json":
Expand Down
179 changes: 123 additions & 56 deletions engines/python/setup/djl_python/transformers-neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import tempfile
import os
import logging
import torch

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from transformers_neuronx import dtypes
Expand All @@ -26,6 +25,7 @@
from transformers_neuronx.bloom.model import BloomForSampling
from transformers_neuronx.module import save_pretrained_split
from djl_python import Input, Output
from djl_python.encode_decode import decode, encode
from djl_python.stable_diffusion_inf2 import StableDiffusionService
from djl_python.streaming_utils import StreamingUtils

Expand Down Expand Up @@ -213,62 +213,129 @@ def initialize(self, properties):
self.model)
self.initialized = True

def infer(self, inputs):
try:
input_map = inputs.get_as_json()
input_text = input_map.pop("inputs", input_map)
parameters = input_map.pop("parameters", {})
if isinstance(input_text, str):
input_text = [input_text]
if len(input_text) != self.batch_size:
raise ValueError(
f"{self.batch_size} batch size not equal to {len(input_text)} prompt size"
)
outputs = Output()
model_kwargs = {}

if self.enable_streaming:
outputs.add_property("content-type", "application/jsonlines")
if self.enable_streaming == "huggingface":
outputs.add_stream_content(
StreamingUtils.use_hf_default_streamer(
self.model, self.tokenizer, input_text, None,
**model_kwargs))
def parse_input(self, inputs):
input_data = []
input_size = []
parameters = []
errors = {}
batch = inputs.get_batches()
first = True
for i, item in enumerate(batch):
try:
content_type = item.get_property("Content-Type")
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
if first:
parameters.append(input_map.pop("parameters", {}))
first = False
else:
stream_generator = StreamingUtils.get_stream_generator(
"transformers-neuronx")
model_kwargs["engine"] = "transformers-neuronx"
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer,
input_text, "cpu", **model_kwargs))
return outputs

encoded_inputs = self.tokenizer.batch_encode_plus(
input_text, return_tensors="pt", padding=True)
use_sample = parameters.pop("use_sample", None)
if use_sample:
# TODO: Watch transformer-neuronx release for fix on gpt-neox generate functionality
output_tokens = self.model.sample(
encoded_inputs.input_ids,
sequence_length=self.n_positions,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
**parameters)
param = input_map.pop("parameters", {})
if parameters[0] != param:
logging.warning(
f"expected param: {parameters}, actual: {param}")
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)
if isinstance(_inputs, list):
input_data.extend(_inputs)
input_size.append(len(_inputs))
else:
input_data.append(_inputs)
input_size.append(1)
except Exception as e: # pylint: disable=broad-except
logging.exception(f"Parse input failed: {i}")
errors[i] = str(e)

return input_data, input_size, parameters, errors, batch

def inference(self, inputs):
input_data, input_size, parameters, errors, batch = self.parse_input(
inputs)
parameters = parameters[0]

outputs = Output()
model_kwargs = {}

prompt_size = len(input_data)
if prompt_size > self.batch_size:
raise ValueError(
f"Batch size {prompt_size} beyond the max_batch size the model can support {self.batch_size}"
)

for i in range(prompt_size, self.batch_size):
input_data.append(self.tokenizer.eos_token)

# clean KV cache
self.model.reset_generation()
if self.enable_streaming:
if batch > 1:
raise NotImplementedError(
"Dynamic batch not supported for generic streaming")
outputs.add_property("content-type", "application/jsonlines")
if self.enable_streaming == "huggingface":
outputs.add_stream_content(
StreamingUtils.use_hf_default_streamer(
self.model, self.tokenizer, input_data, None,
**model_kwargs))
else:
output_tokens = self.model.generate(
input_ids=encoded_inputs.input_ids,
attention_mask=encoded_inputs.attention_mask,
**parameters)
generated_text = self.tokenizer.batch_decode(
output_tokens, skip_special_tokens=True)

return Output().add([{
"generated_text": s
} for s in generated_text])

except Exception as e:
logging.exception("TransformerNeuronX inference failed")
outputs = Output().error((str(e)))
stream_generator = StreamingUtils.get_stream_generator(
"transformers-neuronx")
model_kwargs["engine"] = "transformers-neuronx"
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer, input_data,
"cpu", **model_kwargs))
return outputs

encoded_inputs = self.tokenizer.batch_encode_plus(input_data,
return_tensors="pt",
padding=True)
use_sample = parameters.pop("use_sample", None)
if use_sample:
# TODO: Watch transformer-neuronx release for fix on gpt-neox generate functionality
output_tokens = self.model.sample(
encoded_inputs.input_ids,
sequence_length=self.n_positions,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
**parameters)
else:
output_tokens = self.model.generate(
input_ids=encoded_inputs.input_ids,
attention_mask=encoded_inputs.attention_mask,
**parameters)
prediction = self.tokenizer.batch_decode(output_tokens,
skip_special_tokens=True)

# trim the input based on the actual size
prediction = prediction[:prompt_size]
prediction = [{"generated_text": s} for s in prediction]

offset = 0
for i, item in enumerate(batch):
content_type = item.get_property("Content-Type")
accept = item.get_property("Accept")
if not accept:
content_type = content_type if content_type else "application/json"
accept = content_type if content_type.startswith(
"tensor/") else "application/json"
elif "*/*" in accept:
accept = "application/json"

err = errors.get(i)
if err:
encode(outputs,
err,
accept,
key=inputs.get_content().key_at(i))
else:
encode(outputs,
prediction[offset:offset + input_size[i]],
accept,
key=inputs.get_content().key_at(i))
offset += input_size[i]

outputs.add_property("content-type", "application/json")

return outputs


Expand All @@ -286,4 +353,4 @@ def handle(inputs: Input):
# Model server makes an empty call to warm up the model on startup
return None

return _service.infer(inputs)
return _service.inference(inputs)
2 changes: 1 addition & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def test_handler(model, model_spec):
for memory in memory_usage:
assert float(memory) / 1024.0 < spec["max_memory_per_gpu"][i]


def test_vllm_handler(model, model_spec):
if model not in model_spec:
raise ValueError(
Expand All @@ -629,7 +630,6 @@ def test_vllm_handler(model, model_spec):
) <= seq_length, "generated more tokens than max_new_tokens"



def test_ds_raw_model(model, model_spec):
if model not in model_spec:
raise ValueError(
Expand Down

0 comments on commit 296d759

Please sign in to comment.