Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct attention mask dtype for Flax GPT2 #25636

Merged
merged 6 commits into from
Aug 25, 2023

Conversation

liutianlin0121
Copy link
Contributor

@liutianlin0121 liutianlin0121 commented Aug 21, 2023

What does this PR do?

Fixes #25634

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi
@ArthurZucker

@liutianlin0121 liutianlin0121 changed the title Correct attention mask dtype Correct attention mask dtype for Flax GPT2 Aug 21, 2023
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Could you maybe add a test in the test_modelling_flax_gpt2.py to make sure this is tested? 😉 (taking inspiration from your minimal reproducer!)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@liutianlin0121
Copy link
Contributor Author

@ArthurZucker Sure! I added a test :-)

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modelling code changes LGTM - would be great to make the test a fast test by defining a tester function in the model tester, and then executing it in the model test

tests/models/gpt2/test_modeling_flax_gpt2.py Outdated Show resolved Hide resolved
@liutianlin0121
Copy link
Contributor Author

liutianlin0121 commented Aug 22, 2023

would be great to make the test a fast test by defining a tester function in the model tester, and then executing it in the model test

@sanchit-gandhi Good point! Done. Let me know if you have further suggestions. :-)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a test! 🤗

@ArthurZucker
Copy link
Collaborator

cc @sanchit-gandhi feel free to merge if it's alright with you!

@liutianlin0121
Copy link
Contributor Author

@sanchit-gandhi Hey thanks! I change to assertTrue.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome - thanks @liutianlin0121!

@liutianlin0121
Copy link
Contributor Author

No problem! Feel free to merge it (it seems that I can't).

@ArthurZucker ArthurZucker merged commit 0040469 into huggingface:main Aug 25, 2023
@liutianlin0121 liutianlin0121 deleted the attention_mask branch August 25, 2023 15:44
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* Correct attention mask dtype

* reformat code

* add a test for boolean mask

* convert test to fast test

* delete unwanted print

* use assertTrue for testing
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* Correct attention mask dtype

* reformat code

* add a test for boolean mask

* convert test to fast test

* delete unwanted print

* use assertTrue for testing
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* Correct attention mask dtype

* reformat code

* add a test for boolean mask

* convert test to fast test

* delete unwanted print

* use assertTrue for testing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Problem caused by boolean attention mask in pretrained_model.generate of Flax GPT2
4 participants