Skip to content

Commit

Permalink
Update chatglm flash attention version check
Browse files Browse the repository at this point in the history
  • Loading branch information
w5688414 committed Dec 28, 2023
1 parent 887a201 commit a899fce
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 25 deletions.
11 changes: 6 additions & 5 deletions paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def forward(
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, head_dim, kv_length]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
# - key: [batch_size, kv_length, self.num_heads, head_dim]
# - value: [batch_size, kv_length, self.num_heads, head_dim]
key_layer = paddle.concat((past_key, key_layer), axis=1)
value_layer = paddle.concat((past_value, value_layer), axis=1)

Expand All @@ -394,12 +394,12 @@ def forward(

version = paddle.version.full_version
version_check = True
if version != "0.0.0" and version <= "2.5.2":
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
logger.warning(
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
)
version_check = False
if self.config.use_flash_attention and version_check:
if version_check:
query_states, key_states, value_states = query_layer, key_layer, value_layer

attention_mask = attention_mask.cast(alibi.dtype) + alibi
Expand All @@ -411,6 +411,8 @@ def forward(
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.config.attention_dropout,
training=self.training,
is_causal=False,
)
attn_weights = None
Expand All @@ -434,7 +436,6 @@ def forward(
attention_scores = baddbmm(
alibi, batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor
)

# change view to [batch_size, num_heads, q_length, kv_length]
# attention_scores = matmul_result.reshape([batch_size, self.num_heads, q_length, kv_length])

Expand Down
32 changes: 12 additions & 20 deletions paddlenlp/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,25 +245,26 @@ def forward(
q_layer, k_layer, v_layer = paddle.split(mixed_layer, 3, axis=-1)
# [s, b, n, h/n]
q_layer, k_layer = self._core_attention(q_layer, k_layer, position_ids, rotary_embeds)

if cache is not None:
cache_k, cache_v = cache[0], cache[1]
# [s + c, b, n, h/n]
k_layer = paddle.concat([cache_k, k_layer], axis=0)
v_layer = paddle.concat([cache_v, v_layer], axis=0)

cache_kv = None
if use_cache:
cache_kv = (k_layer, v_layer)
version = paddle.version.full_version
version_check = True
if version != "0.0.0" and version <= "2.5.2":
if self.config.use_flash_attention and version != "0.0.0" and version <= "2.5.2":
logger.warning(
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
)
version_check = False
if self.config.use_flash_attention and version_check:
if version_check:
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
if cache is not None:
cache_k, cache_v = cache[0], cache[1]
# [s + c, b, n, h/n]
k_layer = paddle.concat([cache_k, k_layer], axis=0)
v_layer = paddle.concat([cache_v, v_layer], axis=0)
cache_kv = None
if use_cache:
cache_kv = (k_layer, v_layer)

# [s, b, n, h/n] = > [batch_size, seq_len, num_heads, head_dim]
q_layer = paddle.transpose(q_layer, [1, 0, 2, 3])
k_layer = paddle.transpose(k_layer, [1, 0, 2, 3])
Expand All @@ -286,18 +287,9 @@ def forward(

output, attention_probs = attn_output, attn_weights
else:
if cache is not None:
cache_k, cache_v = cache[0], cache[1]
# [s + c, b, n, h/n]
k_layer = paddle.concat([cache_k, k_layer], axis=0)
v_layer = paddle.concat([cache_v, v_layer], axis=0)

seq_length, batch_size, num_heads, hidden_size = k_layer.shape

cache_kv = None
if use_cache:
cache_kv = (k_layer, v_layer)

attention_scale_coeff = float(layer_id) + 1.0
if self.attention_scale:
# [s, b, n, h/n]
Expand Down

0 comments on commit a899fce

Please sign in to comment.