Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fsdp lora #435

Closed
wants to merge 39 commits into from
Closed

Feature/fsdp lora #435

wants to merge 39 commits into from

Conversation

danbider
Copy link
Contributor

@danbider danbider commented Jul 7, 2023

add a fix to wrap trainable lora modules with fsdp.
verified successful training on 8 gpus.

related, but not affecting this PR: @bcui19 and I discussed a small enhancement of composer to accompany this PR which will spare big untrained modules from being fetched by fsdp

Copy link
Contributor

@codestar12 codestar12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good I approve to merge if tests pass

@danbider danbider requested a review from codestar12 July 10, 2023 15:13
@danbider
Copy link
Contributor Author

  1. fsdp works well
  2. model configs are handled in the same way as every other model
  3. training script and hf_causal_lm scripts are now cleaner.
  4. tutorial updated as well

@palash04
Copy link

Hey @danbider , With this PR can we do fsdp lora on llama models? Basically any other model but MPT?

@danbider
Copy link
Contributor Author

should support any model including MPT.

@alextrott16
Copy link
Contributor

I'm a fan of the re-design -- in particular moving things into the hf_causal_lm builder and out of the train script. It makes everything more readable and more general. Also, I really appreciate seeing the addition to the TUTORIAL faqs.

One area that gives me pause is integration with HF for models that have been LoRA-fied. Are there any gotchas when it comes to converting to HF (and uploading to HF) from a Composer checkpoint? Similarly, are there any gotchas when it comes to working with a HF model that is already LoRA-fied?

My sense is that, with the latter, everything should be OK as long as the right things are installed, but I'd like to get a sanity check on that.

The former seems like it will still be missing support. @dakinggg can probably add some insight here, because I'm worried that the code that actually modifies the model (in hf_causal_lm) won't be reflected in the model construction code that gets uploaded to the HF repo along with the model weights. Has anyone tested that workflow?

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add some basic tests for the lora addition?

TUTORIAL.md Outdated Show resolved Hide resolved
TUTORIAL.md Outdated
<!--pytest.mark.skip-->
```yaml
fsdp_config:
use_orig_params: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we confirm if this is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will verify this tomorrow AM, good point

TUTORIAL.md Outdated
```
or default to DDP, as follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to default DDP just leaving out the FSDP section entirely is a bit cleaner?

llmfoundry/models/hf/hf_fsdp.py Outdated Show resolved Hide resolved
'lora',
must_exist=False,
default_value=None)
if lora_config is not None:
if lora_config.get('rank', None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is supposed to happen if lora config is provided but rank is none? Should that be an error?

danbider and others added 2 commits August 16, 2023 21:30
edit from daniel

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
@danbider
Copy link
Contributor Author

Could you please also add some basic tests for the lora addition?

will add those now.

@samhavens
Copy link
Contributor

samhavens commented Aug 18, 2023

Any chance we could include scripts/train/yamls/finetune/lora-mpt-7b.yaml

max_seq_len: 2048
global_seed: 17
dist_timeout: 5400

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

model:
  name: hf_causal_lm
  pretrained: true
  pretrained_model_name_or_path: mosaicml/mpt-7b
  config_overrides:
    attn_config:
      attn_impl: triton
      attn_uses_sequence_id: false
  lora:
    # UPDATE these as needed
    args:
      r: 1
      lora_alpha: 32
      # target_modules: ["Wqkv", "out_proj", "up_proj", "down_proj"]
      target_modules: ["up_proj", "down_proj"]
      lora_dropout: 0.05
      bias: none
      task_type: "CAUSAL_LM"
# Tokenizer
tokenizer:
  name: mosaicml/mpt-7b
  kwargs:
    model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
  name: finetuning
  dataset:
    hf_name: danbider/codegen
    split: train
    max_seq_len: ${max_seq_len}
    allow_pad_trimming: false
    decoder_only_format: true
    packing_ratio: 19.6 # concat examples
    shuffle: true
  drop_last: true
  num_workers: 8
  pin_memory: false
  prefetch_factor: 2
  persistent_workers: true
  timeout: 0

eval_loader:
  name: finetuning
  dataset:
    hf_name: danbider/codegen
    split: test
    max_seq_len: ${max_seq_len}
    allow_pad_trimming: false
    decoder_only_format: true
    packing_ratio: 19.6
    shuffle: true
  drop_last: true
  num_workers: 8
  pin_memory: false
  prefetch_factor: 2
  persistent_workers: true
  timeout: 0

# Optimization
# Based on MPT pretraining
scheduler:
  name: cosine_with_warmup
  t_warmup: 50ba
  alpha_f: 0.0

optimizer:
  name: decoupled_lionw
  lr: 1.0e-4  # lora needs higher LR
  betas:
  - 0.9
  - 0.95
  eps: 1.0e-8
  weight_decay: 1.0e-4

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0

max_duration: 2ep
eval_interval: 1ep
eval_first: true
global_train_batch_size: 48

# System
seed: ${global_seed}
device_eval_batch_size: 4
device_train_microbatch_size: 1
precision: amp_bf16

# FSDP
# fsdp_config:
#   sharding_strategy: FULL_SHARD
#   mixed_precision: PURE
#   activation_checkpointing: true
#   activation_checkpointing_reentrant: false
#   activation_cpu_offload: false
#   limit_all_gathers: true
#   verbose: false

# leave out fsdp_config for DDP or single GPU

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}

# loggers:
#   wandb: {}
# Checkpoint to local filesystem or remote object store
save_interval: 5000ba
save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK
save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoint

Any example we could point users to would be great, even if we plan on refining it later

@germanjke
Copy link

Hi, i tested your branch and got some bugs:

  1. Bug about build Lora (we can easy fix this)
  2. Bug about wrap FSDP Lora (need to fix)

I'm using your suggestion from TUTORIAL.md:

I'm using LLaMA-2 LorA

model:
  name: hf_causal_lm
  pretrained: true
  ...
  lora:
    args:
      r: 16
      lora_alpha: 32
      target_modules: ["Wqkv", "out_proj", "up_proj", "down_proj"] # or any subset of these for MPT-7B
      lora_dropout: 0.05
      bias: none
      task_type: "CAUSAL_LM"
fsdp_config:
  use_orig_params: true
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: true
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true

So, let's start with bug №1:

You define model_config here, everything is ok, we have lora section there.

Here you build the model, but lora section alraedy not here. So, lora_cfg is None here.

It happens, because you poping this lora section here and not using anywhere, only using for print here

main branch have this function, and they building model with 2 different configs: model_config and lora_cfg, you building this with single model_config, but LoRA popped from there, so we dont build LoRA here.

You need not to pop, or use function from main branch.
After this LoRA builds.

About bug №2:

We build LoRA here

Later, you making reinit of your LoRA model, so we have error here cause of FSDP here

For LLaMA 2 we can get this, but we can't get this for LLaMA 2 LoRA, cause of this we have None here and we raising here

I think we need to rename this for Lora Llama maybe or refactor this in other way

Thanks! I hope soon we will can train LoRA models with FSDP 👍

@danbider
Copy link
Contributor Author

Hi, i tested your branch and got some bugs:

  1. Bug about build Lora (we can easy fix this)
  2. Bug about wrap FSDP Lora (need to fix)

I'm using your suggestion from TUTORIAL.md:

I'm using LLaMA-2 LorA

model:
  name: hf_causal_lm
  pretrained: true
  ...
  lora:
    args:
      r: 16
      lora_alpha: 32
      target_modules: ["Wqkv", "out_proj", "up_proj", "down_proj"] # or any subset of these for MPT-7B
      lora_dropout: 0.05
      bias: none
      task_type: "CAUSAL_LM"
fsdp_config:
  use_orig_params: true
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: true
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true

So, let's start with bug №1:

You define model_config here, everything is ok, we have lora section there.

Here you build the model, but lora section alraedy not here. So, lora_cfg is None here.

It happens, because you poping this lora section here and not using anywhere, only using for print here

main branch have this function, and they building model with 2 different configs: model_config and lora_cfg, you building this with single model_config, but LoRA popped from there, so we dont build LoRA here.

You need not to pop, or use function from main branch. After this LoRA builds.

About bug №2:

We build LoRA here

Later, you making reinit of your LoRA model, so we have error here cause of FSDP here

For LLaMA 2 we can get this, but we can't get this for LLaMA 2 LoRA, cause of this we have None here and we raising here

I think we need to rename this for Lora Llama maybe or refactor this in other way

Thanks! I hope soon we will can train LoRA models with FSDP 👍

thanks for this. fixed the first one, we think. will take care of the second as well.

@dakinggg
Copy link
Collaborator

For Jose:

Two main issues with this PR currently that I know of:
(1) I believe there was a bad merge with main. You may want to go back a commit and redo the merge
(2) run hangs at the start of training when using init_device: mixed + FSDP + LoRA. I don't know the root cause, but I would start with printing out the model on each rank and making sure it is wrapped the same (and wrapped correctly), and all ranks end up with non meta weights before training starts.

@josejg josejg force-pushed the feature/fsdp-lora branch from 1351637 to 5b905b0 Compare October 21, 2023 01:04
@dakinggg
Copy link
Collaborator

Closing in favor of #886

@dakinggg dakinggg closed this Jan 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants