From 2f50992f6a52bb09b42a98855f5881231ceadd91 Mon Sep 17 00:00:00 2001 From: abhishekkrthakur Date: Thu, 26 Sep 2024 14:38:42 +0200 Subject: [PATCH] dpoconfig fix --- configs/llm_finetuning/llama3-8b-dpo-qlora.yml | 2 +- src/autotrain/__init__.py | 2 +- src/autotrain/trainers/clm/train_clm_dpo.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/llm_finetuning/llama3-8b-dpo-qlora.yml b/configs/llm_finetuning/llama3-8b-dpo-qlora.yml index de786fb6bf..640f415439 100644 --- a/configs/llm_finetuning/llama3-8b-dpo-qlora.yml +++ b/configs/llm_finetuning/llama3-8b-dpo-qlora.yml @@ -33,4 +33,4 @@ params: hub: username: ${HF_USERNAME} token: ${HF_TOKEN} - push_to_hub: true \ No newline at end of file + push_to_hub: false \ No newline at end of file diff --git a/src/autotrain/__init__.py b/src/autotrain/__init__.py index a51aa1de32..512f58a797 100644 --- a/src/autotrain/__init__.py +++ b/src/autotrain/__init__.py @@ -41,7 +41,7 @@ warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub") logger = Logger().get_logger() -__version__ = "0.8.20" +__version__ = "0.8.21" def is_colab(): diff --git a/src/autotrain/trainers/clm/train_clm_dpo.py b/src/autotrain/trainers/clm/train_clm_dpo.py index 738697c43e..2cea0a9885 100644 --- a/src/autotrain/trainers/clm/train_clm_dpo.py +++ b/src/autotrain/trainers/clm/train_clm_dpo.py @@ -1,8 +1,8 @@ import torch from peft import LoraConfig -from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments +from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig from transformers.trainer_callback import PrinterCallback -from trl import DPOTrainer +from trl import DPOConfig, DPOTrainer from autotrain import logger from autotrain.trainers.clm import utils @@ -22,7 +22,11 @@ def train(config): training_args = utils.configure_training_args(config, logging_steps) config = utils.configure_block_size(config, tokenizer) - args = TrainingArguments(**training_args) + training_args["max_length"] = config.block_size + training_args["max_prompt_length"] = config.max_prompt_length + training_args["max_target_length"] = config.max_completion_length + training_args["beta"] = config.dpo_beta + args = DPOConfig(**training_args) logger.info("loading model config...") model_config = AutoConfig.from_pretrained( @@ -103,13 +107,9 @@ def train(config): trainer = DPOTrainer( **trainer_args, ref_model=model_ref, - beta=config.dpo_beta, train_dataset=train_data, eval_dataset=valid_data if config.valid_split is not None else None, tokenizer=tokenizer, - max_length=config.block_size, - max_prompt_length=config.max_prompt_length, - max_target_length=config.max_completion_length, peft_config=peft_config if config.peft else None, )