-
-
Notifications
You must be signed in to change notification settings - Fork 892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements SPPO Alignment Algoritm #1735
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @kaykyr, Thanks for submitting this technique. I'd love to see this integrated into axolotl, but my main concern is the amount of duplicated code we're going to have to maintain. I'm happy to help refactor the pieces in the trainer_builder, but I think it would be ideal if we could extract the necessary SPPO changes from DPOTrainer so we have a smaller footprint to maintain.
re: tests, would be good to have some tests to spot check the functionality. I'm happy to help with this as well, where we setup some e2e tests that run a small model for about 10-20 steps to verify that the trainer works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably go in src/axolotl/core/trainers
def create_optimizer(self): | ||
if self.args.loraplus_lr_ratio is None: | ||
return super().create_optimizer() | ||
|
||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | ||
if self.optimizer is None: # pylint: disable=access-member-before-definition | ||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( | ||
self.args, | ||
opt_model, | ||
) | ||
|
||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) | ||
if loraplus_lr_ratio: | ||
print("Using lora+") | ||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) | ||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init | ||
opt_model, | ||
optimizer_cls, | ||
optimizer_kwargs, | ||
loraplus_lr_ratio, | ||
loraplus_lr_embedding, | ||
) | ||
|
||
if is_sagemaker_mp_enabled(): | ||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init | ||
self.optimizer | ||
) | ||
|
||
return self.optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it might be worth extracting this as a AxolotlCreateOptimizerMixin
and then including it in both here and the AxolotlTrainer
|
||
return self.optimizer | ||
|
||
@wraps(DPOTrainer.push_to_hub) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is DPOTrainer
correct for this?
def push_to_hub(self, *args, **kwargs) -> str: | ||
""" | ||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the | ||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. | ||
""" | ||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) | ||
|
||
return super().push_to_hub(*args, **kwargs) | ||
|
||
def tokenize_row( | ||
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None | ||
) -> Dict: | ||
res = super().tokenize_row(feature, model=model) | ||
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None: | ||
for key in res.keys(): | ||
res[key] = res[key][1:] | ||
return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more duplicated code that makes me think we should be extracting this into a Mixin.
if is_deepspeed_available(): | ||
import deepspeed | ||
|
||
class SPPOTrainer(Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty big class that seems to duplicate a lot of code from likely the DPOTrainer I would assume? Would it make more sense to extend the DPOTrainer and just implement the necessary changes?
def sppo_argilla_chat( | ||
cfg, | ||
**kwargs, | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
for argilla/dpo-mix-7k conversations | ||
""" | ||
|
||
def transform_fn(sample): | ||
sample[ | ||
"prompt" | ||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" | ||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" | ||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" | ||
sample["chosen_probs"] = sample['chosen_probs'] | ||
sample["chosen_probs_lose"] = sample['chosen_probs_lose'] | ||
sample["chosen_probs_win"] = sample['chosen_probs_win'] | ||
return sample | ||
|
||
return transform_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because this is in the dpo
path, which iirc, it is loaded based on the rl: ...
setting, I'm not sure that this will load as expected.
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Hey @winglian, I'll do my best to submit a better pull request doing a better approach to SPPO Integration. |
I am also running 3 iterations and I'll upload the result models to hugging face for comparison... At this momento I am running the iter2 on my homelab.
|
Implements SPPO Alignment Algorithm
Description
This pull request implements the Self-Play Preference Optimization (SPPO) algorithm for language model alignment. The SPPO algorithm, as described in the paper "Self-Play Preference Optimization for Language Model Alignment" (available at https://arxiv.org/abs/2405.00675), uses a self-play mechanism to optimize language models based on preference probabilities. This implementation leverages the code from the original repository at https://github.com/uclaml/SPPO and integrates it into the Axolotl framework.
Motivation and Context
This change is required to improve the alignment of language models with human preferences, addressing issues of reliability, safety, and ethical considerations in language model outputs. The SPPO algorithm provides a more flexible and accurate method for preference optimization compared to traditional reinforcement learning approaches.
How has this been tested?
The implementation has been tested using a variety of prompts from the UltraFeedback dataset, evaluating the model's performance on AlpacaEval 2.0 and MT-Bench. The tests involved assessing the log-likelihood of chosen responses and comparing the model's win rates against state-of-the-art models, ensuring that the changes do not adversely affect other areas of the codebase.
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
GitHub: @kaykyr
HuggingFace: https://huggingface.co/kaykyramos
Discord: kaykyramos