Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QDoRA: Support DoRA with BnB quantization #1518

Merged
merged 8 commits into from
Mar 12, 2024

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Feb 29, 2024

Adds support for DoRA on 4bit and 8bit quantized models with bitsandbytes. Merging also works, with the usual caveats for quantized weights (results are not 100% identical), but it's not worse than vanialla LoRA.

I did some quick tests and could see the expected memory savings with bnb. Same as with DoRA on non-quantized layers, using DoRA on quantized layers leads to a moderate increase in runtime.

WIP

Adds support for DoRA on 4bit and 8bit quantized models with BnB. For
now, merging is not implemented. I'll investigate this next.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan BenjaminBossan mentioned this pull request Mar 1, 2024
@BenjaminBossan BenjaminBossan marked this pull request as ready for review March 1, 2024 15:59
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me thank you @BenjaminBossan for adding quantization support for DoRA !

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for adding support to use DoRA with bnb quantized layers and the thorough tests! 🤗


```py
from peft import LoraConfig

config = LoraConfig(use_dora=True, ...)
```

DoRA should work with weights quantized with bitsandbytes ("QDoRA"). Issues have been reported when using QDoRA with DeepSpeed Zero2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be in a Caveats section wherein all such notes can be collated in one place.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a caveats section directly below, within the DoRA section. Is this what you had in mind?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan! ✨

@BenjaminBossan BenjaminBossan merged commit 3eb6bba into huggingface:main Mar 12, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the support-dora-with-bnb branch March 12, 2024 11:45
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
Adds support for DoRA on 4bit and 8bit quantized models with BnB.
Merging also works, with the usual caveats for quantized weights
(results are not 100% identical), but it's not worse than vanialla LoRA.
@mallorbc
Copy link

Qdora seems to be working for me. However, I am noticing a large slowdown comparing Qlora and Qdora of around 2x. I am not sure if that is expected or not. But this seems like a good place as any to share this finding.

@BenjaminBossan
Copy link
Member Author

I am noticing a large slowdown comparing Qlora and Qdora of around 2x. I am not sure if that is expected or not.

Yes, QDoRA unfortunately requires an additional dequantization step on the quantized weights to calculate the weight norm. I wouldn't expect this to slow down training by 2x, but a significant slowdown is expected. Maybe you can run some profiler to check further if you think it's worth investigating.

But this seems like a good place as any to share this finding.

You can also create new issues or discussions (if there aren't already existing ones) for this type of question.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request May 30, 2024
Don't pass load_in_8bit to AutoModel.from_pretrained, instead use
BitsAndBytesConfig.

There was already a PR to clean this up (huggingface#1552) but a slightly later
PR (huggingface#1518) re-added this usage.
BenjaminBossan added a commit that referenced this pull request May 30, 2024
Don't pass load_in_8bit to AutoModel.from_pretrained, instead use
BitsAndBytesConfig.

There was already a PR to clean this up (#1552) but a slightly later
PR (#1518) re-added this usage.

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants