From 294df07ea6c97512e9dfb898d9e71fdb4b04007c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 7 Sep 2023 10:05:33 +0800 Subject: [PATCH] [Paddle Inference]support miniGPT4's second part dy2st (#6905) * support miniGPT4 * remove some useless code * move some function to modeling.py commit * 1->self.config.bos_token_id * remove useless comment * huifu * move prepare_input_ids_for_generation to modeling * LlamaForMiniGPT4InferenceModel * use model_type --- llm/predictor.py | 26 ++-- .../transformers/generation_utils.py | 24 ++- .../transformers/llama/modeling.py | 146 +++++++++++++++++- 3 files changed, 184 insertions(+), 12 deletions(-) diff --git a/llm/predictor.py b/llm/predictor.py index 980c8c884b80..ed28b4be23ef 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -74,11 +74,13 @@ class PredictorArgument: inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"}) batch_size: int = field(default=1, metadata={"help": "The batch size of data."}) max_batch_size: int = field(default=None, metadata={"help": "The max batch size of data during serving."}) - benchmark: bool = field( - default=False, - metadata={ - "help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. " - }, + benchmark: bool = ( + field( + default=False, + metadata={ + "help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. " + }, + ), ) @@ -573,13 +575,19 @@ def create_predictor( # TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) if "llama" in config.architectures[0].lower(): - from paddlenlp.experimental.transformers import ( - LlamaForCausalLMInferenceModel, - ) + if model_args.model_type == "llama-img2txt": + # we use llama for img2txt. + from paddlenlp.experimental.transformers import ( + LlamaForMiniGPT4InferenceModel as LlamaInferenceModel, + ) + else: + from paddlenlp.experimental.transformers import ( + LlamaForCausalLMInferenceModel as LlamaInferenceModel, + ) config.tensor_parallel_degree = tensor_parallel_degree config.tensor_parallel_rank = tensor_parallel_rank - model = LlamaForCausalLMInferenceModel.from_pretrained( + model = LlamaInferenceModel.from_pretrained( predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype ) model.eval() diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 4501a16336b5..0474dbfb1f39 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -85,6 +85,17 @@ def to_static(self, output_path: str, config: dict): model, output_path, skip_prune_program=True ) # Note(Zhengzekang): If we prune program it may cause some inference error. + @staticmethod + def prepare_input_ids_for_generation(bos_token_id, encoder_output=None): + batch_size = 1 + seq_len = 1 + if bos_token_id is None: + raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.") + if encoder_output is not None: + batch_size = encoder_output.shape[0] + seq_len = encoder_output.shape[1] + return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id + @paddle.no_grad() def generate( self, @@ -109,6 +120,7 @@ def generate( pre_ids=None, stop_nums=None, cache_kvs=[], + inputs_embeds=None, **model_kwargs, ): @@ -136,6 +148,7 @@ def generate( top_p=top_p, cache_kvs=cache_kvs, temperature=temperature, + inputs_embeds=inputs_embeds, **model_kwargs, ) return ret @@ -215,17 +228,23 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e def sample( self, - input_ids, - eos_token_id, + input_ids=None, + eos_token_id=None, cache_kvs=[], top_p=None, temperature=None, + inputs_embeds=None, **model_kwargs, ): step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1) batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1) + # let inputs_embeds enter into model_kwargs. + # because the code below directly use the model_kwargs as a parameter without using inputs_embeds. + model_kwargs["inputs_embeds"] = inputs_embeds + def _forward_(**args): + # cache_kvs is never empty because it is passed as a parameter in def sample. model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args) return self(**model_inputs) @@ -297,6 +316,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): ) step_idx_ori += 1 encoder_output = outputs + # gives it a value, means we will entered into decoder phase. model_kwargs["cache"] = 0 # decoder diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 3a3a6c5549d1..49dbe08c37eb 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -33,7 +33,7 @@ ) from paddlenlp.transformers.model_utils import register_base_model -__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel"] +__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForMiniGPT4InferenceModel"] class FusedLlamaRMSNorm(nn.Layer): @@ -149,6 +149,18 @@ def remove_padding(self, input_ids, seq_lens_this_time): ) return ids_remove_padding, padding_offset, cum_offsets + # This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py + @staticmethod + def prepare_input_ids_for_generation(bos_token_id, encoder_output=None): + batch_size = 1 + seq_len = 1 + if bos_token_id is None: + raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.") + if encoder_output is not None: + batch_size = encoder_output.shape[0] + seq_len = encoder_output.shape[1] + return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id + def forward( self, input_ids=None, @@ -165,9 +177,24 @@ def forward( return_dict=False, **kwargs, ): + # kwargs["cache"] is used used to distinguish between encoder and decoder phase. past_key_values = kwargs.get("cache", None) is_decoder = past_key_values is not None + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # genereate a fake input_ids according to inputs_embeds + # this is usually occurred in img2txt multimodal model when first enter into this forward function. + if input_ids is None and inputs_embeds is not None: + input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds) + if inputs_embeds is not None: + batch, seq_len, hidden_dim = inputs_embeds.shape + # merge batch and seq_len dimension. + inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -345,14 +372,19 @@ def prepare_inputs_for_generation( position_ids = kwargs.get("position_ids", None) attention_mask = kwargs.get("attention_mask", None) cache = kwargs.get("cache", None) + inputs_embeds = kwargs.get("inputs_embeds", None) if cache is not None: input_ids = tgt_ids position_ids = tgt_pos attention_mask = (tgt_generation_mask - 1) * 1e4 + # make inputs_embeds be none in decoder phase. + # in forward function, it will be assigned according to input_ids. + inputs_embeds = None else: attention_mask = (attention_mask - 1) * 1e4 model_inputs = { "input_ids": input_ids, + "inputs_embeds": inputs_embeds, "position_ids": position_ids, "attention_mask": attention_mask, "cache_kvs": cache_kvs, @@ -432,3 +464,115 @@ def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: self.lm_head.weight.set_value(state_dict["lm_head.weight"]) self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + + +class LlamaForMiniGPT4InferenceModel(LlamaForCausalLMInferenceModel): + """ + This class is 99% like LlamaForCausalLMInferenceModel. + Used only for miniGPT4's second part. + """ + + # This function corresponds to miniGPT4's second part, only used in miniGPT4. + @paddle.no_grad() + def generate_text_with_image_features( + self, + image_features: paddle.Tensor, + first_input_ids: paddle.Tensor, + second_input_ids: paddle.Tensor, + attention_mask: paddle.Tensor, + position_ids=None, + penalty_score=None, + frequency_score=None, + presence_score=None, + min_length=None, + max_length=None, + temperature=None, + top_p=None, + eos_token_id=None, + seq_len_encoder=None, + seq_len_decoder=None, + step_idx=None, + stop_flags=None, + tgt_ids=None, + tgt_pos=None, + tgt_generation_mask=None, + pre_ids=None, + stop_nums=None, + cache_kvs=[], + inputs_embeds=None, + **generate_kwargs + ) -> paddle.Tensor: + + first_embeds = self.llama.embed_tokens(first_input_ids) + second_embeds = self.llama.embed_tokens(second_input_ids) + image_features = paddle.cast(image_features, dtype=first_embeds.dtype) + inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1) + + outputs = self.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + penalty_score=penalty_score, + frequency_score=frequency_score, + presence_score=presence_score, + min_length=min_length, + max_length=max_length, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + seq_len_encoder=seq_len_encoder, + seq_len_decoder=seq_len_decoder, + step_idx=step_idx, + stop_flags=stop_flags, + tgt_ids=tgt_ids, + tgt_pos=tgt_pos, + tgt_generation_mask=tgt_generation_mask, + pre_ids=pre_ids, + stop_nums=stop_nums, + cache_kvs=cache_kvs, + ) + return outputs + + # rewrite to_static function in generation_utils.py + def to_static(self, output_path: str, config: dict): + dtype = config.get("dtype", paddle.get_default_dtype()) + cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None)) + input_spec = [ + paddle.static.InputSpec( + shape=[None, None, None], dtype="float32", name="image_features" + ), # image_features + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids + paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p + paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id + paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder + paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx + paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos + paddle.static.InputSpec( + shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask" + ), # tgt_generation_mask + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids + paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums + [ + paddle.static.InputSpec( + shape=shape, + dtype=dtype, + name="cache_kvs_{}".format(i), + ) + for i, shape in enumerate(cache_kvs_shapes) + ], # cache_kvs + ] + + model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec) + paddle.jit.save(model, output_path)