-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NEW Feature] 新增基于hook的refined_recompute支持 #9396
Changes from 16 commits
62fc783
9d2632a
56203a1
506c6bf
f394bbc
11075fc
997cf5f
1dad92d
246a913
f348cc1
9d34d2f
da6c9cb
6b6654d
ed4addb
7b8d1c6
7064804
b8671f1
44b2389
418a259
9f5e306
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ | |
except: | ||
flash_attention = None | ||
|
||
from paddlenlp.transformers.refined_recompute import no_recompute | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么要叫no_recompute,感觉怪怪的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 要么改成skip_recompute也行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. recompute(func, xxxxx) vs no_recompute(func, xxxxxx) |
||
from paddlenlp.transformers.ring_flash_attention import RingFlashAttention | ||
|
||
|
||
|
@@ -174,6 +175,7 @@ | |
sequence_parallel=False, | ||
reshard_layer=None, | ||
npu_is_casual=False, | ||
skip_recompute=False, | ||
): | ||
bsz, q_len, num_heads, head_dim = query_states.shape | ||
_, kv_seq_len, _, _ = value_states.shape | ||
|
@@ -257,28 +259,34 @@ | |
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) | ||
|
||
if hasattr(F, "flashmask_attention"): | ||
attn_output = F.flashmask_attention( | ||
attn_output = no_recompute( | ||
F.flashmask_attention, | ||
query_states, | ||
key_states, | ||
value_states, | ||
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1), | ||
causal=True, | ||
enable=skip_recompute, | ||
) | ||
else: | ||
attn_output = F.flash_attention_with_sparse_mask( | ||
attn_output = no_recompute( | ||
F.flash_attention_with_sparse_mask, | ||
query_states, | ||
key_states, | ||
value_states, | ||
attn_mask_start_row_indices=attn_mask_startend_row_indices, | ||
is_causal=True, | ||
enable=skip_recompute, | ||
) | ||
else: | ||
attn_output = F.scaled_dot_product_attention( | ||
attn_output = no_recompute( | ||
F.scaled_dot_product_attention, | ||
query_states, | ||
key_states, | ||
value_states, | ||
attn_mask=attention_mask, | ||
is_causal=query_states.shape[1] != 1, | ||
enable=skip_recompute, | ||
) | ||
attn_weights = None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的配置信息会传到下游任务里面吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要 _set_unsavable_keys 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):