From 621118e4870bcce20b942329ba7a7ca7d3c94bc6 Mon Sep 17 00:00:00 2001 From: EnflameGCU <118410644+EnflameGCU@users.noreply.github.com> Date: Wed, 22 May 2024 20:44:25 +0800 Subject: [PATCH] fix rotary_emb for llama (#8470) --- paddlenlp/transformers/llama/fusion_ops.py | 4 +++- paddlenlp/transformers/llama/modeling.py | 28 +++++++++++++++------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index f9cdf8547dfd..6009a80911d5 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -57,11 +57,13 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_ assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape - cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len) + if get_env_device() != "gcu": + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) if get_env_device() == "npu": query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] elif get_env_device() == "gcu": + cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len) query_states, key_states = core.eager._run_custom_op( "fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True ) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 78fc26d08dbf..366f7ff3c083 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -419,11 +419,14 @@ def forward(self, x, seq_len=None): return ( cos.cast(x.dtype) if cos.dtype != x.dtype else cos, sin.cast(x.dtype) if sin.dtype != x.dtype else sin, - self.cos_sin_table.cast(x.dtype) - if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype - else self.cos_sin_table, ) + def get_fused_cos_sin(self, x, seq_len=None): + if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype: + return self.cos_sin_table.cast(x.dtype) + else: + return self.cos_sin_table + class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): @@ -482,19 +485,26 @@ def _scale_cos_sin(self, seq_len): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_position_embeddings: - scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len) + scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len) else: - scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table + scale_cos, scale_sin = self.cos_cached, self.sin_cached cos = scale_cos[:, :seq_len, :, ...] sin = scale_sin[:, :seq_len, :, ...] return ( cos.cast(x.dtype) if cos.dtype != x.dtype else cos, sin.cast(x.dtype) if sin.dtype != x.dtype else sin, - scale_cos_sin.cast(x.dtype) - if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype - else scale_cos_sin, ) + def get_fused_cos_sin(self, x, seq_len=None): + if seq_len > self.max_position_embeddings: + _, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len) + else: + scale_cos_sin = self.cos_sin_table + if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype: + return scale_cos_sin.cast(x.dtype) + else: + return scale_cos_sin + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -943,7 +953,7 @@ def forward( sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin, ) else: - cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)