Skip to content

Commit

Permalink
add annotation for the fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Jan 31, 2024
1 parent 896c01f commit c79b7e2
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,17 @@ def forward(

if self.fuse_attention_qkv:
mix_layer = self.qkv_proj(hidden_states)
# NOTE for GQA attention fusion (compatible with MHA and MQA):
# The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
# After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
# Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
# where num_groups = num_q_heads // num_kv_heads.
# Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
# to represent the q, k and v respectively.
# The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
# The k and v are in the shape like [b, s, num_kv_heads, head_dim].
# Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
# But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
if self.reshard_layer is not None:
if self.sequence_parallel:
assert self.seq_length % self.config.sep_parallel_degree == 0
Expand Down

0 comments on commit c79b7e2

Please sign in to comment.