Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for different layer shapes for VeRA #1817

Merged
merged 5 commits into from
Jun 10, 2024

Conversation

dkopi
Copy link
Contributor

@dkopi dkopi commented Jun 3, 2024

PR for #1816.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@BenjaminBossan
Copy link
Member

Thanks for the PR!

Solution to that would be creating the largest required A & B matrices, and slicing it accordingly for each adapted layer.

I haven't looked into the details yet, but I wondered about the reasoning for this approach. You probably also considered to save one set of weights for each unique shape (perhaps transposition invariant). What would you think is the advantage of the proposed solution, does it save on average on the number of parameters, does it work better? How would you handle parameters with different dimensions, like Conv2d?

@dkopi
Copy link
Contributor Author

dkopi commented Jun 4, 2024

Slicing creates a view of the original tensor so in most cases it should be more memory-efficient than creating separate sets of A&B, and it better fits the current approach which is - one pair of A&B for the whole adapter. For layers like Conv2d, we could slice the large A&B first, and then create a view to give it proper dimensions.

Orthogonally to that, we could have an option in the config to have separate A&B for every adapted layer, that would be the equivalent of the ablation from the paper (table 7). I don't see any benefit of having something in between - a separate pair per different shape.

@BenjaminBossan
Copy link
Member

Yes, my expectation would also be that slicing from a single big matrix should result in fewer parameters overall than having one weight for each unique shape. There could be some exceptions, e.g. say we have a shape 100x10 and a shape 10x100, the total shape would be 100x100, which is much bigger than 100x10 + 10x100. But I'm not sure if there would be any practical examples of that, since lora_A and lora_B should generally have somewhat similar shapes for different layers (disregarding higher dimensional weights for now). I was just curious if the proposed implementation is based on some empirical finding (I haven't seen it described in the paper).

Even if this is better for overall parameter count, one potential disadvantage of this approach is, however, that it means we need to use the weights directly:

                sliced_A = vera_A[:, : self.in_features]
                sliced_B = vera_B[: self.out_features, :]

                dropout = self.vera_dropout[active_adapter]
                x = x.to(lambda_d.dtype)
                result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), sliced_A), sliced_B)

(https://github.com/huggingface/peft/pull/1817/files#diff-681ccefeb9c0abc7bcde2c387ed7be1a5271713483d6871018579c8d0b9a96c9R255-R260)

I.e. we cannot go through the forward call of the adapter layers (something like lambda_b(lora_B(lambda_d(lora_A(dropout(x)))))). I know that the current VeRA code also doesn't do that, but I think it could be changed to make it happen.

Why would want to go through the forward call? In some circumstances, this is important because the module is wrapped or hooks are being added such that some important functionality is executed when module.forward or module.__call__ are being called. This is something that I came across recently when working with FSDP. My concern is now that if we go with the slicing approach, we could end up in a situation where we cannot possibly make this work. In contrast, if we create one layer per shape, we leave the door open to refactor such that forward is called on these modules.

Orthogonally to that, we could have an option in the config to have separate A&B for every adapted layer, that would be the equivalent of the ablation from the paper (table 7).

This to me seems to be an inferior approach compared to having one set of weights per shape (though easier to implement), as I think VeRA is all about minimizing the number of trainable parameters as much as possible.

Anyway, this is a all bit hypothetical at the moment and I just wanted to put write down my thoughts and hear your opinion. One step forward would be to try out VeRA in different settings that involve these types of wrapping and hooks like offloading with accelerate or using DeepSpeed/FSDP and check if it works.

@dkopi
Copy link
Contributor Author

dkopi commented Jun 4, 2024

Regarding the edge case of two "long and narrow" tensors of different "orientation", e.g. 100x10 and 10x100 - resulting A & B would have shapes (r, 100) & (100, r), not (100x100). So it would be 200*r vs 220*r.

I was just curious if the proposed implementation is based on some empirical finding (I haven't seen it described in the paper).

I didn't compare two implementations, but imo both should be equivalent in outcomes, with negligible differences in memory usage.

Even if this is better for overall parameter count, one potential disadvantage of this approach is, however, that it means we need to use the weights directly

Actually, my first version of implementation was temporarily swapping .data of tensors, so it could be changed back to that.
Then it would be something like:

                sliced_A = vera_A.data[:, : self.in_features]
                sliced_B = vera_B.data[: self.out_features, :]
                init_A = vera_A.data
                init_B = vera_B.data
                vera_A.data = sliced_A
                vera_B.data = sliced_B

                dropout = self.vera_dropout[active_adapter]
                x = x.to(lambda_d.dtype)
                result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B)
                vera_A.data = init_A
                vera_B.data = init_B

But done on the weights of adapted layers. Still, I'm not sure if it would work with stuff like FSDP.

This to me seems to be an inferior approach compared to having one set of weights per shape (though easier to implement), as I think VeRA is all about minimizing the number of trainable parameters as much as possible.

True, it would make sense only in combination with the regeneration from a seed. Then we still store only one seed per adapted layer.

The most flexible and "future proof" approach would be to give an option to have basis A&B matrices per:

  • model
  • adapted layer (ablation showed slight advantage of it, so would be good to have)
  • shape of adapted layers

However, it would require the most work and could be confusing for new users of VeRA.


On a side note (completely unrelated to this PR), A & B could be potentially quantized as they are not trainable, leading to lower memory usage. I haven't tested it yet tho.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your replies. At the end of the day, I think it's okay to go ahead as is. If we do run into the issue that users want to use VeRA with DS or FSDP and it doesn't work, I'm sure we can figure out a way. This PR shouldn't break anything in that regard which currently works.

The implementation is super sweet, nice that only a few lines needed to be changed. I only have a few small comments, please check.

Actually, my first version of implementation was temporarily swapping .data of tensors, so it could be changed back to that.

That probably would also run into trouble, so let's not do it that way.

the most flexible and "future proof" approach would be to give an option to have basis A&B matrices per: ...
However, it would require the most work and could be confusing for new users of VeRA.

I agree.

On a side note (completely unrelated to this PR), A & B could be potentially quantized as they are not trainable, leading to lower memory usage. I haven't tested it yet tho.

Cool idea, if you have some time to test it and find that it works, that would be great. My intuition is, however, that it won't be quite trivial to implement and of course it will introduce more quantization error.

tests/test_vera.py Show resolved Hide resolved
docs/source/package_reference/vera.md Show resolved Hide resolved
src/peft/tuners/vera/layer.py Show resolved Hide resolved
tests/test_vera.py Show resolved Hide resolved
@dkopi dkopi requested a review from BenjaminBossan June 8, 2024 15:45
@dkopi
Copy link
Contributor Author

dkopi commented Jun 8, 2024

I've addressed the requested changes and fixed merging of the adapter - adding the new test to test_custom_models.py revealed that merging also required an update.

@BenjaminBossan
Copy link
Member

Thanks @dkopi. Could you please run make style so that CI will pass?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dkopi
Copy link
Contributor Author

dkopi commented Jun 10, 2024

@BenjaminBossan Done. I just skipped two files modified by "make style" - lycoris_utils.py and router.py, to avoid potential merge conflicts.

@dkopi
Copy link
Contributor Author

dkopi commented Jun 10, 2024

@BenjaminBossan One of the tests failed with requests.exceptions.ReadTimeout: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10).

Is it an issue with HF servers? Could we rerun it?

@BenjaminBossan
Copy link
Member

Is it an issue with HF servers? Could we rerun it?

Yes, no worries, I'll get a notification and re-run those tests.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks fantastic, thanks. Updating the notebook was a nice touch.

@BenjaminBossan BenjaminBossan merged commit 7b1c08d into huggingface:main Jun 10, 2024
14 checks passed
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 22, 2024
Some pre-configured models like mistral used not to work with VeRA
because the weight shapes were not identical. However, since huggingface#1817, this
is no longer a requirement. Therefore, this commented code can now be
uncommented.

I have tested mistral and gemma and they worked. I haven't tested btlm
and mixtral but with the update, I'm pretty sure they will work too.
sayakpaul pushed a commit that referenced this pull request Jul 30, 2024
Some pre-configured models like mistral used not to work with VeRA
because the weight shapes were not identical. However, since #1817, this
is no longer a requirement. Therefore, this commented code can now be
uncommented.

I have tested mistral and gemma and they worked. I haven't tested btlm
and mixtral but with the update, I'm pretty sure they will work too.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants