Skip to content

Commit

Permalink
fix rotary_emb for llama (#8470)
Browse files Browse the repository at this point in the history
  • Loading branch information
EnflameGCU authored May 22, 2024
1 parent 70bffa8 commit 621118e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
4 changes: 3 additions & 1 deletion paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() != "gcu":
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
elif get_env_device() == "gcu":
cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
query_states, key_states = core.eager._run_custom_op(
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
)
Expand Down
28 changes: 19 additions & 9 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,14 @@ def forward(self, x, seq_len=None):
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
self.cos_sin_table.cast(x.dtype)
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
else self.cos_sin_table,
)

def get_fused_cos_sin(self, x, seq_len=None):
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype:
return self.cos_sin_table.cast(x.dtype)
else:
return self.cos_sin_table


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
Expand Down Expand Up @@ -482,19 +485,26 @@ def _scale_cos_sin(self, seq_len):
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_position_embeddings:
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len)
else:
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table
scale_cos, scale_sin = self.cos_cached, self.sin_cached
cos = scale_cos[:, :seq_len, :, ...]
sin = scale_sin[:, :seq_len, :, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
scale_cos_sin.cast(x.dtype)
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
else scale_cos_sin,
)

def get_fused_cos_sin(self, x, seq_len=None):
if seq_len > self.max_position_embeddings:
_, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
else:
scale_cos_sin = self.cos_sin_table
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype:
return scale_cos_sin.cast(x.dtype)
else:
return scale_cos_sin


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -943,7 +953,7 @@ def forward(
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down

0 comments on commit 621118e

Please sign in to comment.