Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add q-galore optimizer #1752

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def parse_requirements():
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"q-galore-torch==1.0",
"torch-optimi==0.2.1",
],
},
Expand Down
11 changes: 9 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def __init__(
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
and self.args.alternate_optimizer
not in ["optimi_adamw", "q_galore_adamw8bit"]
):
return super().create_optimizer()

Expand Down Expand Up @@ -344,6 +345,12 @@ def create_optimizer(self):
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "q_galore_adamw8bit":
from q_galore_torch import QGaLoreAdamW8bit

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
QGaLoreAdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -1436,7 +1443,7 @@ def build(self, total_num_steps):

trainer_kwargs = {}

if self.cfg.optimizer == "optimi_adamw":
if self.cfg.optimizer in ["optimi_adamw", "q_galore_adamw8bit"]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
Expand Down
5 changes: 4 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
Union[
OptimizerNames,
Literal["lion_pytorch", "optimi_adamw", "q_galore_adamw8bit"],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
Expand Down
42 changes: 42 additions & 0 deletions tests/e2e/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,45 @@ def test_optimi_adamw(self, temp_dir):

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_q_galore_adamw8bit(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "q_galore_adamw8bit",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
Loading