Skip to content

Commit

Permalink
dpoconfig fix
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Sep 26, 2024
1 parent b810afd commit 2f50992
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion configs/llm_finetuning/llama3-8b-dpo-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ params:
hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
push_to_hub: false
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

logger = Logger().get_logger()
__version__ = "0.8.20"
__version__ = "0.8.21"


def is_colab():
Expand Down
14 changes: 7 additions & 7 deletions src/autotrain/trainers/clm/train_clm_dpo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from peft import LoraConfig
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.trainer_callback import PrinterCallback
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer

from autotrain import logger
from autotrain.trainers.clm import utils
Expand All @@ -22,7 +22,11 @@ def train(config):
training_args = utils.configure_training_args(config, logging_steps)
config = utils.configure_block_size(config, tokenizer)

args = TrainingArguments(**training_args)
training_args["max_length"] = config.block_size
training_args["max_prompt_length"] = config.max_prompt_length
training_args["max_target_length"] = config.max_completion_length
training_args["beta"] = config.dpo_beta
args = DPOConfig(**training_args)

logger.info("loading model config...")
model_config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -103,13 +107,9 @@ def train(config):
trainer = DPOTrainer(
**trainer_args,
ref_model=model_ref,
beta=config.dpo_beta,
train_dataset=train_data,
eval_dataset=valid_data if config.valid_split is not None else None,
tokenizer=tokenizer,
max_length=config.block_size,
max_prompt_length=config.max_prompt_length,
max_target_length=config.max_completion_length,
peft_config=peft_config if config.peft else None,
)

Expand Down

0 comments on commit 2f50992

Please sign in to comment.