Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new merging methods #1364

Merged
merged 23 commits into from
Feb 9, 2024
Merged

Add new merging methods #1364

merged 23 commits into from
Feb 9, 2024

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Jan 16, 2024

What does this PR do?

  1. Add new model merging methods for LoRA based on the papers TIES-MERGING: Resolving Interference When Merging Models and Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch. The methods are ties, dare_linear, dare_ties, ties_svd, dare_linear_svd, dare_ties_svd.
  2. The inspiration for the implementation of these methods are from https://github.com/yule-BUAA/MergeLM/tree/main and https://github.com/cg123/mergekit/tree/main.

Example of ties_svd is shown below (https://github.com/pacman100/peft-dreambooth-ui/blob/main/lora_merging.ipynb):
Screenshot 2024-01-17 at 10 45 07 AM

LLM LoRA merging example:
Screenshot 2024-01-17 at 6 12 49 PM

To do:

  • Add tests
  • Add documentation
  • Add example notebook

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! I left few open questions - what do you think?

for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A or adapter in target.lora_embedding_A:
valid_adapters.append(adapter)
valid_weights.append(weight)
valid_weights.append(weight * target.scaling[adapter])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change somehow breaking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Younes, this should be the correct usage, earlier the implementation was incorrect as it was missing scaling factor.

src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nits.

I really like how you have broken down the core merging methods in terms of small logical components that are shared across.

I think it definitely makes sense to also add a detailed guide about these methods on the PEFT doc site. But this can come later. Cc: @MKhalusova @stevhliu

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice!

Do you think it'd also be nice to create a separate API reference page for these merging utilities so it's easy for users to find? I can work on this in a separate PR in addition to the guide suggested by @sayakpaul :)

src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
@prateeky2806
Copy link
Contributor

Hi @pacman100 @younesbelkada @stevhliu @sayakpaul, I am Prateek Yadav, the author of TIES-Merging. Let me know if there is some way I can contribute to adding these merging methods to the Transformers or the PEFT library.

Also, I agree with others here that having some sort of changes to the documentation or updates to the README can help the users to be aware of these merging methods, and how to use them.

Let me know if there is some for me help!

Thanks,
Prateek

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@stevhliu stevhliu mentioned this pull request Jan 31, 2024
2 tasks
@prateeky2806
Copy link
Contributor

Hi @pacman100, I went over the code for TIES and it seems good to me. I just mentioning the embedding layer thing might be important at least for merging full models.

@prateeky2806
Copy link
Contributor

Let's go 🚀

How are we planning on documenting this as best as we can?

Hi @sayakpaul, I guess this is the pull request trying to document these merging methods

@sayakpaul
Copy link
Member

Ah, you're correct. Thanks for the mention.

@prateeky2806
Copy link
Contributor

Hi @sayakpaul @pacman100, is there a timeline on when this is expected to be merged?

@sayakpaul
Copy link
Member

I think we can merge now. @younesbelkada WDYT?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @pacman100 for this excellent PR. Super well done and clean implementation + refactor!

In my review, I didn't check the correctness of the implementations in detail, since the original authors already did that and are much more qualified for this. Instead, I focused on the other parts, please take a look. Most things should be fairly easy adjustments.

Regarding the new functions in merge_utils.py, I think it will be good to add some unit tests for those in a future PR.

src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/model.py Show resolved Hide resolved
new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type == "svd":
elif "svd" in combination_type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer:

Suggested change
elif "svd" in combination_type:
elif (combination_type == "svd") or (combination_type == "ties_svd"):

This is to prevent potential bugs in the future where we add a new combination type with a strange name that has "svd" as a substring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it is also meant for dare_linear_svd, dare_ties_svd too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. So maybe let's check all these options explicitly, or use if combination_type.endswith("svd"), which should be less error prone.

src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
return sign == majority_sign


def disjoint_merge(task_tensors: torch.Tensor, majority_sign_mask: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

task_tensors is not a list of tensors but a tensor, right? Below, task_tensors is always a list of tensors. I wonder if it would make sense to use a different variable name for this function to avoid confusion.

src/peft/utils/merge_utils.py Outdated Show resolved Hide resolved
tests/testing_common.py Show resolved Hide resolved
@prateeky2806
Copy link
Contributor

Hi @pacman100 and @BenjaminBossan, I went into detail over the merge_utils file and left some comments, I feel like some parts need to be corrected and might have bugs if I am not missing something.

I think I accidentally left many comments as opposed to reviewing but I guess that should be fine.

pacman100 and others added 6 commits February 8, 2024 12:28
Co-Authored-By: Prateek Yadav <15224633+prateeky2806@users.noreply.github.com>
Co-Authored-By: Yu Le <55241218+yule-buaa@users.noreply.github.com>
Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the issues from my side. This LGTM now. One small comment left, but it's not a must.

new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type == "svd":
elif "svd" in combination_type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. So maybe let's check all these options explicitly, or use if combination_type.endswith("svd"), which should be less error prone.

@pacman100 pacman100 merged commit c1a83fd into main Feb 9, 2024
14 checks passed
@prateeky2806
Copy link
Contributor

Hi @pacman100, Thanks for working on this and merging this pull request. As we discussed earlier, I was wondering if this would be announced as a part of the next PEFT version or if should I just announce it on my Twitter right away. I feel like it would be much better to announce it with the next PEFT version, however, I am not sure when that would be. Moreover, the PR on adding the docs is still ongoing.

Thanks,
Prateek

@sayakpaul sayakpaul deleted the smangrul/add-new-merging-methods branch February 12, 2024 02:52
@sayakpaul
Copy link
Member

Hey Prateek!

@pacman100 and I are working on a blog post to discuss this which we plan to release very soon. We will of course give everyone the due credits there. I think it makes sense to also do a release for this feature so that users don't have to install peft from the main.

Does that work for you?

@prateeky2806
Copy link
Contributor

prateeky2806 commented Feb 12, 2024

Hi @sayakpaul, yes that sounds good to me looking forward to it. I will tweet about this once you do this release. Moreover, if you someone to proof read the blog post then I would be happy to help you guys with it.

Prateek

@yule-BUAA
Copy link
Contributor

Hi @pacman100 and @sayakpaul,
Thanks for merging this PR into the main branch of peft. I would also be glad to do proofreading for the doc or blog post on merging models if there are any things in need.

Le Yu

@sayakpaul
Copy link
Member

Appreciate all the support. We will keep you posted.


def ties(
task_tensors: List[torch.Tensor],
weights: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pacman100 @sayakpaul, I have just one last suggestion. It would be ideal to set the default weight for TIES to all ones because that's a setting we experimented with in our paper and kind of works well if people do not want to try out different values. You can look at the results with the red cross in this table, they are with default weights of all ones. This makes TIES as simple to use as basic averaging by avoiding the necessity to select the weights.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So by default TIES should have 20% of the parameters remaining and mixing weight = 1 for each peft module.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
* add code

* update docstring

* quality

* fix test

* fix test

* fix svd embedding layer merging

* fixes

* fixes

* Update model.py

* Add test and example

* quality

* fix tests

* update the example

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* address comments

* address comments and add co-authors

Co-Authored-By: Prateek Yadav <15224633+prateeky2806@users.noreply.github.com>
Co-Authored-By: Yu Le <55241218+yule-buaa@users.noreply.github.com>
Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* quality

* Update merge_utils.py

* revert

* address comments

* address comment

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Prateek Yadav <15224633+prateeky2806@users.noreply.github.com>
Co-authored-by: Yu Le <55241218+yule-buaa@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants