-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT2 should not store/compute cached activations during finetuning #2356
Conversation
Not sure which size of GPT-2 you're testing with, but the 355M version utilizes gradient checkpointing for finetuning in gpt-2-simple, which is not the case with the 124M version w/ Transformers. That might be a useful test case. |
I just tried this with gpt-2-medium on my poetry dataset and have the same memory error as before. Complete info below:
|
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
@thomwolf @LysandreJik : Curious about the status of this. It seems like the memory issues still exists with "run_lm_finetuning.py" and GPT-2. For instance, even a batch size of 1 doesn't help prevent OOM error when fine-tuning GPT-2 large with a sequence length of 1024 (despite using FP-16). Is there anything we could do here (apart from gradient checkpointing) that would make the memory usage lower as Thomas listed in his first comment above? Thanks. |
This PR tries to fix the issue with large memory usage from GPT2 during fine-tuning.
Quick estimations
@LysandreJik compared memory usage with @minimaxir GPT2-simple (https://github.com/minimaxir/gpt-2-simple):
Small model, batch size 4, sequence length 512 (roughly similar):
Increasing to a 1024 length:
Medium model, batch size de 4, sequence length de 512
Possible reason
Investigating our
run_lm_finetuning
script and GPT2 model showed that we are alway computing/storing cached hidden-states (which are normally only useful for decoding).This PR attempt to fix this most probable source of large memory usage.
It cleans up a little bit GPT2 codebase at the same time.
I haven't tried it yet on a large scale test.
cc @LysandreJik @arnicas