Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix rotary_emb for llama #8470

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() != "gcu":
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 61 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L60-L61

Added lines #L60 - L61 were not covered by tests
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
elif get_env_device() == "gcu":
cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)

Check warning on line 66 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L66

Added line #L66 was not covered by tests
query_states, key_states = core.eager._run_custom_op(
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
)
Expand Down
28 changes: 19 additions & 9 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,14 @@
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
self.cos_sin_table.cast(x.dtype)
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
else self.cos_sin_table,
)

def get_fused_cos_sin(self, x, seq_len=None):
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype:
return self.cos_sin_table.cast(x.dtype)

Check warning on line 426 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L425-L426

Added lines #L425 - L426 were not covered by tests
else:
return self.cos_sin_table

Check warning on line 428 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L428

Added line #L428 was not covered by tests


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
Expand Down Expand Up @@ -482,19 +485,26 @@
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_position_embeddings:
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len)

Check warning on line 488 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L488

Added line #L488 was not covered by tests
else:
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table
scale_cos, scale_sin = self.cos_cached, self.sin_cached

Check warning on line 490 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L490

Added line #L490 was not covered by tests
cos = scale_cos[:, :seq_len, :, ...]
sin = scale_sin[:, :seq_len, :, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
scale_cos_sin.cast(x.dtype)
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
else scale_cos_sin,
)

def get_fused_cos_sin(self, x, seq_len=None):
if seq_len > self.max_position_embeddings:
_, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)

Check warning on line 500 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L499-L500

Added lines #L499 - L500 were not covered by tests
else:
scale_cos_sin = self.cos_sin_table
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype:
return scale_cos_sin.cast(x.dtype)

Check warning on line 504 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L502-L504

Added lines #L502 - L504 were not covered by tests
else:
return scale_cos_sin

Check warning on line 506 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L506

Added line #L506 was not covered by tests


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -943,7 +953,7 @@
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
Loading