Skip to content

Commit

Permalink
format with black
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGuge committed Dec 19, 2023
1 parent 291cd79 commit 29afaa9
Showing 1 changed file with 37 additions and 66 deletions.
103 changes: 37 additions & 66 deletions paddlenlp/experimental/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
CausalLMOutputWithPast,
)
from paddlenlp.transformers import QWenPretrainedModel, QWenConfig
from paddlenlp.transformers.qwen.modeling import (
QWenLMHead,
QWenPretrainingCriterion
)
from paddlenlp.transformers.qwen.modeling import QWenLMHead, QWenPretrainingCriterion
from paddlenlp.transformers.model_utils import (
dy2st_nocheck_guard_context,
register_base_model,
Expand All @@ -54,15 +51,14 @@ def __init__(self, config):
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)

def forward(self, x):
result = paddle.incubate.nn.functional.fused_rms_norm(
x, self.weight, None, self.eps, begin_norm_axis=1
)
result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1)
if isinstance(result, tuple):
return result[0]
return result


@register_base_model
class QWenInferenceModel(QWenPretrainedModel):
def __init__(self, config: QWenConfig):
Expand Down Expand Up @@ -98,10 +94,7 @@ def __init__(self, config: QWenConfig):
)
for i in range(self.num_layers)
]
qkv_bias_attrs = [
paddle.ParamAttr(name="fuseqwen.{}.qkv_bias".format(i))
for i in range(self.num_layers)
]
qkv_bias_attrs = [paddle.ParamAttr(name="fuseqwen.{}.qkv_bias".format(i)) for i in range(self.num_layers)]
out_proj_weight_attrs = [
paddle.ParamAttr(
name="fuseqwen.{}.out_proj_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0)
Expand Down Expand Up @@ -174,7 +167,7 @@ def __init__(self, config: QWenConfig):
self.transformer_block = FusedMultiTransformerBase(transformer_config)

self.ln_f = FusedQWenRMSNorm(config)

self.cache_kvs = None
self.head_dim_shape_tensor = paddle.ones((self.hidden_size // self.num_attention_heads), dtype="int8")

Expand All @@ -186,103 +179,84 @@ def set_input_embeddings(self, value):

@paddle.no_grad()
def set_state_dict(self, state_dict):
wte_weight = paddle.to_tensor(
state_dict["qwen.wte.weight"],
dtype=self.wte.weight.dtype
)
ln_f_weight = paddle.to_tensor(
state_dict["qwen.ln_f.weight"],
dtype=self.ln_f.weight.dtype
)
wte_weight = paddle.to_tensor(state_dict["qwen.wte.weight"], dtype=self.wte.weight.dtype)
ln_f_weight = paddle.to_tensor(state_dict["qwen.ln_f.weight"], dtype=self.ln_f.weight.dtype)
self.wte.weight.set_value(wte_weight)
self.ln_f.weight.set_value(ln_f_weight)

for idx in range(self.num_layers):
ln_scale = paddle.to_tensor(
state_dict["qwen.h.{}.ln_1.weight".format(idx)],
dtype=self.transformer_block.ln_scales[idx].dtype
state_dict["qwen.h.{}.ln_1.weight".format(idx)], dtype=self.transformer_block.ln_scales[idx].dtype
)
self.transformer_block.ln_scales[idx].set_value(ln_scale)


qkv_weight = paddle.to_tensor(
state_dict["qwen.h.{}.attn.c_attn.weight".format(idx)].transpose([1, 0]),
dtype=self.transformer_block.qkv_weights[idx].dtype
dtype=self.transformer_block.qkv_weights[idx].dtype,
)
if self.use_weight_only:
qkv_weight = paddle.transpose(qkv_weight, perm=[1, 0])
qkv_quanted_weight, qkv_weight_scale = weight_quantize(
qkv_weight, algo=self.quant_type
)
qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_type)
self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight)
self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale)
else:
self.transformer_block.qkv_weights[idx].set_value(qkv_weight)

qkv_bias = paddle.to_tensor(
state_dict["qwen.h.{}.attn.c_attn.bias".format(idx)],
dtype=self.transformer_block.qkv_biases[idx].dtype
dtype=self.transformer_block.qkv_biases[idx].dtype,
)
self.transformer_block.qkv_biases[idx].set_value(qkv_bias)

linear_weight = paddle.to_tensor(
state_dict["qwen.h.{}.attn.c_proj.weight".format(idx)],
dtype=self.transformer_block.linear_weights[idx].dtype
dtype=self.transformer_block.linear_weights[idx].dtype,
)
if self.use_weight_only:
linear_quanted_weight, linear_weight_scale = weight_quantize(
linear_weight, algo=self.quant_type
)
linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_type)
self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight)
self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale)
else:
self.transformer_block.linear_weights[idx].set_value(linear_weight)

ffn_ln_scale = paddle.to_tensor(
state_dict["qwen.h.{}.ln_2.weight".format(idx)],
dtype=self.transformer_block.ffn_ln_scales[idx].dtype
state_dict["qwen.h.{}.ln_2.weight".format(idx)], dtype=self.transformer_block.ffn_ln_scales[idx].dtype
)
self.transformer_block.ffn_ln_scales[idx].set_value(ffn_ln_scale)

up_weight = paddle.to_tensor(
state_dict["qwen.h.{}.mlp.w1.weight".format(idx)],
dtype=self.transformer_block.ffn1_weights[idx].dtype
state_dict["qwen.h.{}.mlp.w1.weight".format(idx)], dtype=self.transformer_block.ffn1_weights[idx].dtype
)
gate_weight = paddle.to_tensor(
state_dict["qwen.h.{}.mlp.w2.weight".format(idx)],
dtype=self.transformer_block.ffn1_weights[idx].dtype
state_dict["qwen.h.{}.mlp.w2.weight".format(idx)], dtype=self.transformer_block.ffn1_weights[idx].dtype
)
ffn1_weight = paddle.concat(x=[gate_weight, up_weight], axis=-1)
if self.use_weight_only:
ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(
ffn1_weight, algo=self.quant_type
)
ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_type)
self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight)
self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale)
else:
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight)

ffn2_weight = paddle.to_tensor(
state_dict["qwen.h.{}.mlp.c_proj.weight".format(idx)],
dtype=self.transformer_block.ffn2_weights[idx].dtype
dtype=self.transformer_block.ffn2_weights[idx].dtype,
)
if self.use_weight_only:
ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(
ffn2_weight, algo=self.quant_type
)
ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_type)
self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight)
self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale)
else:
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight)

def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
return ids_remove_padding, padding_offset, cum_offsets

def forward(
self,
input_ids=None,
Expand All @@ -303,12 +277,12 @@ def forward(
# kwargs["cache"] is used used to distinguish between encoder and decoder phase.
past_key_values = kwargs.get("cache", None)
is_decoder = past_key_values is not None

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand All @@ -319,17 +293,17 @@ def forward(
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])

if past_key_values is None:
past_key_values = tuple([None] * self.config.num_hidden_layers)

if not is_decoder:
ids_remove_padding, padding_offset, cum_offsets = self.remove_padding(input_ids, seq_len_encoder)
else:
ids_remove_padding = input_ids
padding_offset = None
cum_offsets = None

if inputs_embeds is None:
inputs_embeds = self.wte(ids_remove_padding)
hidden_states = inputs_embeds
Expand All @@ -340,7 +314,7 @@ def forward(
all_self_attentions = () if output_attentions else None

seq_lens = seq_len_decoder if is_decoder else seq_len_encoder

position_offset = 0
if not is_decoder and pre_caches is not None:
position_offset = 128
Expand All @@ -364,32 +338,33 @@ def forward(
rotary_emb_dims=1,
time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None,
)

hidden_states = self.ln_f(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


class QWenForCausalLMInferenceModel(GenerationInferenceModel, QWenPretrainedModel):
def __init__(self, config: QWenConfig, **kwargs):
super(QWenForCausalLMInferenceModel, self).__init__(config)
self.qwen = QWenInferenceModel(config)
self.lm_head = QWenLMHead(config)
self.criterion = QWenPretrainingCriterion(config)

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

Expand All @@ -400,7 +375,7 @@ def from_pretrained(
# TODO: Support safetensors loading.
kwargs["use_safetensors"] = True
return super().from_pretrained(pretrained_model_name_or_path, from_hf_hub, subfolder, *args, **kwargs)

@classmethod
def get_cache_kvs_shape(
cls, config: QWenConfig, max_batch_size: int = None, max_length: int = None
Expand All @@ -416,7 +391,7 @@ def get_cache_kvs_shape(
"""
if max_length is None:
max_length = config.max_position_embeddings

cache_kvs = []
for _ in range(config.num_hidden_layers):
cache_kvs.append(
Expand Down Expand Up @@ -537,10 +512,6 @@ def forward(
@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
lm_head_weight = paddle.to_tensor(
state_dict["lm_head.weight"],
dtype=self.lm_head.weight.dtype
)
lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype)
self.lm_head.weight.set_value(lm_head_weight)
self.qwen.set_state_dict({k: state_dict[k] for k in state_dict.keys()})

0 comments on commit 29afaa9

Please sign in to comment.