Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fsdp lora #435

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
93cdae8
attempt to wrfsdp wrap lora modules
danbider Jul 6, 2023
0ec0de1
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 6, 2023
20ab8b6
fsdp works by iterating over modulers
danbider Jul 6, 2023
57659c5
merged remote
danbider Jul 6, 2023
d6cf053
cleaned up fsdp loop for peft
danbider Jul 7, 2023
a44b641
robust peft import
danbider Jul 7, 2023
d957d55
fsdp known issue deleted
danbider Jul 7, 2023
e5e012d
more info in tutorial about fsdp
danbider Jul 7, 2023
f7b5e70
conditioning on peft installation for cpu tests
danbider Jul 7, 2023
1cf348c
Merge branch 'main' into feature/fsdp-lora
codestar12 Jul 7, 2023
6a1c172
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 9, 2023
a3f370c
moved lora model building to ComposerHFCausalLM
danbider Jul 11, 2023
082f71e
formatting
danbider Jul 11, 2023
058951d
updated tutorial to move lora config under model config
danbider Jul 12, 2023
f57c84f
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 15, 2023
cc7a8f9
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 24, 2023
433ae51
Merge branch 'main' into feature/fsdp-lora
dakinggg Aug 1, 2023
7c68c19
merged upstream main, fixed conflicts
danbider Aug 15, 2023
4118367
added typecheck for peft model
danbider Aug 15, 2023
5db8c74
more pyright fixes
danbider Aug 15, 2023
3a3342f
more typechecking in training script
danbider Aug 15, 2023
622e51d
Merge branch 'main' into feature/fsdp-lora
danbider Aug 16, 2023
9ec0f69
pyright following main merge
danbider Aug 16, 2023
a4439c5
model_config instead of cfg.model
danbider Aug 16, 2023
0a9e542
Update TUTORIAL.md
danbider Aug 17, 2023
9bc0b50
Update llmfoundry/models/hf/hf_fsdp.py
danbider Aug 17, 2023
f2fd418
DDP tutorial edit
danbider Aug 17, 2023
050267f
edit fsdp stuff
danbider Aug 21, 2023
c0f5148
fixed popping
danbider Aug 30, 2023
1c47c23
eliminated bnb dep
danbider Aug 30, 2023
5b905b0
Merge branch 'feature/fsdp-lora' of https://github.com/danbider/llm-f…
danbider Aug 30, 2023
27d186d
Merge branch 'main' into feature/fsdp-lora
josejg Oct 21, 2023
2f59377
Update accelerate for peft
josejg Oct 23, 2023
5bc5240
Simplify LoRA validation logic
josejg Oct 23, 2023
79cf8d6
Proper import checking
josejg Oct 24, 2023
7f72c25
Fix indent
josejg Oct 30, 2023
02d949c
Prevent FDSP wrapping empty embedding LoRA attributes
josejg Oct 30, 2023
b955696
Merge branch 'main' into feature/fsdp-lora
josejg Oct 31, 2023
4a430bd
Fix bad indent
josejg Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,29 @@ lora:
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: ['Wqkv']
target_modules: ['Wqkv', 'out_proj', 'down_proj', 'up_proj']
```
You can train LoRA models either using FSDP for further memory savings. in your `.yaml`, specify:
danbider marked this conversation as resolved.
Show resolved Hide resolved
<!--pytest.mark.skip-->
```yaml
fsdp_config:
use_orig_params: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we confirm if this is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will verify this tomorrow AM, good point

sharding_strategy: FULL_SHARD
mixed_precision: PURE
activation_checkpointing: true
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
```
or default to DDP, as follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to default DDP just leaving out the FSDP section entirely is a bit cleaner?

<!--pytest.mark.skip-->
```yaml
fsdp:
{}
```

- In the current release, these features have Beta support.
- For efficiency, The MPT model concatenates the `Q`, `K`, and `V` matrices in each attention block into a single `Wqkv` matrix that is three times wider. Currently, LoRA supports a low-rank approximation to this `Wqkv` matrix.
- Known issue: PEFT / LoRA do not directly work with FSDP.

### Can I quantize these models and/or run on CPU?
- The LLM Foundry codebase does not directly have examples of quantization or limited-resource inference. But you can check out [GGML](https://github.com/ggerganov/ggml) (same library that powers llama.cpp) which has built support for efficiently running MPT models on CPU! You _can_ load your model in 8-bit precision for inference using the [bitsandbytes library](https://github.com/TimDettmers/bitsandbytes) and Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index) via `load model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto", trust_remote_code=True)`, although we have not extensively benchmarked the performance (see the Hugging Face [quantization documentation](https://huggingface.co/docs/transformers/main/main_classes/quantization) for more detail).
Expand Down
18 changes: 18 additions & 0 deletions llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
# which is MIT licensed

import functools
import warnings
from typing import Any, Iterable, List

import torch
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder

try:
from peft import LoraModel
lora_model_type = LoraModel
except ImportError:
lora_model_type = None
warnings.warn('peft is not installed, LoraModel will not be available')


# helper functions
def rhasattr(obj: Any, attr: str):
Expand Down Expand Up @@ -182,6 +190,16 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
tied_embeddings._fsdp_wrap = False # type: ignore
lm_head._fsdp_wrap = False # type: ignore

# applying ._fsdp_wrap = True for the LoRA modules
# this is needed because added LoRA modules have requires_grad=True,
# while the rest of the modules have requires_grad=False
if lora_model_type is not None: # peft is installed
if isinstance(model.base_model,
lora_model_type): # we have builR a LoraModel
danbider marked this conversation as resolved.
Show resolved Hide resolved
for name, module in model_block.named_modules():
if 'lora' in name: # peft adds modules named with lora
module._fsdp_wrap = True

# FSDP Wrap and Activation Checkpoint every model block
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
model.activation_checkpointing_fn = lambda module: isinstance(
Expand Down