diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 61405f6c4e28..6098079d9084 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -64,7 +64,11 @@ from paddlenlp.utils.download import resolve_file_path from paddlenlp.utils.log import logger -__all__ = ["Qwen2ForCausalLMInferenceModel", "Qwen2ForCausalLMBlockInferenceModel"] +__all__ = [ + "Qwen2ForCausalLMInferenceModel", + "Qwen2ForCausalLMBlockInferenceModel", + "Qwen2VLForConditionalGenerationBlockInferenceModel", +] class FusedQwen2RMSNorm(nn.Layer): @@ -551,23 +555,28 @@ def set_state_dict(self, state_dict): self.transformer_block.init_weight() split_fn = split_param_func() self.embed_tokens.weight.set_value( - paddle.to_tensor(state_dict["qwen2.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) + paddle.to_tensor(state_dict[f"{self.base_model_prefix}.embed_tokens.weight"]).cast( + self.embed_tokens.weight.dtype + ) + ) + self.norm.weight.set_value( + paddle.to_tensor(state_dict[f"{self.base_model_prefix}.norm.weight"]).cast(self.norm.weight.dtype) ) - self.norm.weight.set_value(paddle.to_tensor(state_dict["qwen2.norm.weight"]).cast(self.norm.weight.dtype)) for idx in range(self.num_layers): + model_prefix = self.base_model_prefix + f".layers.{idx}" logger.info(f"set state for layer {idx}") - ln_scale = paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.weight".format(idx)]).cast( + ln_scale = paddle.to_tensor(state_dict[f"{model_prefix}.input_layernorm.weight"]).cast( self.transformer_block.ln_scales[idx].dtype ) self.transformer_block.ln_scales[idx].set_value(ln_scale) - if "qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys(): + if f"{model_prefix}.self_attn.qkv_proj.weight" in state_dict.keys(): concated_qkv_weight = paddle.to_tensor( np.concatenate( split_fn( - state_dict["qwen2.layers.{}.self_attn.qkv_proj.weight".format(idx)], + state_dict[f"{model_prefix}.self_attn.qkv_proj.weight"], is_qkv=True, num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, @@ -578,13 +587,13 @@ def set_state_dict(self, state_dict): else: unfused_state_dict = {} unfused_state_dict["self_attn.q_proj.weight"] = paddle.to_tensor( - state_dict["qwen2.layers.{}.self_attn.q_proj.weight".format(idx)] + state_dict[f"{model_prefix}.self_attn.q_proj.weight"] ) unfused_state_dict["self_attn.k_proj.weight"] = paddle.to_tensor( - state_dict["qwen2.layers.{}.self_attn.k_proj.weight".format(idx)] + state_dict[f"{model_prefix}.self_attn.k_proj.weight"] ) unfused_state_dict["self_attn.v_proj.weight"] = paddle.to_tensor( - state_dict["qwen2.layers.{}.self_attn.v_proj.weight".format(idx)] + state_dict[f"{model_prefix}.self_attn.v_proj.weight"] ) if "fp8" in self.quant_type: q_wgt_scale = self.transformer_block.weight_scales["q_weight_scale"][idx] @@ -658,30 +667,17 @@ def set_state_dict(self, state_dict): else: self.transformer_block.qkv_weights[idx].set_value(qkv_weight) - unfused_state_dict["qwen2.self_attn.q_proj.bias"] = state_dict[ - "qwen2.layers.{}.self_attn.q_proj.bias".format(idx) - ] - unfused_state_dict["qwen2.self_attn.k_proj.bias"] = state_dict[ - "qwen2.layers.{}.self_attn.k_proj.bias".format(idx) - ] - unfused_state_dict["qwen2.self_attn.v_proj.bias"] = state_dict[ - "qwen2.layers.{}.self_attn.v_proj.bias".format(idx) - ] + q_bias = state_dict[f"{model_prefix}.self_attn.q_proj.bias"] + k_bias = state_dict[f"{model_prefix}.self_attn.k_proj.bias"] + v_bias = state_dict[f"{model_prefix}.self_attn.v_proj.bias"] - concated_qkv_biases = np.concatenate( - [ - unfused_state_dict["qwen2.self_attn.q_proj.bias"], - unfused_state_dict["qwen2.self_attn.k_proj.bias"], - unfused_state_dict["qwen2.self_attn.v_proj.bias"], - ], - axis=-1, - ) + concated_qkv_biases = np.concatenate([q_bias, k_bias, v_bias], axis=-1) qkv_bias = paddle.to_tensor(concated_qkv_biases) self.transformer_block.qkv_biases[idx].set_value( qkv_bias.cast(self.transformer_block.qkv_biases[idx].dtype) ) - linear_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).cast( + linear_weight = paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.weight"]).cast( paddle.get_default_dtype() ) if self.use_weight_only: @@ -691,9 +687,7 @@ def set_state_dict(self, state_dict): elif "fp8" in self.quant_type: self.transformer_block.linear_weights[idx].copy_( paddle.cast( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]).transpose( - (1, 0) - ), + paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.weight"]).transpose((1, 0)), "float8_e4m3fn", ), False, @@ -707,16 +701,14 @@ def set_state_dict(self, state_dict): if paddle.is_compiled_with_rocm(): self.transformer_block.linear_weights[idx].set_value( paddle.cast( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)]), + paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.weight"]), w_dtype, ) ) else: self.transformer_block.linear_weights[idx].set_value( paddle.cast( - paddle.to_tensor( - state_dict["qwen2.layers.{}.self_attn.o_proj.weight".format(idx)] - ).transpose((1, 0)), + paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.weight"]).transpose((1, 0)), w_dtype, ) ) @@ -726,22 +718,20 @@ def set_state_dict(self, state_dict): ) ffn_ln_scale = paddle.to_tensor( - state_dict["qwen2.layers.{}.post_attention_layernorm.weight".format(idx)], + state_dict[f"{model_prefix}.post_attention_layernorm.weight"], ) self.transformer_block.ffn_ln_scales[idx].set_value( ffn_ln_scale.cast(self.transformer_block.ffn_ln_scales[idx].dtype) ) - if "qwen2.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys(): + if f"{model_prefix}.mlp.gate_up_fused_proj.weight" in state_dict.keys(): concated_ffn1_weight = np.concatenate( - split_fn(state_dict["qwen2.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1 + split_fn(state_dict[f"{model_prefix}.mlp.gate_up_fused_proj.weight"]), axis=-1 ) else: - unfused_state_dict["mlp.gate_proj.weight"] = state_dict[ - "qwen2.layers.{}.mlp.gate_proj.weight".format(idx) - ] - unfused_state_dict["mlp.up_proj.weight"] = state_dict["qwen2.layers.{}.mlp.up_proj.weight".format(idx)] + unfused_state_dict["mlp.gate_proj.weight"] = state_dict[f"{model_prefix}.mlp.gate_proj.weight"] + unfused_state_dict["mlp.up_proj.weight"] = state_dict[f"{model_prefix}.mlp.up_proj.weight"] concated_ffn1_weight = np.concatenate( [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 ) @@ -781,14 +771,14 @@ def set_state_dict(self, state_dict): ffn1_weight.cast(self.transformer_block.ffn1_weights[idx].dtype) ) - ffn2_weight = paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]) + ffn2_weight = paddle.to_tensor(state_dict[f"{model_prefix}.mlp.down_proj.weight"]) if self.use_weight_only: ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_algo) self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight) self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale) elif "fp8" in self.quant_type: self.transformer_block.ffn2_weights[idx].copy_( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.weight".format(idx)]) + paddle.to_tensor(state_dict[f"{model_prefix}.mlp.down_proj.weight"]) .transpose([1, 0]) .cast("float8_e4m3fn"), False, @@ -811,57 +801,57 @@ def set_state_dict(self, state_dict): if "fp8" not in self.quant_type and "a8w8" in self.quant_type: if self.shift_smooth_all_linears: if self.use_fake_parameter: - if "qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx) not in state_dict: - state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)] = paddle.zeros( + if f"{model_prefix}.self_attn.o_proj.shift_bias" not in state_dict: + state_dict[f"{model_prefix}.self_attn.o_proj.shift_bias"] = paddle.zeros( shape=[ (self.num_attention_heads // self.config.tensor_parallel_degree) * (self.hidden_size // self.num_attention_heads) ], dtype=paddle.get_default_dtype(), ) - state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] = paddle.ones( + state_dict[f"{model_prefix}.self_attn.o_proj.smooth_weight"] = paddle.ones( shape=[ (self.num_attention_heads // self.config.tensor_parallel_degree) * (self.hidden_size // self.num_attention_heads) ], dtype=paddle.get_default_dtype(), ) - state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)] = paddle.zeros( + state_dict[f"{model_prefix}.mlp.down_proj.shift_bias"] = paddle.zeros( shape=[self.intermediate_size // self.config.tensor_parallel_degree], dtype=paddle.get_default_dtype(), ) - state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)] = paddle.ones( + state_dict[f"{model_prefix}.mlp.down_proj.smooth_weight"] = paddle.ones( shape=[self.intermediate_size // self.config.tensor_parallel_degree], dtype=paddle.get_default_dtype(), ) self.transformer_block.linear_shifts[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype( + paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.shift_bias"]).astype( paddle.get_default_dtype() ) ) self.transformer_block.linear_smooths[idx].set_value( - paddle.to_tensor( - state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] - ).astype(paddle.get_default_dtype()) + paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.smooth_weight"]).astype( + paddle.get_default_dtype() + ) ) self.transformer_block.ffn2_shifts[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype( + paddle.to_tensor(state_dict[f"{model_prefix}.mlp.down_proj.shift_bias"]).astype( paddle.get_default_dtype() ) ) self.transformer_block.ffn2_smooths[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype( + paddle.to_tensor(state_dict[f"{model_prefix}.mlp.down_proj.smooth_weight"]).astype( paddle.get_default_dtype() ) ) if self.shift: if self.use_fake_parameter: - if "qwen2.layers.{}.input_layernorm.bias".format(idx) not in state_dict: - state_dict["qwen2.layers.{}.input_layernorm.bias".format(idx)] = paddle.zeros( + if f"{model_prefix}.input_layernorm.bias" not in state_dict: + state_dict[f"{model_prefix}.input_layernorm.bias"] = paddle.zeros( shape=[self.hidden_size], dtype=paddle.get_default_dtype() ) - state_dict["qwen2.layers.{}.post_attention_layernorm.bias".format(idx)] = paddle.zeros( + state_dict[f"{model_prefix}.post_attention_layernorm.bias"] = paddle.zeros( [self.hidden_size], dtype=paddle.get_default_dtype() ) unfused_state_dict["self_attn.q_proj.bias"] = paddle.zeros( @@ -884,26 +874,22 @@ def set_state_dict(self, state_dict): ) else: unfused_state_dict["self_attn.q_proj.bias"] = state_dict[ - "qwen2.layers.{}.self_attn.q_proj.bias".format(idx) + f"{model_prefix}.self_attn.q_proj.bias" ] unfused_state_dict["self_attn.k_proj.bias"] = state_dict[ - "qwen2.layers.{}.self_attn.k_proj.bias".format(idx) + f"{model_prefix}.self_attn.k_proj.bias" ] unfused_state_dict["self_attn.v_proj.bias"] = state_dict[ - "qwen2.layers.{}.self_attn.v_proj.bias".format(idx) - ] - unfused_state_dict["mlp.gate_proj.bias"] = state_dict[ - "qwen2.layers.{}.mlp.gate_proj.bias".format(idx) - ] - unfused_state_dict["mlp.up_proj.bias"] = state_dict[ - "qwen2.layers.{}.mlp.up_proj.bias".format(idx) + f"{model_prefix}.self_attn.v_proj.bias" ] + unfused_state_dict["mlp.gate_proj.bias"] = state_dict[f"{model_prefix}.mlp.gate_proj.bias"] + unfused_state_dict["mlp.up_proj.bias"] = state_dict[f"{model_prefix}.mlp.up_proj.bias"] self.transformer_block.ln_biases[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.bias".format(idx)]) + paddle.to_tensor(state_dict[f"{model_prefix}.input_layernorm.bias"]) ) self.transformer_block.ffn_ln_biases[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.post_attention_layernorm.bias".format(idx)]) + paddle.to_tensor(state_dict[f"{model_prefix}.post_attention_layernorm.bias"]) ) concated_qkv_biases = np.concatenate( [ @@ -922,18 +908,18 @@ def set_state_dict(self, state_dict): if self.shift_smooth_all_linears: if self.use_fake_parameter: - if "qwen2.layers.{}.self_attn.o_proj.bias".format(idx) not in state_dict: - state_dict["qwen2.layers.{}.self_attn.o_proj.bias".format(idx)] = paddle.zeros( + if f"{model_prefix}.self_attn.o_proj.bias" not in state_dict: + state_dict[f"{model_prefix}.self_attn.o_proj.bias"] = paddle.zeros( [self.hidden_size], dtype=paddle.get_default_dtype() ) - state_dict["qwen2.layers.{}.mlp.down_proj.layer.bias".format(idx)] = paddle.zeros( + state_dict[f"{model_prefix}.mlp.down_proj.layer.bias"] = paddle.zeros( [self.hidden_size], dtype=paddle.get_default_dtype() ) self.transformer_block.linear_biases[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.bias".format(idx)]) + paddle.to_tensor(state_dict[f"{model_prefix}.self_attn.o_proj.bias"]) ) self.transformer_block.ffn2_biases[idx].set_value( - paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.layer.bias".format(idx)]) + paddle.to_tensor(state_dict[f"{model_prefix}.mlp.down_proj.layer.bias"]) ) def remove_padding(self, input_ids, seq_lens_this_time): @@ -1286,7 +1272,14 @@ def forward( kwargs["padding_offsets"] = padding_offset kwargs["max_input_length"] = self.max_seq_len - inputs_embeds = self.embed_tokens(ids_remove_padding) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(ids_remove_padding) + else: + assert len(inputs_embeds.shape) == 3 + # This is the case in the image-to-text model such as qwen2-vl, + # In the prefill phase, the language model is first fed with inputs_embeds instead of input_ids + # but in decoder phase, the language model is fed with input_ids just like normal text-to-text model. + inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]]) with dy2st_nocheck_guard_context(): hidden_states, _ = self.transformer_block( @@ -1425,6 +1418,7 @@ def get_cache_kvs_shape( def prepare_inputs_for_generation(self, **kwargs): # only last token for inputs_ids if cache is defined in kwargs input_ids = kwargs["input_ids"] + inputs_embeds = kwargs.get("inputs_embeds", None) src_mask = kwargs.get("src_mask", None) block_tables = kwargs.get("block_tables", None) @@ -1446,6 +1440,7 @@ def prepare_inputs_for_generation(self, **kwargs): model_inputs = { "input_ids": input_ids, + "inputs_embeds": inputs_embeds, "src_mask": src_mask, "rope_emb": rope_emb, "pre_caches": pre_caches, @@ -1466,6 +1461,7 @@ def prepare_inputs_for_generation(self, **kwargs): def forward( self, input_ids, + inputs_embeds=None, src_mask=None, pre_caches=None, caches=None, @@ -1483,6 +1479,7 @@ def forward( ): outputs = self.qwen2( input_ids, + inputs_embeds=inputs_embeds, src_mask=src_mask, caches=caches, rope_emb=rope_emb, @@ -1514,3 +1511,15 @@ def set_state_dict(self, state_dict): paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) ) self.qwen2.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + + +class Qwen2VLForConditionalGenerationBlockInferenceModel(Qwen2ForCausalLMBlockInferenceModel): + """ + NOTE: (changwenbin) This class inherits from Qwen2ForCausalLMBlockInferenceModel. + Used only for QWen2-VL's second part. + """ + + # NOTE: (changwenbin) This function corresponds to QWen2-VL's second part, only used for QWen2-VL. + def __init__(self, config): + super().__init__(config) + self.qwen2.base_model_prefix = "model" diff --git a/paddlenlp/transformers/auto/configuration.py b/paddlenlp/transformers/auto/configuration.py index ff89b81d5cc2..f2058a5ec389 100644 --- a/paddlenlp/transformers/auto/configuration.py +++ b/paddlenlp/transformers/auto/configuration.py @@ -231,6 +231,9 @@ def __init__(self, mapping): self._modules = {} def __getitem__(self, key): + # NOTE: (changwenbin) This is to enable the qwen2_vl language model to use qwen2 reasoning optimization + if key == "qwen2_vl": + key = "qwen2" if key in self._extra_content: return self._extra_content[key] if key not in self._mapping: