Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conflict between last version of Transformers.Trainer and DPOTrainer.get_batch_samples #2275

Open
2 of 4 tasks
lucasdegeorge opened this issue Oct 24, 2024 · 6 comments
Open
2 of 4 tasks
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO

Comments

@lucasdegeorge
Copy link

System Info

System Info

Python version: 3.11.0
PyTorch version: 2.4.1 or 2.5.0
Transformers version: 4.46.0
TRL version: 0.11.4
PEFT version: 0.13.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Code to reproduce the error (with transformers==4.46.0 and trl==0.11.4)

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B",
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
config = LoraConfig(
    r = 16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type= "CAUSAL_LM"
)
model = get_peft_model(model, config)

data_dpo = load_dataset("CultriX/llama70B-dpo-dataset")
def preprocess_data_dpo(data_point):
    system = data_point['system']
    question = data_point['question']
    chosen = data_point['chosen']
    rejected = data_point['rejected']
    features = {
        "prompt": f"{question}",
        "chosen": f"{chosen}",
        "rejected": f"{rejected}"
     }
    return features # fill the gap, using dpo format
data_dpo = data_dpo['train'].shuffle(seed=42).map(preprocess_data_dpo)

training_args = DPOConfig(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    learning_rate=2e-4,
    fp16=True,
    save_total_limit=3,
    logging_steps=1,
    output_dir="output_dir",
    max_steps=200,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    report_to="tensorboard",
)

trainer = DPOTrainer(
    model=model,
    args = training_args,
    train_dataset = data_dpo,
    tokenizer=tokenizer,
)
trainer.train() 

This error is raised:

Traceback (most recent call last):
  File "dpo_lora.py", line 70, in <module>
    trainer.train() 
    ^^^^^^^^^^^^^^^
  File ".venv/env/lib/python3.11/site-packages/transformers/trainer.py", line 2122, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File ".venv/env/lib/python3.11/site-packages/transformers/trainer.py", line 2426, in _inner_training_loop
    batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/env/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 1508, in get_batch_samples
    policy_output = model.generate(
                    ^^^^^^^^^^^^^^
AttributeError: 'generator' object has no attribute 'generate'

Expected behavior

It seems to be due to the recent commit 6ba31a8 on Oct 17, 2024 "Enable users to use their own loss functions + deal with prefetching for grad accum (huggingface/transformers#34198)" on Transformers 4.46.0
where a new method get_batch_samples is defined and used by _inner_training_loop

But the subclass DPOTrainer overwrites the method get_batch_samples with a different signature (and output).

Error can be avoided with Transformers==4.45.2 and trl==0.11.4

@qgallouedec
Copy link
Member

The issue has been solved with #2246
TRL 0.11.4 is not compatible with Transformers 4.46.
We will release TRL 0.12 very soon

@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 DPO Related to DPO labels Oct 25, 2024
@swamymushini
Copy link

What is the working fix for this issue now? which library versions we can use now for temp solution? should be downgrade transformers

@bibhudutta-p
Copy link

Yes, use the latest version of TRL and v4.45.2 of Transformers. This fixed it for me.

@swamymushini
Copy link

Yes, use the latest version of TRL and v4.45.2 of Transformers. This fixed it for me.

u mean the TRL 0.11.4?

@bibhudutta-p
Copy link

yes

@swamymushini
Copy link

yes

Really thanks.. it worked for me..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO
Projects
None yet
Development

No branches or pull requests

4 participants