-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention Mask] Refactor all encoder-decoder attention mask #27086
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you have mentioned in your PR description, I would personally be in favor of having _
prefixes just to ensure they aren't leveraged by third-parties thinking they're public. They're prominently displayed in the forward pass of a significant number of models, so it's worth adding the prefix.
Very cool PR! This is much cleaner this way, IMO
@@ -1522,7 +1499,7 @@ def forward( | |||
layer_outputs = self.gradient_checkpointing_func( | |||
decoder_layer.__call__, | |||
hidden_states, | |||
combined_attention_mask, | |||
None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh that combined attention_mask
was pretty weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for refactoring all this code. It's so much cleaner and great to see all the repeated code being deleted 🙏
Just a few nit comments. Only request is that a handful of our biggest models have their slow tests run as a sense check before merging.
|
||
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): | ||
self.is_causal = is_causal | ||
self.sliding_window = sliding_window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add a check here that the pass sliding_window
is positive if not None
attention_mask = self.attn_mask_converter.to_causal_4d( | ||
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | ||
) | ||
attention_mask = prepare_4d_causal_attention_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! :D
assert mask_4d.shape == (bsz, 1, q_len, kv_len) | ||
|
||
context = mask_converter.sliding_window | ||
if mask_converter.is_causal and context is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally not a fan of having test functions with lots of if/else statements - it tends to lead to utility functions which try to handle everything and can be error prone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm fair, I'd say though for tests it's ok since only us maintainers look at them. For me, the if-else statements actually helped quite a bit to map out all the different scenarios that exist (causal mask + window, causal mask + no winow, non-causal mask + window, non-causal mask + no window)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love it!
|
||
|
||
@require_torch | ||
class AttentionMaskTester(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would also advocate for self.assert that we use more now 😉 but it's a nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I really think assert ...
is much better to use in tests because the error message is much cleaner. E.g. if you do:
self.assertTrue(expected_ids == predicted_ids)
assuming each is a list, you're error message will just be "the assertion is not True"
where as doing:
assert expected_ids == predicted_ids
gives you a much better error message. Ok to change for consistency, but I really think that just doing assert ...
is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but self.assertListEqual() should give you more info no?
I don't mine but just want consistency!
...ecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
Outdated
Show resolved
Hide resolved
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…ansformers into add_fa2_bart_like
Thanks for the fast reviews everyone! Ran slow tests
for:
Think this should be good enough! |
@patrickvonplaten I know this PR has been merged, but I have a question regarding the Pytorch version of BART. The implementation assumes the decoder is always used in an autoregressive manner Pytorch version unlike the flax version . There could be cases of the decoder being used as an "encoder" and a "cross attention". In this case, the autoregressive nature is not required. While I think the default should be the autoregressive manner, but if |
@DavidAkinpelu I think you linked the FlaxAttention class, not the FlaxDecoder class above. In PT the Attention class can also be used in non-causal model, just like in Flax. If you want to use Bart in non-auto-regressive mode why don't you use BartEncoder? |
@patrickvonplaten This paper got me thinking in that direction Mores+. |
…face#27086) * [FA2 Bart] Add FA2 to all Bart-like * better * Refactor attention mask * remove all customized atteniton logic * format * mass rename * replace _expand_mask * replace _expand_mask * mass rename * add pt files * mass replace & rename * mass replace & rename * mass replace & rename * mass replace & rename * Update src/transformers/models/idefics/modeling_idefics.py * fix more * clean more * fix more * make style * fix again * finish * finish * finish * finish * finish * finish * finish * finish * finish * finish * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * small fix mistral * finish * finish * finish * finish --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
This PR refactors the attention mask of all PT Seq2Seq models. While this is a nice improvement of life, it is also necessary to effectively add FA2 and SDPA to PT Seq2Seq models (without having to change 54+ files).
In a follow-up PR it'll be much easier to add FA2 to just Bart and most important Bart-like models.
The PR slightly goes against the single-file policy, but attention masks are really the same across models and there is also only so much they can be (causal, non-causal, windowed). I think it doesn't really hurt readability as the function are very clearly defined (create 4d attention mask from 2d).
For some very big exceptions (I found only one which is LED, see comment here), we could just write part of the mask creation separately as is done.
I could also give the mask creation functions a
_
to make it clearer that they are private mehtods in Transformers. Both keeping as is or changing is good for me.@amyeroberts @LysandreJik @ArthurZucker this is ready for a review!