Skip to content

Commit

Permalink
gqa fuse attention qkv
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Jan 24, 2024
1 parent 082dc52 commit 7b04017
Showing 1 changed file with 39 additions and 13 deletions.
52 changes: 39 additions & 13 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
self.head_dim = self.hidden_size // config.num_attention_heads

self.num_key_value_heads = config.num_key_value_heads
assert config.num_attention_heads // config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads

self.max_position_embeddings = config.max_position_embeddings
self.seq_length = config.seq_length
self.sequence_parallel = config.sequence_parallel

self.fuse_attention_qkv = config.fuse_attention_qkv
if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads:
raise ValueError(
f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}"
)

self.kv_indices = None
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
Expand All @@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if self.num_key_value_heads % config.tensor_parallel_degree == 0:
self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree
else:
if self.fuse_attention_qkv:

Check warning on line 616 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L616

Added line #L616 was not covered by tests
# TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
raise ValueError(

Check warning on line 618 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L618

Added line #L618 was not covered by tests
f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0"
)
logger.warning(
f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight."
)
Expand Down Expand Up @@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if self.fuse_attention_qkv:
self.qkv_proj = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
has_bias=False,
gather_output=False,
)
Expand Down Expand Up @@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.hidden_size,
3 * self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
Expand Down Expand Up @@ -776,7 +779,11 @@ def forward(
assert self.seq_length % self.config.sep_parallel_degree == 0
mix_layer = paddle.reshape_(
mix_layer,
[-1, self.seq_length // self.config.sep_parallel_degree, 3 * self.num_heads * self.head_dim],
[
-1,
self.seq_length // self.config.sep_parallel_degree,
self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim,
],
)
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
mix_layer = self.reshard_layer(
Expand All @@ -785,15 +792,26 @@ def forward(
concat_axis=1,
)
mix_layer = paddle.reshape_(
mix_layer, [0, self.seq_length, -1, 3 * self.head_dim]
mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim]
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
else:
if self.sequence_parallel:
target_shape = [-1, self.seq_length, self.num_heads, 3 * self.head_dim]
target_shape = [

Check warning on line 799 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L799

Added line #L799 was not covered by tests
-1,
self.seq_length,
self.num_key_value_heads,
(self.num_key_value_groups + 2) * self.head_dim,
]
else:
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]

Check warning on line 806 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L806

Added line #L806 was not covered by tests
mix_layer = paddle.reshape_(mix_layer, target_shape)
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
query_states, key_states, value_states = paddle.split(

Check warning on line 808 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L808

Added line #L808 was not covered by tests
mix_layer,
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
axis=-1,
)
if self.gqa_or_mqa:
query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim])

Check warning on line 814 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L813-L814

Added lines #L813 - L814 were not covered by tests
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -807,11 +825,19 @@ def forward(
)
key_states = paddle.reshape(
key_states,
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
[
-1,
self.seq_length // self.config.sep_parallel_degree,
self.num_key_value_heads * self.head_dim,
],
)
value_states = paddle.reshape(
value_states,
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
[
-1,
self.seq_length // self.config.sep_parallel_degree,
self.num_key_value_heads * self.head_dim,
],
)
query_states = self.reshard_layer(
query_states,
Expand Down

0 comments on commit 7b04017

Please sign in to comment.