Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bloom] Add kv cache support for flash attention & fix bugs #7735

Merged
merged 12 commits into from
Dec 29, 2023
45 changes: 15 additions & 30 deletions paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,20 @@ def forward(
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

batch_size, q_length, _, _ = query_layer.shape

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, head_dim, kv_length]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=1)
value_layer = paddle.concat((past_value, value_layer), axis=1)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

version = paddle.version.full_version
version_check = True
if version != "0.0.0" and version <= "2.5.2":
Expand All @@ -404,39 +418,10 @@ def forward(
attn_output = attn_output.reshape([attn_output.shape[0], attn_output.shape[1], -1])
output_tensor = self.dense(attn_output)

query_layer = query_layer.transpose([0, 2, 1, 3])
key_layer = key_layer.transpose([0, 2, 3, 1])
value_layer = value_layer.transpose([0, 2, 1, 3])
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, head_dim, kv_length]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=3)
value_layer = paddle.concat((past_value, value_layer), axis=2)

if use_cache:
present = (key_layer, value_layer)
else:
present = None
else:

query_layer = query_layer.transpose([0, 2, 1, 3])
key_layer = key_layer.transpose([0, 2, 3, 1])
value_layer = value_layer.transpose([0, 2, 1, 3])
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, head_dim, kv_length]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=3)
value_layer = paddle.concat((past_value, value_layer), axis=2)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

_, _, _, kv_length = key_layer.shape

query_layer = query_layer.reshape([batch_size * self.num_heads, q_length, self.head_dim])
Expand Down Expand Up @@ -949,7 +934,7 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[3]
past_key_values_length = past_key_values[0][0].shape[1]
seq_length_with_past = seq_length_with_past + past_key_values_length

if attention_mask is None:
Expand Down