-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct attention mask dtype for Flax GPT2 #25636
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! Could you maybe add a test in the test_modelling_flax_gpt2.py
to make sure this is tested? 😉 (taking inspiration from your minimal reproducer!)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
2f1e4cb
to
1a9d3dd
Compare
@ArthurZucker Sure! I added a test :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modelling code changes LGTM - would be great to make the test a fast test by defining a tester function in the model tester, and then executing it in the model test
1a9d3dd
to
aa34af0
Compare
@sanchit-gandhi Good point! Done. Let me know if you have further suggestions. :-) |
aa34af0
to
7f966d7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding a test! 🤗
cc @sanchit-gandhi feel free to merge if it's alright with you! |
42b2414
to
9d97878
Compare
@sanchit-gandhi Hey thanks! I change to assertTrue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome - thanks @liutianlin0121!
No problem! Feel free to merge it (it seems that I can't). |
* Correct attention mask dtype * reformat code * add a test for boolean mask * convert test to fast test * delete unwanted print * use assertTrue for testing
* Correct attention mask dtype * reformat code * add a test for boolean mask * convert test to fast test * delete unwanted print * use assertTrue for testing
* Correct attention mask dtype * reformat code * add a test for boolean mask * convert test to fast test * delete unwanted print * use assertTrue for testing
What does this PR do?
Fixes #25634
Before submitting
Pull Request section?
to it if that's the case. Link: Problem caused by boolean attention mask in
pretrained_model.generate
of Flax GPT2 #25634documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sanchit-gandhi
@ArthurZucker