From 296d7599c9b6792d1db5e385da3b7f070d2e6a46 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Wed, 13 Sep 2023 22:41:19 -0700 Subject: [PATCH] [Handler] add dynamic batching to transformers neuronx (#1076) --- .../setup/djl_python/stable_diffusion_inf2.py | 2 +- .../setup/djl_python/transformers-neuronx.py | 179 ++++++++++++------ tests/integration/llm/client.py | 2 +- 3 files changed, 125 insertions(+), 58 deletions(-) diff --git a/engines/python/setup/djl_python/stable_diffusion_inf2.py b/engines/python/setup/djl_python/stable_diffusion_inf2.py index 34dec6285de..6b78e903765 100644 --- a/engines/python/setup/djl_python/stable_diffusion_inf2.py +++ b/engines/python/setup/djl_python/stable_diffusion_inf2.py @@ -276,7 +276,7 @@ def load_compiled(self, saved_dir): unet_filename = os.path.join(saved_dir, 'unet.pt') self.pipeline.unet.unetwrap = torch.jit.load(unet_filename) - def infer(self, inputs: Input): + def inference(self, inputs: Input): try: content_type = inputs.get_property("Content-Type") if content_type == "application/json": diff --git a/engines/python/setup/djl_python/transformers-neuronx.py b/engines/python/setup/djl_python/transformers-neuronx.py index 3f0c0433912..7646da8b2bf 100644 --- a/engines/python/setup/djl_python/transformers-neuronx.py +++ b/engines/python/setup/djl_python/transformers-neuronx.py @@ -13,7 +13,6 @@ import tempfile import os import logging -import torch from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer from transformers_neuronx import dtypes @@ -26,6 +25,7 @@ from transformers_neuronx.bloom.model import BloomForSampling from transformers_neuronx.module import save_pretrained_split from djl_python import Input, Output +from djl_python.encode_decode import decode, encode from djl_python.stable_diffusion_inf2 import StableDiffusionService from djl_python.streaming_utils import StreamingUtils @@ -213,62 +213,129 @@ def initialize(self, properties): self.model) self.initialized = True - def infer(self, inputs): - try: - input_map = inputs.get_as_json() - input_text = input_map.pop("inputs", input_map) - parameters = input_map.pop("parameters", {}) - if isinstance(input_text, str): - input_text = [input_text] - if len(input_text) != self.batch_size: - raise ValueError( - f"{self.batch_size} batch size not equal to {len(input_text)} prompt size" - ) - outputs = Output() - model_kwargs = {} - - if self.enable_streaming: - outputs.add_property("content-type", "application/jsonlines") - if self.enable_streaming == "huggingface": - outputs.add_stream_content( - StreamingUtils.use_hf_default_streamer( - self.model, self.tokenizer, input_text, None, - **model_kwargs)) + def parse_input(self, inputs): + input_data = [] + input_size = [] + parameters = [] + errors = {} + batch = inputs.get_batches() + first = True + for i, item in enumerate(batch): + try: + content_type = item.get_property("Content-Type") + input_map = decode(item, content_type) + _inputs = input_map.pop("inputs", input_map) + if first: + parameters.append(input_map.pop("parameters", {})) + first = False else: - stream_generator = StreamingUtils.get_stream_generator( - "transformers-neuronx") - model_kwargs["engine"] = "transformers-neuronx" - outputs.add_stream_content( - stream_generator(self.model, self.tokenizer, - input_text, "cpu", **model_kwargs)) - return outputs - - encoded_inputs = self.tokenizer.batch_encode_plus( - input_text, return_tensors="pt", padding=True) - use_sample = parameters.pop("use_sample", None) - if use_sample: - # TODO: Watch transformer-neuronx release for fix on gpt-neox generate functionality - output_tokens = self.model.sample( - encoded_inputs.input_ids, - sequence_length=self.n_positions, - pad_token_id=self.tokenizer.pad_token_id, - eos_token_id=self.tokenizer.eos_token_id, - **parameters) + param = input_map.pop("parameters", {}) + if parameters[0] != param: + logging.warning( + f"expected param: {parameters}, actual: {param}") + raise ValueError( + "In order to enable dynamic batching, all input batches must have the same parameters" + ) + if isinstance(_inputs, list): + input_data.extend(_inputs) + input_size.append(len(_inputs)) + else: + input_data.append(_inputs) + input_size.append(1) + except Exception as e: # pylint: disable=broad-except + logging.exception(f"Parse input failed: {i}") + errors[i] = str(e) + + return input_data, input_size, parameters, errors, batch + + def inference(self, inputs): + input_data, input_size, parameters, errors, batch = self.parse_input( + inputs) + parameters = parameters[0] + + outputs = Output() + model_kwargs = {} + + prompt_size = len(input_data) + if prompt_size > self.batch_size: + raise ValueError( + f"Batch size {prompt_size} beyond the max_batch size the model can support {self.batch_size}" + ) + + for i in range(prompt_size, self.batch_size): + input_data.append(self.tokenizer.eos_token) + + # clean KV cache + self.model.reset_generation() + if self.enable_streaming: + if batch > 1: + raise NotImplementedError( + "Dynamic batch not supported for generic streaming") + outputs.add_property("content-type", "application/jsonlines") + if self.enable_streaming == "huggingface": + outputs.add_stream_content( + StreamingUtils.use_hf_default_streamer( + self.model, self.tokenizer, input_data, None, + **model_kwargs)) else: - output_tokens = self.model.generate( - input_ids=encoded_inputs.input_ids, - attention_mask=encoded_inputs.attention_mask, - **parameters) - generated_text = self.tokenizer.batch_decode( - output_tokens, skip_special_tokens=True) - - return Output().add([{ - "generated_text": s - } for s in generated_text]) - - except Exception as e: - logging.exception("TransformerNeuronX inference failed") - outputs = Output().error((str(e))) + stream_generator = StreamingUtils.get_stream_generator( + "transformers-neuronx") + model_kwargs["engine"] = "transformers-neuronx" + outputs.add_stream_content( + stream_generator(self.model, self.tokenizer, input_data, + "cpu", **model_kwargs)) + return outputs + + encoded_inputs = self.tokenizer.batch_encode_plus(input_data, + return_tensors="pt", + padding=True) + use_sample = parameters.pop("use_sample", None) + if use_sample: + # TODO: Watch transformer-neuronx release for fix on gpt-neox generate functionality + output_tokens = self.model.sample( + encoded_inputs.input_ids, + sequence_length=self.n_positions, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + **parameters) + else: + output_tokens = self.model.generate( + input_ids=encoded_inputs.input_ids, + attention_mask=encoded_inputs.attention_mask, + **parameters) + prediction = self.tokenizer.batch_decode(output_tokens, + skip_special_tokens=True) + + # trim the input based on the actual size + prediction = prediction[:prompt_size] + prediction = [{"generated_text": s} for s in prediction] + + offset = 0 + for i, item in enumerate(batch): + content_type = item.get_property("Content-Type") + accept = item.get_property("Accept") + if not accept: + content_type = content_type if content_type else "application/json" + accept = content_type if content_type.startswith( + "tensor/") else "application/json" + elif "*/*" in accept: + accept = "application/json" + + err = errors.get(i) + if err: + encode(outputs, + err, + accept, + key=inputs.get_content().key_at(i)) + else: + encode(outputs, + prediction[offset:offset + input_size[i]], + accept, + key=inputs.get_content().key_at(i)) + offset += input_size[i] + + outputs.add_property("content-type", "application/json") + return outputs @@ -286,4 +353,4 @@ def handle(inputs: Input): # Model server makes an empty call to warm up the model on startup return None - return _service.infer(inputs) + return _service.inference(inputs) diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 4c41a5463e3..db3d5029539 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -607,6 +607,7 @@ def test_handler(model, model_spec): for memory in memory_usage: assert float(memory) / 1024.0 < spec["max_memory_per_gpu"][i] + def test_vllm_handler(model, model_spec): if model not in model_spec: raise ValueError( @@ -629,7 +630,6 @@ def test_vllm_handler(model, model_spec): ) <= seq_length, "generated more tokens than max_new_tokens" - def test_ds_raw_model(model, model_spec): if model not in model_spec: raise ValueError(