-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable users to use their own loss functions + deal with prefetching for grad accum #34198
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, IMO a regression test on the grad norms could be fairly nice!
src/transformers/trainer.py
Outdated
self.state.num_input_tokens_seen += ( | ||
torch.sum( | ||
self.accelerator.gather( | ||
torch.tensor( | ||
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 | ||
) | ||
) | ||
) | ||
.cpu() | ||
.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make this more readable!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean
did this one 🫠
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can split in 3-4 lines 🎐
src/transformers/trainer.py
Outdated
if (self.label_smoother is not None or self.compute_loss is not None) and "labels" in inputs: | ||
labels = inputs.pop("labels") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmmm if people don't pass a loss, we won't use the model's default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will, it stays in inputs
and gets passed to the models forward()
src/transformers/trainer.py
Outdated
# For now we don't support object detection | ||
try: | ||
num_items_in_batch = sum( | ||
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already quickly discussed this with Zach, so this is a more general questions to other reviewers:
Would this line be work for all the different task types we support? Specifically, can we always skip the first item in the sequence, i.e. is the [..., 1:]
part valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For casual auto regressive models it works but won't work in other ones
src/transformers/trainer.py
Outdated
self.state.num_input_tokens_seen += ( | ||
torch.sum( | ||
self.accelerator.gather( | ||
torch.tensor( | ||
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 | ||
) | ||
) | ||
) | ||
.cpu() | ||
.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can split in 3-4 lines 🎐
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a denominator change in the test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to merge!
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
What does this PR do?
In conjunction with #34191, this PR solves the other half of what's needed:
compute_loss
gradient_accumulation_steps
worth of data each complete step and marking how many samples were seen (num_items_in_batch
), which can be passed to a loss function if it takes innum_items_seen
(name TBD)A bit of feedback needed we need to coordinate:
num_items_in_batch
and then passed through to the loss functions as such? Or is there a better name we can think ofFixes huggingface/trl#2175
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik @ArthurZucker