Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gradient checkpointing for LLM fine-tuning #3613

Merged
merged 15 commits into from
Sep 15, 2023
Merged

Conversation

arnavgarg1
Copy link
Contributor

@arnavgarg1 arnavgarg1 commented Sep 15, 2023

This PR adds support in the finetune trainer to optionally enabled gradient_checkpointing.

What is gradient checkpointing?

Gradient checkpointing works by recomputing the activations of the model during the backward pass, rather than storing them in memory during the forward pass. This is a tradeoff between compute and memory, as the activations need to be recomputed during the backward pass, but the memory footprint is reduced. This is set to false by default because it is not always beneficial to use gradient checkpointing, and it can sometimes slow down training.
image

How can you use gradient checkpointing in the config?

Gradient checkpointing is disabled by default. To enable it, you can simply set enable_gradient_checkpointing to True.

trainer:
    enable_gradient_checkpointing: true

When should I enable gradient checkpointing?

This is useful when training very large models that run into out of memory errors very quickly during training. It is particularly helpful when doing non-quantization based training (adapter based or full fine-tuning).

@github-actions
Copy link

Unit Test Results

  6 files  ±0    6 suites  ±0   39m 8s ⏱️ +45s
31 tests ±0  26 ✔️ ±0    5 💤 ±0  0 ±0 
82 runs  ±0  66 ✔️ ±0  16 💤 ±0  0 ±0 

Results for commit 1b3fe04. ± Comparison against base commit 6198693.

Copy link
Contributor

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice description! Do you have a sense for which LLMs support gradient checkpointing and which ones don't?

ludwig/trainers/trainer.py Show resolved Hide resolved
ludwig/trainers/trainer.py Show resolved Hide resolved
@arnavgarg1
Copy link
Contributor Author

Nice description! Do you have a sense for which LLMs support gradient checkpointing and which ones don't?

I think almost all models coming from the transformer package support it, but I have the additional check here just in case there is an edge case that hasn't been handled.

@arnavgarg1 arnavgarg1 merged commit 3edfb84 into master Sep 15, 2023
17 checks passed
@arnavgarg1 arnavgarg1 deleted the lora_grad branch September 15, 2023 16:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants