From 0c2aece54652d11c0b4f4269b68cf068bd278cbb Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Wed, 30 Aug 2023 11:13:46 +0800 Subject: [PATCH] fix format --- llm/predictor.py | 2 +- .../transformers/chatglm/modeling.py | 25 ++++++------- .../transformers/generation_utils.py | 10 +++--- .../transformers/llama/modeling.py | 35 ++++++++++--------- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/llm/predictor.py b/llm/predictor.py index 218c07bbdc38..8b5f3dfc58eb 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -417,7 +417,7 @@ def __init__( self.cache_kvs = [ paddle.zeros(shape, dtype=dtype) - for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size) + for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size, config.max_length) ] self.pre_ids = paddle.full([config.max_batch_size, config.max_length], -1, dtype="int64") if "chatglm" in self.architectures: diff --git a/paddlenlp/experimental/transformers/chatglm/modeling.py b/paddlenlp/experimental/transformers/chatglm/modeling.py index 80d826e81039..9ee16c21fe8f 100644 --- a/paddlenlp/experimental/transformers/chatglm/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm/modeling.py @@ -294,18 +294,19 @@ def forward( new_cache = [None] hidden_states = self.input_layernorm(hidden_states) - hidden_states, new_cache = self.transformer_block( - input_ids, - hidden_states, - cum_offsets=cum_offsets, - padding_offset=padding_offset, - attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), - caches=cache_kvs, - rotary_embs=paddle.cast(rotary_embeds, "float32"), - rotary_emb_dims=2 if self.config.position_encoding_2d else 1, - seq_lens=seq_lens, - time_step=time_step, - ) + with paddle.fluid.framework._stride_in_no_check_dy2st_diff(): + hidden_states, new_cache = self.transformer_block( + input_ids, + hidden_states, + cum_offsets=cum_offsets, + padding_offset=padding_offset, + attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), + caches=cache_kvs, + rotary_embs=paddle.cast(rotary_embeds, "float32"), + rotary_emb_dims=2 if self.config.position_encoding_2d else 1, + seq_lens=seq_lens, + time_step=time_step, + ) return (hidden_states, new_cache) @paddle.no_grad() diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index bc2144134d84..699b5f14f560 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -162,7 +162,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e if cache is None: # encoder's generation model_kwargs["tgt_ids"] = paddle.where(just_decoder, model_kwargs["tgt_ids"], next_tokens) - if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True: + if self.config["position_encoding_2d"] and self.config.position_encoding_2d is True: tgt_pos = model_kwargs["tgt_pos"] new_position_id = tgt_pos[:, 0, :].clone() new_block_id = tgt_pos[:, 1, :].clone() @@ -182,7 +182,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e ) else: model_kwargs["tgt_ids"] = next_tokens - if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True: + if self.config["position_encoding_2d"] and self.config.position_encoding_2d is True: tgt_pos = model_kwargs["tgt_pos"] new_position_id = tgt_pos[:, 0, :].clone() new_block_id = tgt_pos[:, 1, :].clone() @@ -261,9 +261,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): # compute next_tokens, use paddle.top_p_sampling logits = logits / temperature - _, next_tokens = top_p_sampling(probs, top_p) + _, next_tokens = top_p_sampling(probs, top_p, -1) - if self.model.config.tensor_parallel_degree > 1: + if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0) model_kwargs = self.update_model_kwargs_for_generation( @@ -275,7 +275,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): batch_idx, step_idx_ori, "real_time_save.temp_ids", - self.model.config.tensor_parallel_rank, + self.config.tensor_parallel_rank, ) return next_tokens, model_kwargs diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 469f6b8b9141..3a3a6c5549d1 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -25,7 +25,7 @@ from paddlenlp.experimental.transformers.generation_utils import ( GenerationInferenceModel, ) -from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM, LlamaPretrainedModel +from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel from paddlenlp.transformers.llama.modeling import LlamaLMHead from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -200,18 +200,19 @@ def forward( new_rope = fused_get_rotary_embedding(input_ids, position_ids, self.head_dim_shape_tensor, 0, True) - hidden_states, _ = self.transformer_block( - input_ids, - hidden_states, - cum_offsets=cum_offsets, - padding_offset=padding_offset, - attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), - caches=cache_kvs, - seq_lens=seq_lens, - rotary_embs=new_rope, - rotary_emb_dims=1, - time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None, - ) + with paddle.fluid.framework._stride_in_no_check_dy2st_diff(): + hidden_states, _ = self.transformer_block( + input_ids, + hidden_states, + cum_offsets=cum_offsets, + padding_offset=padding_offset, + attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), + caches=cache_kvs, + seq_lens=seq_lens, + rotary_embs=new_rope, + rotary_emb_dims=1, + time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -289,7 +290,7 @@ def set_state_dict(self, state_dict): ) -class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaForCausalLM): +class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaPretrainedModel): """ Dynamic Batching for LLaMA Model with pretraining tasks on top. """ @@ -298,7 +299,7 @@ class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaForCausalLM) def __init__(self, config): super().__init__(config) - self.model = LlamaInferenceModel(config) + self.llama = LlamaInferenceModel(config) self.lm_head = LlamaLMHead(config) @classmethod @@ -384,7 +385,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs = self.llama( input_ids, position_ids=position_ids, attention_mask=attention_mask, @@ -430,4 +431,4 @@ def forward( def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: self.lm_head.weight.set_value(state_dict["lm_head.weight"]) - self.model.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})