-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SFTTrainer
] Introducing DataCollatorForCompletionOnlyLM
#445
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Nice PR @younesbelkada!! There are two issues. Prompt leading spacesFirst, the response token issue can be resolved by removing the leading spaces with the promps. Note that the if len(input_text) >= 2:
text = f'''\
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input_text}
### Response:
{response}
''' instead of if len(input_text) >= 2:
text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input_text}
### Response:
{response}
''' Identify the location of response tokensThe reference implementation from dolly seems incorrect. Its implementation word instead matches This is because it breaks when the first token matches, but To resolve the issue, 864948f ensures all four tokens match. I gave it a quick run, but wandb is not recoding anything... https://wandb.ai/costa-huang/huggingface/runs/dcld1hg6/overview?workspace=user-costa-huang. Am I missing some configuration like the |
Small nit: I would add it to the main init so one can import it via |
Thanks a lot @vwxyzjn for digging deeper into that! |
The changes LGTM. It would be great to add some docs and potentially have a stanford alpaca example; if we have the bandwidth we can probably run some tracked experiments and make the tracked metrics and HF models available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello Is there an example somewhere of how to use this new collator? I see in the source code that it inherits from DataCollatorForCompletionOnlyLM("### Response:\n", tokenizer, mlm=False) EDIT: I see now that the new collator overwrites the torch_call method so mlm is never done. But I don't think that that is intuitive for the user because |
Yes you are right, we should probably set the default |
Regarding the documentation we're currently working on reproducing stanford alpaca using that collator, but for now you should just create the data collator in your main script and pass it as a positional argument on the SFTTrainer's init |
I can do a PR tomorrow. Another issue that I encountered: in some cases you do not have the Response because the text is too long and the tokenizer truncates it. What should happen in those cases? Maybe a preprocessing function should already filter those cases out? |
Yeah I imagine we can have a preprocessing function inside |
At the very top, we are warned to add a space in our response template, otherwise the tokenizer will not produce the same tokens. I found this to be slightly insufficient, as I had to also prefix my response template with a Hopes this helps someone else. |
What does this PR do?
Fixes: #426
This PR introduces
DataCollatorForCompletionOnlyLM
data collator that masks out all the prompts that are before completion, similarly at what is done here: https://github.com/databrickslabs/dolly/blob/master/training/trainer.py#L48-L77The goal for that data collator is to find where the target response template token is located in the sentence, and mask out all the tokens before the response token to attend only on the completions.
Currently the API looks as follows:
Handy reproducible snippet
Currently, for some reason the data collator cannot find the response token because of the issue I describe in the snippet below:
As you can see the first token of
### Response:
(21017) is replace by 44386.EDIT: the issue appeared to be quite straightforward, one needs to replace the response template to
### Response:
instead of### Response:
since the tokenizer will encode###
differently from###
cc @vwxyzjn