Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention Mask] Refactor all encoder-decoder attention mask #27086

Merged
merged 39 commits into from
Oct 27, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 26, 2023

What does this PR do?

This PR refactors the attention mask of all PT Seq2Seq models. While this is a nice improvement of life, it is also necessary to effectively add FA2 and SDPA to PT Seq2Seq models (without having to change 54+ files).

In a follow-up PR it'll be much easier to add FA2 to just Bart and most important Bart-like models.
The PR slightly goes against the single-file policy, but attention masks are really the same across models and there is also only so much they can be (causal, non-causal, windowed). I think it doesn't really hurt readability as the function are very clearly defined (create 4d attention mask from 2d).

For some very big exceptions (I found only one which is LED, see comment here), we could just write part of the mask creation separately as is done.

I could also give the mask creation functions a _ to make it clearer that they are private mehtods in Transformers. Both keeping as is or changing is good for me.

@amyeroberts @LysandreJik @ArthurZucker this is ready for a review!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 26, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title [FA2 Bart] Add FA2 to all Bart-like [Attention Mask] Refactor all encoder-decoder attention mask Oct 26, 2023
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you have mentioned in your PR description, I would personally be in favor of having _ prefixes just to ensure they aren't leveraged by third-parties thinking they're public. They're prominently displayed in the forward pass of a significant number of models, so it's worth adding the prefix.

Very cool PR! This is much cleaner this way, IMO

@@ -1522,7 +1499,7 @@ def forward(
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh that combined attention_mask was pretty weird

src/transformers/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
src/transformers/modeling_attn_mask_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring all this code. It's so much cleaner and great to see all the repeated code being deleted 🙏

Just a few nit comments. Only request is that a handful of our biggest models have their slow tests run as a sense check before merging.


def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a check here that the pass sliding_window is positive if not None

src/transformers/modeling_attn_mask_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_attn_mask_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_attn_mask_utils.py Show resolved Hide resolved
src/transformers/modeling_attn_mask_utils.py Outdated Show resolved Hide resolved
attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
attention_mask = prepare_4d_causal_attention_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! :D

tests/test_modeling_utils.py Outdated Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
assert mask_4d.shape == (bsz, 1, q_len, kv_len)

context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally not a fan of having test functions with lots of if/else statements - it tends to lead to utility functions which try to handle everything and can be error prone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm fair, I'd say though for tests it's ok since only us maintainers look at them. For me, the if-else statements actually helped quite a bit to map out all the different scenarios that exist (causal mask + window, causal mask + no winow, non-causal mask + window, non-causal mask + no window)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it!



@require_torch
class AttentionMaskTester(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would also advocate for self.assert that we use more now 😉 but it's a nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I really think assert ... is much better to use in tests because the error message is much cleaner. E.g. if you do:

self.assertTrue(expected_ids == predicted_ids)

assuming each is a list, you're error message will just be "the assertion is not True"

where as doing:

assert expected_ids == predicted_ids

gives you a much better error message. Ok to change for consistency, but I really think that just doing assert ... is better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but self.assertListEqual() should give you more info no?
I don't mine but just want consistency!

@patrickvonplaten
Copy link
Contributor Author

Thanks for the fast reviews everyone! Ran slow tests

CUDA_VISIBLE_DEVICES="0" RUN_SLOW=1 pytest ...

for:

  • Mistral
  • Falcon
  • LLama
  • Whisper
  • Bart
  • LED (exceptional model)
  • Bloom (uses boolean mask)

Think this should be good enough!

@DavidAkinpelu
Copy link

DavidAkinpelu commented Oct 31, 2023

@patrickvonplaten I know this PR has been merged, but I have a question regarding the Pytorch version of BART. The implementation assumes the decoder is always used in an autoregressive manner Pytorch version unlike the flax version . There could be cases of the decoder being used as an "encoder" and a "cross attention". In this case, the autoregressive nature is not required. While I think the default should be the autoregressive manner, but if is_decoder is set to false, the non-causal masking operation should be performed instead.

@patrickvonplaten
Copy link
Contributor Author

non-causal masking operation should be performed instead.

@DavidAkinpelu I think you linked the FlaxAttention class, not the FlaxDecoder class above. In PT the Attention class can also be used in non-causal model, just like in Flax. If you want to use Bart in non-auto-regressive mode why don't you use BartEncoder?

@DavidAkinpelu
Copy link

DavidAkinpelu commented Nov 12, 2023

@patrickvonplaten This paper got me thinking in that direction Mores+.

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…face#27086)

* [FA2 Bart] Add FA2 to all Bart-like

* better

* Refactor attention mask

* remove all customized atteniton logic

* format

* mass rename

* replace _expand_mask

* replace _expand_mask

* mass rename

* add pt files

* mass replace & rename

* mass replace & rename

* mass replace & rename

* mass replace & rename

* Update src/transformers/models/idefics/modeling_idefics.py

* fix more

* clean more

* fix more

* make style

* fix again

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* small fix mistral

* finish

* finish

* finish

* finish

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants