From da3f9ba80c5c153ad0d6a66c98e9d448c66e0c59 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 18 Sep 2023 08:41:06 +0000 Subject: [PATCH 1/9] support batch_size=1 --- llm/predictor.py | 28 +- .../experimental/transformers/__init__.py | 1 + .../transformers/fused_transformer_layers.py | 4 +- .../experimental/transformers/gpt/__init__.py | 15 + .../experimental/transformers/gpt/modeling.py | 524 ++++++++++++++++++ 5 files changed, 559 insertions(+), 13 deletions(-) create mode 100644 paddlenlp/experimental/transformers/gpt/__init__.py create mode 100644 paddlenlp/experimental/transformers/gpt/modeling.py diff --git a/llm/predictor.py b/llm/predictor.py index dc6da46cbe73..ac1832143b40 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -664,8 +664,6 @@ def create_predictor( LlamaForCausalLMInferenceModel as LlamaInferenceModel, ) - config.tensor_parallel_degree = tensor_parallel_degree - config.tensor_parallel_rank = tensor_parallel_rank config.quant_bits = -1 if predictor_args.quant_type.startswith("weight_only_int"): @@ -692,8 +690,6 @@ def create_predictor( BloomForCausalLMInferenceModel, ) - config.tensor_parallel_degree = tensor_parallel_degree - config.tensor_parallel_rank = tensor_parallel_rank model = BloomForCausalLMInferenceModel.from_pretrained( predictor_args.model_name_or_path, config=config, @@ -701,6 +697,18 @@ def create_predictor( ) cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) model.eval() + elif "gpt" in config.architectures[0].lower(): + # raise NotImplementedError() + from paddlenlp.experimental.transformers import ( + GPTForCausalLMInferenceModel, + ) + + model = GPTForCausalLMInferenceModel.from_pretrained( + predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype, + ) + model.eval() + else: + raise ValueError("the `model type` should be one of [llama, chatglm, gpt]") predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) @@ -710,25 +718,21 @@ def create_predictor( ) cache_kvs_shape = LlamaForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) - predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) elif "chatglm" in config.architectures[0].lower(): from paddlenlp.experimental.transformers import ( ChatGLMForCausalLMInferenceModel, ) - cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape( - config, predictor_args.batch_size - ) - predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) + cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) elif "bloom" in config.architectures[0].lower(): from paddlenlp.experimental.transformers import ( BloomForCausalLMInferenceModel, ) cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) - predictor = StaticInferencePredictor( - predictor_args, cache_kvs_shape=cache_kvs_shape, tokenizer=tokenizer - ) + else: + raise ValueError("the `model type` should be one of [llama, chatglm, bloom]") + predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) else: raise ValueError("the `mode` should be one of [dynamic, static]") return predictor diff --git a/paddlenlp/experimental/transformers/__init__.py b/paddlenlp/experimental/transformers/__init__.py index 007827ebc4b9..42169640d946 100644 --- a/paddlenlp/experimental/transformers/__init__.py +++ b/paddlenlp/experimental/transformers/__init__.py @@ -16,3 +16,4 @@ from .chatglm import * from .fused_transformer_layers import * from .llama import * +from .gpt import * diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index cc587ca376ad..796bf02ec289 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -186,8 +186,10 @@ def __init__( self.norm_type = norm_type if norm_type == "layernorm": self.norm_func = fused_layer_norm - else: + elif norm_type == "rmsnorm": self.norm_func = fused_rms_norm + else: + raise NotImplementedError("Only support norm type of [layernorm, rmsnorm]") self.use_neox_rotary_style = use_neox_rotary_style self._norm_weight_dtype = "float32" if self.norm_type == "layernorm" else self._dtype diff --git a/paddlenlp/experimental/transformers/gpt/__init__.py b/paddlenlp/experimental/transformers/gpt/__init__.py new file mode 100644 index 000000000000..c2a7f656c636 --- /dev/null +++ b/paddlenlp/experimental/transformers/gpt/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling import * diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py new file mode 100644 index 000000000000..9520268fc22f --- /dev/null +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -0,0 +1,524 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import paddle +from paddle import nn +from paddle.distributed import fleet +from paddlenlp_ops import get_padding_offset + +from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedMultiTransformer, +) +from paddlenlp.experimental.transformers.generation_utils import ( + GenerationInferenceModel, +) +from paddlenlp.transformers import GPTConfig, GPTPretrainedModel +from paddlenlp.transformers.gpt.modeling import parallel_matmul +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddlenlp.transformers.model_utils import register_base_model + +__all__ = ["GPTInferenceModel", "GPTForCausalLMInferenceModel"] + + +class GPTEmbeddingsDyBatch(nn.Layer): + """ + Include embeddings from word and position embeddings. + """ + + def __init__( + self, + config, + ): + super(GPTEmbeddingsDyBatch, self).__init__() + + if config.tensor_parallel_degree > 1: + self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + ) + + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + ) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, position_ids): + inputs_embeddings = self.word_embeddings(input_ids) + + position_ids = paddle.slice(position_ids, axes=[1], starts=[0], ends=[input_ids.shape[0]]) + position_ids = position_ids.squeeze(0) + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeddings + position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +@register_base_model +class GPTInferenceModel(GPTPretrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GPTDecoderLayer`] + Args: + config: GPTConfig + """ + + def __init__(self, config: GPTConfig): + super().__init__(config) + self.pad_token_id = config.pad_token_id + self.eos_token_id = config.eos_token_id + self.bos_token_id = config.bos_token_id + self.eol_token_id = config.eol_token_id + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_layers = config.num_hidden_layers + self.max_position_embeddings = config.max_position_embeddings + + self.bias = paddle.tril( + paddle.ones([1, 1, config.max_position_embeddings, config.max_position_embeddings], dtype="int64") + ) + + self.embeddings = GPTEmbeddingsDyBatch(config) + + # get ring_id + ring_id = -1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + ring_id = model_parallel_group.id + except: + pass + + ln_scale_attrs = [paddle.ParamAttr(name="fusemt.{}.ln_scale".format(i)) for i in range(self.num_layers)] + ln_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ln_bias".format(i)) for i in range(self.num_layers)] + qkv_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.qkv_weight".format(i)) for i in range(self.num_layers)] + qkv_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.qkv_bias".format(i)) for i in range(self.num_layers)] + linear_weight_attrs = [ + paddle.ParamAttr(name="fusemt.{}.linear_weight".format(i)) for i in range(self.num_layers) + ] + linear_bias_attrs = [ + paddle.ParamAttr(name="fusemt.{}.linear_bias".format(i)) for i in range(self.num_layers) + ] + ffn_ln_scale_attrs = [ + paddle.ParamAttr(name="fusemt.{}.ffn_ln_scale".format(i)) for i in range(self.num_layers) + ] + ffn_ln_bias_attrs = [ + paddle.ParamAttr(name="fusemt.{}.ffn_ln_bias".format(i)) for i in range(self.num_layers) + ] + ffn1_weight_attrs = [ + paddle.ParamAttr(name="fusemt.{}.ffn1_weight".format(i)) for i in range(self.num_layers) + ] + ffn1_bias_attrs = [ + paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(self.num_layers) + ] + ffn2_weight_attrs = [ + paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(self.num_layers) + ] + ffn2_bias_attrs = [ + paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(self.num_layers) + ] + self.transformer_block = FusedMultiTransformer( + config.hidden_size, + config.num_attention_heads, + 4 * config.hidden_size, + activation="gelu", + num_layers=self.num_layers, + nranks=config.tensor_parallel_degree, + ring_id=ring_id, + ln_scale_attrs=ln_scale_attrs, + ln_bias_attrs=ln_bias_attrs, + qkv_weight_attrs=qkv_weight_attrs, + qkv_bias_attrs=qkv_bias_attrs, + linear_weight_attrs=linear_weight_attrs, + linear_bias_attrs=linear_bias_attrs, + ffn_ln_scale_attrs=ffn_ln_scale_attrs, + ffn_ln_bias_attrs=ffn_ln_bias_attrs, + ffn1_weight_attrs=ffn1_weight_attrs, + ffn1_bias_attrs=ffn1_bias_attrs, + ffn2_weight_attrs=ffn2_weight_attrs, + ffn2_bias_attrs=ffn2_bias_attrs, + epsilon=1e-5, + norm_type="layernorm", + ) + self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5) + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def remove_padding(self, input_ids, seq_lens_this_time): + cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) + token_num = paddle.sum(seq_lens_this_time) + ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( + input_ids, cum_offsets_now, token_num, seq_lens_this_time + ) + return ids_remove_padding, padding_offset, cum_offsets + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + cache=None, + cache_kvs=None, + seq_len_encoder=None, + seq_len_decoder=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + **kwargs, + ): + cache = kwargs.get("cache", cache) + is_decoder = cache is not None + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = paddle.shape(input_ids) + input_ids = input_ids.reshape((-1, input_shape[-1])) + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + input_shape = paddle.shape(inputs_embeds)[:-1] + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # embed positions + if attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool) + + if not is_decoder: + ids_remove_padding, padding_offset, cum_offsets = self.remove_padding(input_ids, seq_len_encoder) + else: + ids_remove_padding = input_ids + padding_offset = None + cum_offsets = None + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids=ids_remove_padding, position_ids=position_ids) + + if cache is None: + cache = tuple([None] * self.num_layers) + + # TODO, use registered buffer + length = input_shape[-1] + if is_decoder: + cache_length = paddle.shape(attention_mask)[-1] - 1 + length = length + cache_length + else: + cache_length = 0 + causal_mask = self.bias[:, :, cache_length:length, :length] + + attention_mask = (1.0 - causal_mask) * -1e4 + + # The tensor returned by triu not in static graph. + attention_mask.stop_gradient = True + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + seq_lens = seq_len_decoder if is_decoder else seq_len_encoder + + hidden_states = inputs_embeds + + with paddle.base.framework._stride_in_no_check_dy2st_diff(): + hidden_states, _ = self.transformer_block( + input_ids, + hidden_states, + cum_offsets=cum_offsets, + padding_offset=padding_offset, + attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), + caches=cache_kvs, + seq_lens=seq_lens, + time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None, + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, None, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + + + @paddle.no_grad() + def set_state_dict(self, state_dict): + dtype = paddle.get_default_dtype() + + for k, v in state_dict.items(): + if k.startswith("gpt."): + k = str(k.split("gpt.")[1]) + if k.find("embeddings.word_embeddings.weight") >= 0: + self.embeddings.word_embeddings.weight.set_value(v.astype(dtype)) + elif k.find("embeddings.position_embeddings.weight") >= 0: + self.embeddings.position_embeddings.weight.set_value(v.astype(dtype)) + elif k.find("decoder.norm.weight") >= 0: + self.norm.weight.set_value(v.astype(dtype)) + elif k.find("decoder.norm.bias") >= 0: + self.norm.bias.set_value(v.astype(dtype)) + else: + if not k.startswith("decoder.layers."): + continue + idx = int(k.split(".")[2]) + if k.endswith("norm1.weight"): + self.transformer_block.ln_scales[idx].set_value(v.astype("float32")) + elif k.endswith("norm1.bias"): + self.transformer_block.ln_biases[idx].set_value(v.astype("float32")) + elif k.endswith("self_attn.qkv_proj.weight"): + self.transformer_block.qkv_weights[idx].set_value( + v.reshape( + [ + self.hidden_size, + self.num_attention_heads // self.config.tensor_parallel_degree, + 3, + self.hidden_size // self.num_attention_heads, + ] + ) + .transpose([2, 1, 3, 0]) + .reshape( + [ + self.num_attention_heads // self.config.tensor_parallel_degree * 3 * self.hidden_size // self.num_attention_heads, + self.hidden_size + ] + ) + .astype(dtype) + ) + elif k.endswith("self_attn.qkv_proj.bias"): + self.transformer_block.qkv_biases[idx].set_value( + v.reshape( + [ + self.num_attention_heads // self.config.tensor_parallel_degree, + 3, + self.hidden_size // self.num_attention_heads, + ] + ) + .transpose([1, 0, 2]) + .reshape( + [ + self.num_attention_heads // self.config.tensor_parallel_degree * 3 * self.hidden_size // self.num_attention_heads + ] + ) + .astype(dtype) + ) + elif k.endswith("self_attn.out_proj.weight"): + self.transformer_block.linear_weights[idx].set_value(v.astype(dtype)) + elif k.endswith("self_attn.out_proj.bias"): + self.transformer_block.linear_biases[idx].set_value(v.astype(dtype)) + elif k.endswith("norm2.weight"): + self.transformer_block.ffn_ln_scales[idx].set_value(v.astype("float32")) + elif k.endswith("norm2.bias"): + self.transformer_block.ffn_ln_biases[idx].set_value(v.astype("float32")) + elif k.endswith("linear1.weight"): + self.transformer_block.ffn1_weights[idx].set_value(v.astype(dtype)) + elif k.endswith("linear1.bias"): + self.transformer_block.ffn1_biases[idx].set_value(v.astype(dtype)) + elif k.endswith("linear2.weight"): + self.transformer_block.ffn2_weights[idx].set_value(v.astype(dtype)) + elif k.endswith("linear2.bias"): + self.transformer_block.ffn2_biases[idx].set_value(v.astype(dtype)) + else: + raise ValueError("Unknow weight {}".format(k)) + + +class GPTForCausalLMInferenceModel(GenerationInferenceModel, GPTPretrainedModel): + """ + Dynamic Batching for GPT Model with pretraining tasks on top. + """ + + def __init__(self, config): + super().__init__(config) + self.gpt = GPTInferenceModel(config) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, from_hf_hub: bool = False, subfolder: str | None = None, *args, **kwargs + ): + # TODO: Support safetensors loading. + kwargs["use_safetensors"] = False + return super().from_pretrained(pretrained_model_name_or_path, from_hf_hub, subfolder, *args, **kwargs) + + @classmethod + def get_cache_kvs_shape( + cls, config: GPTConfig, max_batch_size: int = None, max_length: int = None + ) -> list[list[int]]: + """get cache_kvs tensor for gpt model + + Args: + max_batch_size (int): the max batch size + max_length (int | None, optional): the max_length of cache_kvs. Defaults to None. + + Returns: + list[paddle.Tensor]: the list tensor shape for cache + """ + if max_length is None: + max_length = config.max_position_embeddings + + cache_kvs = [] + for _ in range(config.num_hidden_layers): + cache_kvs.append( + [ + 2, + max_batch_size, + config.num_attention_heads // max(config.tensor_parallel_degree, 1), + max_length, + config.hidden_size // config.num_attention_heads, + ] + ) + return cache_kvs + + def prepare_inputs_for_generation( + self, + input_ids, + cache_kvs, + seq_len_encoder, + seq_len_decoder, + tgt_ids, + tgt_pos, + tgt_generation_mask, + **kwargs, + ): + position_ids = kwargs.get("position_ids", None) + attention_mask = kwargs.get("attention_mask", None) + cache = kwargs.get("cache", None) + if cache is not None: + input_ids = tgt_ids + position_ids = tgt_pos + attention_mask = (tgt_generation_mask - 1) * 1e4 + else: + attention_mask = (attention_mask - 1) * 1e4 + + model_inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "cache_kvs": cache_kvs, + "seq_len_encoder": seq_len_encoder, + "seq_len_decoder": seq_len_decoder, + "cache": cache, + } + return model_inputs + + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any( + input_ids == pad_token_id + ).numpy().item() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids != pad_token_id).astype("int64") + else: + attention_mask = paddle.ones_like(input_ids, dtype="int64") + return paddle.unsqueeze(attention_mask, axis=[1, 2]) + + def forward( + self, + input_ids, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + cache=None, + cache_kvs=None, + seq_len_encoder=None, + seq_len_decoder=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache=cache, + cache_kvs=cache_kvs, + seq_len_encoder=seq_len_encoder, + seq_len_decoder=seq_len_decoder, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = parallel_matmul(hidden_states, self.gpt.embeddings.word_embeddings.weight, tensor_parallel_output=False) + + if not return_dict: + return (logits, outputs[1:]) + + return CausalLMOutputWithCrossAttentions( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + + @paddle.no_grad() + def set_state_dict(self, state_dict): + self.gpt.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) From 953297b22d32cf755e59d43ab4cd2a0d52604b3f Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 18 Sep 2023 12:20:19 +0000 Subject: [PATCH 2/9] support batch_size > 1 --- .../experimental/transformers/gpt/modeling.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index 9520268fc22f..a15669dd9fe9 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -64,11 +64,17 @@ def __init__( self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, input_ids, position_ids): + def forward(self, input_ids, position_ids, seq_lens): inputs_embeddings = self.word_embeddings(input_ids) - position_ids = paddle.slice(position_ids, axes=[1], starts=[0], ends=[input_ids.shape[0]]) - position_ids = position_ids.squeeze(0) + if position_ids is None: + position_ids = paddle.arange(input_ids.shape[0], dtype=input_ids.dtype) + pre_len = seq_lens[0] + seq_lens = seq_lens[1:] + for seq_len in seq_lens: + position_ids[pre_len:seq_len + pre_len] = position_ids[pre_len:seq_len + pre_len] - pre_len + pre_len += seq_len + position_embeddings = self.position_embeddings(position_ids) embeddings = inputs_embeddings + position_embeddings embeddings = self.dropout(embeddings) @@ -239,8 +245,10 @@ def forward( padding_offset = None cum_offsets = None - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids=ids_remove_padding, position_ids=position_ids) + seq_lens = seq_len_decoder if is_decoder else seq_len_encoder + if not is_decoder: + position_ids = None + inputs_embeds = self.embeddings(input_ids=ids_remove_padding, position_ids=position_ids, seq_lens=seq_lens) if cache is None: cache = tuple([None] * self.num_layers) @@ -262,8 +270,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - seq_lens = seq_len_decoder if is_decoder else seq_len_encoder - hidden_states = inputs_embeds with paddle.base.framework._stride_in_no_check_dy2st_diff(): From 0ebfa8ac10e4b4959b4ebf59f9a2903b13a5b960 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 18 Sep 2023 12:30:11 +0000 Subject: [PATCH 3/9] update --- llm/predictor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llm/predictor.py b/llm/predictor.py index 13ac3dae56c5..80ec5a201e9c 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -698,7 +698,6 @@ def create_predictor( cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) model.eval() elif "gpt" in config.architectures[0].lower(): - # raise NotImplementedError() from paddlenlp.experimental.transformers import ( GPTForCausalLMInferenceModel, ) @@ -708,7 +707,7 @@ def create_predictor( ) model.eval() else: - raise ValueError("the `model type` should be one of [llama, chatglm, gpt]") + raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]") predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) From eb54fc530eea8e9b8a6996cd71914c6e1ad1e381 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 19 Sep 2023 08:21:47 +0000 Subject: [PATCH 4/9] support to_static --- llm/predictor.py | 8 +- llm/utils.py | 29 ++++++ .../transformers/generation_utils.py | 6 +- .../experimental/transformers/gpt/modeling.py | 97 ++----------------- 4 files changed, 47 insertions(+), 93 deletions(-) diff --git a/llm/predictor.py b/llm/predictor.py index 80ec5a201e9c..92243a8c5e61 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -729,8 +729,14 @@ def create_predictor( ) cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) + elif "gpt" in config.architectures[0].lower(): + from paddlenlp.experimental.transformers import ( + GPTForCausalLMInferenceModel, + ) + + cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) else: - raise ValueError("the `model type` should be one of [llama, chatglm, bloom]") + raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]") predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) else: raise ValueError("the `mode` should be one of [dynamic, static]") diff --git a/llm/utils.py b/llm/utils.py index 5d439cec1145..e46447383dd8 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -413,6 +413,35 @@ def dybatch_preprocess( for i in range(len(position_ids)): inst_data_pos.append(np.array([list(inst) + [0] * (max_len - len(inst)) for inst in position_ids[i]])) inputs["position_ids"] = paddle.to_tensor(np.array(inst_data_pos)) + if "gpt" in architectures: + input_ids = [] + position_ids = [] + if isinstance(texts, str): + texts = [texts] + + for text in texts: + tokens = tokenizer( + text, + return_tensors="np", + padding=False, + max_length=src_length, + return_attention_mask=False, + return_token_type_ids=False, + ) + input_ids.append(tokens["input_ids"][0]) + + inputs = {} + pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1] + inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True) + bs = inputs["input_ids"].shape[0] + max_len = max(map(len, input_ids)) + + position_ids = paddle.arange(sum(seq_len), dtype="int64") + pre_len = seq_len[0] + for length in seq_len[1:]: + position_ids[pre_len:length + pre_len] = position_ids[pre_len:length + pre_len] - pre_len + pre_len += length + inputs["position_ids"] = position_ids else: input_ids = [] position_ids = [] diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 54cd637eca11..d24716813528 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -111,10 +111,10 @@ def to_static(self, output_path: str, config: dict): precache_input_spec, ] if self.config["model_type"] and "chatglm" in self.config.model_type: - input_spec[2] = paddle.static.InputSpec( - shape=[None, None, None], dtype="int64", name="position_ids" - ) # position_ids + input_spec[2] = paddle.static.InputSpec(shape=[None, None, None], dtype="int64", name="position_ids") # position_ids input_spec[16] = paddle.static.InputSpec(shape=[None, 2, 1], dtype="int64", name="tgt_pos") # tgt_pos + elif self.config["model_type"] and "gpt" in self.config.model_type: + input_spec[2] = paddle.static.InputSpec(shape=[None], dtype="int64", name="position_ids") # position_ids model = paddle.jit.to_static(self.generate, input_spec=input_spec) paddle.jit.save( model, output_path, skip_prune_program=True diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index a15669dd9fe9..dbd0941f237c 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -25,7 +25,7 @@ GenerationInferenceModel, ) from paddlenlp.transformers import GPTConfig, GPTPretrainedModel -from paddlenlp.transformers.gpt.modeling import parallel_matmul +from paddlenlp.transformers.gpt.modeling import GPTEmbeddings, parallel_matmul from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -35,53 +35,6 @@ __all__ = ["GPTInferenceModel", "GPTForCausalLMInferenceModel"] -class GPTEmbeddingsDyBatch(nn.Layer): - """ - Include embeddings from word and position embeddings. - """ - - def __init__( - self, - config, - ): - super(GPTEmbeddingsDyBatch, self).__init__() - - if config.tensor_parallel_degree > 1: - self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - else: - self.word_embeddings = nn.Embedding( - config.vocab_size, - config.hidden_size, - ) - - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, - config.hidden_size, - ) - - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids, position_ids, seq_lens): - inputs_embeddings = self.word_embeddings(input_ids) - - if position_ids is None: - position_ids = paddle.arange(input_ids.shape[0], dtype=input_ids.dtype) - pre_len = seq_lens[0] - seq_lens = seq_lens[1:] - for seq_len in seq_lens: - position_ids[pre_len:seq_len + pre_len] = position_ids[pre_len:seq_len + pre_len] - pre_len - pre_len += seq_len - - position_embeddings = self.position_embeddings(position_ids) - embeddings = inputs_embeddings + position_embeddings - embeddings = self.dropout(embeddings) - - return embeddings - - @register_base_model class GPTInferenceModel(GPTPretrainedModel): """ @@ -103,11 +56,7 @@ def __init__(self, config: GPTConfig): self.num_layers = config.num_hidden_layers self.max_position_embeddings = config.max_position_embeddings - self.bias = paddle.tril( - paddle.ones([1, 1, config.max_position_embeddings, config.max_position_embeddings], dtype="int64") - ) - - self.embeddings = GPTEmbeddingsDyBatch(config) + self.embeddings = GPTEmbeddings(config) # get ring_id ring_id = -1 @@ -221,23 +170,10 @@ def forward( ) if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = paddle.shape(input_ids) - input_ids = input_ids.reshape((-1, input_shape[-1])) - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - input_shape = paddle.shape(inputs_embeds)[:-1] - batch_size, seq_length, _ = inputs_embeds.shape - else: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") - # embed positions - if attention_mask is None: - attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool) - if not is_decoder: ids_remove_padding, padding_offset, cum_offsets = self.remove_padding(input_ids, seq_len_encoder) else: @@ -245,31 +181,14 @@ def forward( padding_offset = None cum_offsets = None - seq_lens = seq_len_decoder if is_decoder else seq_len_encoder - if not is_decoder: - position_ids = None - inputs_embeds = self.embeddings(input_ids=ids_remove_padding, position_ids=position_ids, seq_lens=seq_lens) - - if cache is None: - cache = tuple([None] * self.num_layers) - - # TODO, use registered buffer - length = input_shape[-1] - if is_decoder: - cache_length = paddle.shape(attention_mask)[-1] - 1 - length = length + cache_length - else: - cache_length = 0 - causal_mask = self.bias[:, :, cache_length:length, :length] - - attention_mask = (1.0 - causal_mask) * -1e4 - - # The tensor returned by triu not in static graph. - attention_mask.stop_gradient = True + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids=ids_remove_padding, position_ids=position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + seq_lens = seq_len_decoder if is_decoder else seq_len_encoder + hidden_states = inputs_embeds with paddle.base.framework._stride_in_no_check_dy2st_diff(): From 486adef9b5b320e7417e4816ddbf82a53fe7db1d Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 19 Sep 2023 08:37:57 +0000 Subject: [PATCH 5/9] add benchmark for dybatch_preprocess --- llm/predictor.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/llm/predictor.py b/llm/predictor.py index 92243a8c5e61..654725177172 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -357,6 +357,7 @@ def _preprocess(self, source): self.architectures, top_p=self.config.top_p, temperature=self.config.temperature, + benchmark=self.config.benchmark, ) for i in range(inputs["input_ids"].shape[0]): length = inputs["seq_len_encoder"][i][0] @@ -375,6 +376,7 @@ def _preprocess(self, source): self.architectures, top_p=self.config.top_p, temperature=self.config.temperature, + benchmark=self.config.benchmark, ) for i in range(inputs["input_ids"].shape[0]): length = inputs["seq_len_encoder"][i][0] @@ -439,6 +441,7 @@ def _preprocess(self, source): top_p=self.config.top_p, temperature=self.config.temperature, pre_caches_length=pre_caches_length, + benchmark=self.config.benchmark, ) for i in range(inputs["input_ids"].shape[0]): @@ -703,7 +706,9 @@ def create_predictor( ) model = GPTForCausalLMInferenceModel.from_pretrained( - predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype, + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, ) model.eval() else: @@ -722,7 +727,9 @@ def create_predictor( ChatGLMForCausalLMInferenceModel, ) - cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size) + cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape( + config, predictor_args.batch_size + ) elif "bloom" in config.architectures[0].lower(): from paddlenlp.experimental.transformers import ( BloomForCausalLMInferenceModel, From 54a4ed6fc1bf4b564ec0d4f786d4da8e94c54ca6 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 19 Sep 2023 08:55:26 +0000 Subject: [PATCH 6/9] fix code style --- .../transformers/generation_utils.py | 4 +- .../experimental/transformers/gpt/modeling.py | 64 ++++++++----------- 2 files changed, 29 insertions(+), 39 deletions(-) diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index d24716813528..42a86f489ca1 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -111,7 +111,9 @@ def to_static(self, output_path: str, config: dict): precache_input_spec, ] if self.config["model_type"] and "chatglm" in self.config.model_type: - input_spec[2] = paddle.static.InputSpec(shape=[None, None, None], dtype="int64", name="position_ids") # position_ids + input_spec[2] = paddle.static.InputSpec( + shape=[None, None, None], dtype="int64", name="position_ids" + ) # position_ids input_spec[16] = paddle.static.InputSpec(shape=[None, 2, 1], dtype="int64", name="tgt_pos") # tgt_pos elif self.config["model_type"] and "gpt" in self.config.model_type: input_spec[2] = paddle.static.InputSpec(shape=[None], dtype="int64", name="position_ids") # position_ids diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index dbd0941f237c..66d333ea6722 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -54,6 +54,7 @@ def __init__(self, config: GPTConfig): self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_layers = config.num_hidden_layers + self.max_position_embeddings = config.max_position_embeddings self.embeddings = GPTEmbeddings(config) @@ -74,27 +75,15 @@ def __init__(self, config: GPTConfig): linear_weight_attrs = [ paddle.ParamAttr(name="fusemt.{}.linear_weight".format(i)) for i in range(self.num_layers) ] - linear_bias_attrs = [ - paddle.ParamAttr(name="fusemt.{}.linear_bias".format(i)) for i in range(self.num_layers) - ] + linear_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.linear_bias".format(i)) for i in range(self.num_layers)] ffn_ln_scale_attrs = [ paddle.ParamAttr(name="fusemt.{}.ffn_ln_scale".format(i)) for i in range(self.num_layers) ] - ffn_ln_bias_attrs = [ - paddle.ParamAttr(name="fusemt.{}.ffn_ln_bias".format(i)) for i in range(self.num_layers) - ] - ffn1_weight_attrs = [ - paddle.ParamAttr(name="fusemt.{}.ffn1_weight".format(i)) for i in range(self.num_layers) - ] - ffn1_bias_attrs = [ - paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(self.num_layers) - ] - ffn2_weight_attrs = [ - paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(self.num_layers) - ] - ffn2_bias_attrs = [ - paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(self.num_layers) - ] + ffn_ln_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn_ln_bias".format(i)) for i in range(self.num_layers)] + ffn1_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_weight".format(i)) for i in range(self.num_layers)] + ffn1_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(self.num_layers)] + ffn2_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(self.num_layers)] + ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(self.num_layers)] self.transformer_block = FusedMultiTransformer( config.hidden_size, config.num_attention_heads, @@ -120,7 +109,6 @@ def __init__(self, config: GPTConfig): ) self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5) - def get_input_embeddings(self): return self.embeddings.word_embeddings @@ -155,19 +143,11 @@ def forward( cache = kwargs.get("cache", cache) is_decoder = cache is not None - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -202,7 +182,7 @@ def forward( seq_lens=seq_lens, time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None, ) - + hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -219,7 +199,6 @@ def forward( cross_attentions=None, ) - @paddle.no_grad() def set_state_dict(self, state_dict): dtype = paddle.get_default_dtype() @@ -256,8 +235,12 @@ def set_state_dict(self, state_dict): .transpose([2, 1, 3, 0]) .reshape( [ - self.num_attention_heads // self.config.tensor_parallel_degree * 3 * self.hidden_size // self.num_attention_heads, - self.hidden_size + self.num_attention_heads + // self.config.tensor_parallel_degree + * 3 + * self.hidden_size + // self.num_attention_heads, + self.hidden_size, ] ) .astype(dtype) @@ -274,7 +257,11 @@ def set_state_dict(self, state_dict): .transpose([1, 0, 2]) .reshape( [ - self.num_attention_heads // self.config.tensor_parallel_degree * 3 * self.hidden_size // self.num_attention_heads + self.num_attention_heads + // self.config.tensor_parallel_degree + * 3 + * self.hidden_size + // self.num_attention_heads ] ) .astype(dtype) @@ -428,9 +415,11 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - + hidden_states = outputs[0] - logits = parallel_matmul(hidden_states, self.gpt.embeddings.word_embeddings.weight, tensor_parallel_output=False) + logits = parallel_matmul( + hidden_states, self.gpt.embeddings.word_embeddings.weight, tensor_parallel_output=False + ) if not return_dict: return (logits, outputs[1:]) @@ -443,7 +432,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @paddle.no_grad() def set_state_dict(self, state_dict): self.gpt.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) From 7dccaf178d3147c3b8cb57943fb5e68a30a3ee56 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 19 Sep 2023 09:00:29 +0000 Subject: [PATCH 7/9] fix code style --- paddlenlp/experimental/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/experimental/transformers/__init__.py b/paddlenlp/experimental/transformers/__init__.py index 42169640d946..1b1950f16b27 100644 --- a/paddlenlp/experimental/transformers/__init__.py +++ b/paddlenlp/experimental/transformers/__init__.py @@ -15,5 +15,5 @@ from .bloom import * from .chatglm import * from .fused_transformer_layers import * -from .llama import * from .gpt import * +from .llama import * From c7ec8bf9740b261f436682fdcde135502a5b328b Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 19 Sep 2023 11:25:34 +0000 Subject: [PATCH 8/9] fix comment --- .../experimental/transformers/gpt/modeling.py | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index 66d333ea6722..d41ea16b65f2 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -30,7 +30,10 @@ BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) -from paddlenlp.transformers.model_utils import register_base_model +from paddlenlp.transformers.model_utils import ( + dy2st_nocheck_guard_context, + register_base_model, +) __all__ = ["GPTInferenceModel", "GPTForCausalLMInferenceModel"] @@ -68,22 +71,46 @@ def __init__(self, config: GPTConfig): except: pass - ln_scale_attrs = [paddle.ParamAttr(name="fusemt.{}.ln_scale".format(i)) for i in range(self.num_layers)] - ln_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ln_bias".format(i)) for i in range(self.num_layers)] - qkv_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.qkv_weight".format(i)) for i in range(self.num_layers)] - qkv_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.qkv_bias".format(i)) for i in range(self.num_layers)] + ln_scale_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.norm1.weight".format(i)) for i in range(self.num_layers) + ] + ln_bias_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.norm1.bias".format(i)) for i in range(self.num_layers) + ] + qkv_weight_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.self_attn.qkv_proj.weight".format(i)) + for i in range(self.num_layers) + ] + qkv_bias_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.self_attn.qkv_proj.bias".format(i)) + for i in range(self.num_layers) + ] linear_weight_attrs = [ - paddle.ParamAttr(name="fusemt.{}.linear_weight".format(i)) for i in range(self.num_layers) + paddle.ParamAttr(name="gpt.decoder.layers.{}.self_attn.out_proj.weight".format(i)) + for i in range(self.num_layers) + ] + linear_bias_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.self_attn.out_proj.bias".format(i)) + for i in range(self.num_layers) ] - linear_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.linear_bias".format(i)) for i in range(self.num_layers)] ffn_ln_scale_attrs = [ - paddle.ParamAttr(name="fusemt.{}.ffn_ln_scale".format(i)) for i in range(self.num_layers) + paddle.ParamAttr(name="gpt.decoder.layers.{}.norm2.weight".format(i)) for i in range(self.num_layers) + ] + ffn_ln_bias_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.norm2.bias".format(i)) for i in range(self.num_layers) + ] + ffn1_weight_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.linear1.weight".format(i)) for i in range(self.num_layers) + ] + ffn1_bias_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.linear1.bias".format(i)) for i in range(self.num_layers) + ] + ffn2_weight_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.linear2.weight".format(i)) for i in range(self.num_layers) + ] + ffn2_bias_attrs = [ + paddle.ParamAttr(name="gpt.decoder.layers.{}.linear2.bias".format(i)) for i in range(self.num_layers) ] - ffn_ln_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn_ln_bias".format(i)) for i in range(self.num_layers)] - ffn1_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_weight".format(i)) for i in range(self.num_layers)] - ffn1_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(self.num_layers)] - ffn2_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(self.num_layers)] - ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(self.num_layers)] self.transformer_block = FusedMultiTransformer( config.hidden_size, config.num_attention_heads, @@ -171,7 +198,7 @@ def forward( hidden_states = inputs_embeds - with paddle.base.framework._stride_in_no_check_dy2st_diff(): + with dy2st_nocheck_guard_context(): hidden_states, _ = self.transformer_block( input_ids, hidden_states, From d847287b02f04d93f9d91348c942c9fcedae2930 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 19 Sep 2023 13:56:22 +0000 Subject: [PATCH 9/9] update --- llm/utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/llm/utils.py b/llm/utils.py index e46447383dd8..2ace51032ce9 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -394,6 +394,7 @@ def dybatch_preprocess( benchmark: bool = False, ): """Pre-process generation inputs.""" + inputs = {} if "chatglm" in architectures: input_ids = [] position_ids = [] @@ -403,7 +404,6 @@ def dybatch_preprocess( input_ids.append(tokens["input_ids"][0]) position_ids.append(tokens["position_ids"][0]) - inputs = {} pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][0] inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True) bs = inputs["input_ids"].shape[0] @@ -413,9 +413,8 @@ def dybatch_preprocess( for i in range(len(position_ids)): inst_data_pos.append(np.array([list(inst) + [0] * (max_len - len(inst)) for inst in position_ids[i]])) inputs["position_ids"] = paddle.to_tensor(np.array(inst_data_pos)) - if "gpt" in architectures: + elif "gpt" in architectures: input_ids = [] - position_ids = [] if isinstance(texts, str): texts = [texts] @@ -430,7 +429,6 @@ def dybatch_preprocess( ) input_ids.append(tokens["input_ids"][0]) - inputs = {} pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1] inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True) bs = inputs["input_ids"].shape[0] @@ -439,12 +437,11 @@ def dybatch_preprocess( position_ids = paddle.arange(sum(seq_len), dtype="int64") pre_len = seq_len[0] for length in seq_len[1:]: - position_ids[pre_len:length + pre_len] = position_ids[pre_len:length + pre_len] - pre_len + position_ids[pre_len : length + pre_len] = position_ids[pre_len : length + pre_len] - pre_len pre_len += length inputs["position_ids"] = position_ids else: input_ids = [] - position_ids = [] if isinstance(texts, str): texts = [texts] @@ -459,7 +456,6 @@ def dybatch_preprocess( ) input_ids.append(tokens["input_ids"][0]) - inputs = {} pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1] inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True) bs = inputs["input_ids"].shape[0]