Skip to content

Commit

Permalink
Peft deepspeed resume (#1227)
Browse files Browse the repository at this point in the history
* import deepspeed integration

* monkeypatch peft adapater with deepspeed for resume from checkpoint

* fix patch

* fix patches attempt 2

* make sure to set lora_model_dir

* skip pylint for deepspeed.utils

* pick up upstream fix in transformers

* remove monkeypatch for deepspeed/peft fix

* no need to set the lora_model_dir on resume

* unset load_in_*bit when using quant config

* guard before del

* better handling of load_in* kwargs
  • Loading branch information
winglian authored Jan 31, 2024
1 parent 25e037f commit c67fb71
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft @ git+https://github.com/huggingface/peft.git
transformers==4.37.0
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
Expand Down
7 changes: 4 additions & 3 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import Tuple

import fire
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -27,7 +28,7 @@
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
Expand Down
30 changes: 15 additions & 15 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ def train(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps

if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint

# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
Expand All @@ -79,21 +94,6 @@ def train(

safe_serialization = cfg.save_safetensors is True

if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint

if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)

Expand Down
22 changes: 15 additions & 7 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,18 @@ def load_model(
**bnb_config,
)

if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
if cfg.load_in_4bit and cfg.adapter is not None:
model_kwargs["load_in_4bit"] = True

# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in model_kwargs or cfg.gptq:
if "load_in_8bit" in model_kwargs:
del model_kwargs["load_in_8bit"]
if "load_in_4bit" in model_kwargs:
del model_kwargs["load_in_4bit"]

# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
Expand Down Expand Up @@ -506,8 +518,6 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)

Expand Down Expand Up @@ -575,8 +585,6 @@ def load_model(
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
Expand Down Expand Up @@ -608,8 +616,6 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
Expand Down Expand Up @@ -678,7 +684,9 @@ def load_model(
skip_prepare_model_for_kbit_training = False

if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import set_z3_leaf_modules
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
Expand Down

0 comments on commit c67fb71

Please sign in to comment.