Skip to content

Commit

Permalink
fix rotary_emb for llama
Browse files Browse the repository at this point in the history
  • Loading branch information
EnflameGCU committed May 22, 2024
1 parent 87e4c4f commit dd1e52b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
4 changes: 3 additions & 1 deletion paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() != "gcu":
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 61 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L60-L61

Added lines #L60 - L61 were not covered by tests
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
elif get_env_device() == "gcu":
cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)

Check warning on line 66 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L66

Added line #L66 was not covered by tests
query_states, key_states = core.eager._run_custom_op(
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
)
Expand Down
28 changes: 19 additions & 9 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,14 @@ def forward(self, x, seq_len=None):
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
self.cos_sin_table.cast(x.dtype)
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
else self.cos_sin_table,
)

def get_fused_cos_sin(self, x, seq_len=None):
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype:
return self.cos_sin_table.cast(x.dtype)

Check warning on line 426 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L425-L426

Added lines #L425 - L426 were not covered by tests
else:
return self.cos_sin_table

Check warning on line 428 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L428

Added line #L428 was not covered by tests


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
Expand Down Expand Up @@ -482,19 +485,26 @@ def _scale_cos_sin(self, seq_len):
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_position_embeddings:
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len)

Check warning on line 488 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L488

Added line #L488 was not covered by tests
else:
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table
scale_cos, scale_sin = self.cos_cached, self.sin_cached

Check warning on line 490 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L490

Added line #L490 was not covered by tests
cos = scale_cos[:, :seq_len, :, ...]
sin = scale_sin[:, :seq_len, :, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
scale_cos_sin.cast(x.dtype)
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
else scale_cos_sin,
)

def get_fused_cos_sin(self, x, seq_len=None):
if seq_len > self.max_position_embeddings:
_, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)

Check warning on line 500 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L499-L500

Added lines #L499 - L500 were not covered by tests
else:
scale_cos_sin = self.cos_sin_table
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype:
return scale_cos_sin.cast(x.dtype)

Check warning on line 504 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L502-L504

Added lines #L502 - L504 were not covered by tests
else:
return scale_cos_sin

Check warning on line 506 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L506

Added line #L506 was not covered by tests


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -943,7 +953,7 @@ def forward(
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down

0 comments on commit dd1e52b

Please sign in to comment.