From 01163e0fa061376e1af3270db9bd19044abe72b7 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Fri, 26 Jul 2024 13:34:34 +0800 Subject: [PATCH] [DCU] fix llama inference bug on DCU --- .../experimental/transformers/fused_transformer_layers.py | 6 +----- paddlenlp/experimental/transformers/llama/modeling.py | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index be8bcad1c878..7c21f0a261b1 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -550,11 +550,7 @@ def init_weight_shape(self, config): if config.trans_qkvw else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim] ) - self.linear_weight_shape = ( - [self.num_heads * self.head_dim, self.embed_dim] - if config.trans_qkvw - else [self.embed_dim, self.num_heads * self.head_dim] - ) + self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim] self.ffn1_weight_shape = ( [self.embed_dim, self.dim_feedforward * 2] if self.activation.endswith("glu") diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index ac226907a41e..d0cdbe9b9081 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -565,7 +565,7 @@ def __init__(self, config: LlamaConfig): use_neox_rotary_style=True, use_dynamic_cachekv_quant=config.use_cachekv_int8 == "dynamic", rank_id=config.tensor_parallel_rank, - trans_qkvw=(True if not paddle.is_compiled_with_rocm() else False), + trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True), ) self.set_transformer_block(transformer_config) @@ -752,7 +752,7 @@ def set_state_dict(self, state_dict): unfused_state_dict["self_attn.v_proj.weight"] = state_dict[ "llama.layers.{}.self_attn.v_proj.weight".format(idx) ] - if paddle.is_compiled_with_rocm(): + if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8": concated_qkv_weight = np.concatenate( [ unfused_state_dict["self_attn.q_proj.weight"],