Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SFTTrainer] Introducing DataCollatorForCompletionOnlyLM #445

Merged
merged 6 commits into from
Jun 20, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 16, 2023

What does this PR do?

Fixes: #426

This PR introduces DataCollatorForCompletionOnlyLM data collator that masks out all the prompts that are before completion, similarly at what is done here: https://github.com/databrickslabs/dolly/blob/master/training/trainer.py#L48-L77

The goal for that data collator is to find where the target response template token is located in the sentence, and mask out all the tokens before the response token to attend only on the completions.

Currently the API looks as follows:

Handy reproducible snippet
from datasets import load_dataset
from trl import SFTTrainer
from trl.trainer import DataCollatorForCompletionOnlyLM
import transformers

dataset = load_dataset("tatsu-lab/alpaca", split="train")

model = transformers.AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
tokenizer.pad_token = tokenizer.eos_token

def formatting_prompts_func(examples):
    output_text = []
    for i in range(len(examples["instruction"])):
        instruction = examples["instruction"][i]
        input_text = examples["input"][i]
        response = examples["output"][i]

        if len(input_text) >= 2:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Input:
            {input_text}
            
            ### Response:
            {response}
            '''
        else:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Response:
            {response}
            '''
        output_text.append(text)

    return output_text

response_template = "### Response:\n"
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=data_collator,
    max_seq_length=1024,
)

trainer.train()

Currently, for some reason the data collator cannot find the response token because of the issue I describe in the snippet below:

import transformers

model = transformers.AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
tokenizer.pad_token = tokenizer.eos_token

print(tokenizer("### Response:"))
>>> {'input_ids': [21017, 18261, 25], 'attention_mask': [1, 1, 1]}
print(tokenizer("some random text\n ### Response:"))
>>> {'input_ids': [11246, 4738, 2420, 198, 44386, 18261, 25], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

As you can see the first token of ### Response: (21017) is replace by 44386.

EDIT: the issue appeared to be quite straightforward, one needs to replace the response template to ### Response: instead of ### Response: since the tokenizer will encode ### differently from ###

cc @vwxyzjn

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 16, 2023

The documentation is not available anymore as the PR was closed or merged.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 16, 2023

Nice PR @younesbelkada!! There are two issues.

Prompt leading spaces

First, the response token issue can be resolved by removing the leading spaces with the promps. Note that the \ in text = f'''\ is important.

        if len(input_text) >= 2:
            text = f'''\
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input_text}

### Response:
{response}
            '''

instead of

        if len(input_text) >= 2:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Input:
            {input_text}
            
            ### Response:
            {response}
            '''

Identify the location of response tokens

The reference implementation from dolly seems incorrect. Its implementation word instead matches ### Instruction:\n given the formats that we have.

image

This is because it breaks when the first token matches, but '### Response:\n' is encoded with [21017, 18261, 25, 198]., but it matches ### Instruction:\n ([21017, 46486, 25, 198]) instead.

To resolve the issue, 864948f ensures all four tokens match.

I gave it a quick run, but wandb is not recoding anything... https://wandb.ai/costa-huang/huggingface/runs/dcld1hg6/overview?workspace=user-costa-huang. Am I missing some configuration like the log_with="wandb"?

@lvwerra
Copy link
Member

lvwerra commented Jun 19, 2023

Small nit: I would add it to the main init so one can import it via from trl import DataCollator....

@younesbelkada younesbelkada marked this pull request as ready for review June 19, 2023 15:52
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 19, 2023

Thanks a lot @vwxyzjn for digging deeper into that!
I made some tiny changes from your commit and added some tests to make sure that collator doesn't get broken by future commits.
I suggest we add a script to reproduce stanford alpaca in examples/ folder (c.f.: #439) and educate users on how to use the collator once we build that example, on the documentation section. How does that sound?

@younesbelkada younesbelkada requested review from vwxyzjn and lvwerra June 19, 2023 15:56
@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 20, 2023

Thanks a lot @vwxyzjn for digging deeper into that! I made some tiny changes from your commit and added some tests to make sure that collator doesn't get broken by future commits. I suggest we add a script to reproduce stanford alpaca in examples/ folder (c.f.: #439) and educate users on how to use the collator once we build that example, on the documentation section. How does that sound?

The changes LGTM. It would be great to add some docs and potentially have a stanford alpaca example; if we have the bandwidth we can probably run some tracked experiments and make the tracked metrics and HF models available.

Copy link
Contributor Author

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense! I was thinking maybe we can merge this PR and do a follow up PR to add the stanford alpaca reproduction as @Lyken17 was interested to dive into it #439

@younesbelkada younesbelkada merged commit 7705daa into main Jun 20, 2023
@younesbelkada younesbelkada deleted the add-alpaca-dc branch June 20, 2023 15:51
@BramVanroy
Copy link
Contributor

BramVanroy commented Jun 21, 2023

Hello

Is there an example somewhere of how to use this new collator? I see in the source code that it inherits from DataCollatorForLanguageModeling, but this has mlm set to True by default and the mlm probability is 0.15. So to just use the data collator for completion, should we initialize it like so? Crucially disabling mlm?

DataCollatorForCompletionOnlyLM("### Response:\n", tokenizer, mlm=False)

EDIT: I see now that the new collator overwrites the torch_call method so mlm is never done. But I don't think that that is intuitive for the user because self.mlm = True. Maybe DataCollatorForCompletionOnlyLM can also pass mlm=False to the init of super? That makes things clearer.

@younesbelkada
Copy link
Contributor Author

Yes you are right, we should probably set the default mlm to False and leave the option to change it to True for superusers. Do you want to open a PR for that? The changes would be very minimal

@younesbelkada
Copy link
Contributor Author

Regarding the documentation we're currently working on reproducing stanford alpaca using that collator, but for now you should just create the data collator in your main script and pass it as a positional argument on the SFTTrainer's init

@BramVanroy
Copy link
Contributor

I can do a PR tomorrow.

Another issue that I encountered: in some cases you do not have the Response because the text is too long and the tokenizer truncates it. What should happen in those cases? Maybe a preprocessing function should already filter those cases out?

@younesbelkada
Copy link
Contributor Author

Yeah I imagine we can have a preprocessing function inside SFTTrainer that takes care of that indeed. It would be really great if you can add that to the PR as well. Otherwise happy to do it !

@MatousAc
Copy link

MatousAc commented Oct 11, 2023

At the very top, we are warned to add a space in our response template, otherwise the tokenizer will not produce the same tokens. I found this to be slightly insufficient, as I had to also prefix my response template with a <s>. When I included the leading separator in the argument for the data collator AND I included it in my training prompts (with no space before <s>), only then the collator was able to find the right token sequence and mask properly.

Hopes this helps someone else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to Instruction Tune with SFTTrainer?
6 participants