Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / Quantization ] AWQ integration #27045

Merged
merged 72 commits into from
Nov 1, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 24, 2023

What does this PR do?

As per title, this PR adds the AWQ inference support in transformers.

Screenshot 2023-10-24 at 17 56 56

AWQ is a new and popular quantization scheme, already used in various libraries such as TGI, vllm, etc. and known to be faster than GPTQ models according to some benchmarks

Contrary to GPTQ, in this integration we want to support inference only - since the ecosystem is quite mature with respect to quantizing a model, we will publicize different routes for that purpose, such as using auto-awq, the original repository or optimum Neural Compressor.

For now I have pushed a 'test' model under this repository: https://huggingface.co/ybelkada/test-mistral-7b-v0.1-awq but we plan to support all AWQ weights from TheBloke. For running experiments using this PR, you can first pip install autoawq then run:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "ybelkada/test-mistral-7b-v0.1-awq"

tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(0)

print(model)

text = ["Hello my name is", "hi"]
input_ids = tok.encode(text, return_tensors="pt").to(0)

output = model.generate(input_ids, max_new_tokens=40)
print(tok.batch_decode(output, skip_special_tokens=True))

TODO:

  • Benchmarks
  • Documentation
  • Support fused modules
  • Colab inference demo
  • Write tests
  • Support weights that have been quantized with optimum NC
  • Support weights that have been quantized with llm-awq

cc @fxmarty @SunMarc @casper-hansen @TheBloke @IlyasMoutawwakil

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@casper-hansen
Copy link

This is super exciting to see! The original repository does not support TheBloke’s quants, they were made with AutoAWQ - perhaps an argument to route to AutoAWQ for compatibility.

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
in_features = module.in_features
out_features = module.out_features

model._modules[name] = WQLinear_GEMM(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the version in config is GEMV, this will fail to load that version. Would it be appropriate to get the WQLinear based on the version in the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know WDYT of 13abcf2 !

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, and should work as intended. If you want to test GEMV without quantizing yourself, I have quantized Vicuna 7B v1.5 in GEMV:
https://huggingface.co/casperhansen/vicuna-7b-v1.5-awq-gemv

quantization_config = AWQConfig.from_dict(config.quantization_config)

model, _ = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modules_to_not_convert=["lm_head"] may end up causing issues eventually if the head is not named "lm_head" exactly. The way we deal with avoiding this is by only looking at the decoder layers of the model in AutoAWQ by calling the model's get_model_layers() function (e.g. llama)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! For bnb we have the same issue and use this method: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py#L243 which should be quite generic for most transformers models. I planned to use that instead

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Left a few comments.

docs/source/en/main_classes/quantization.md Outdated Show resolved Hide resolved
docs/source/en/main_classes/quantization.md Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/integrations/awq.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
tests/quantization/autoawq/test_awq.py Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🔥

Main comment is the need for tests to check that layers are replaced as expected. Once those are added it'll be good to go!

src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
Comment on lines +489 to +490
if major < 8:
raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice set of tests :)

Only thing that should be added is a test which takes a dummy or real model and checks the recursive logic of replace_with_awq_linear acts as expected, as this is the most critical part. In particular, we should check that the expected layers are converted and modules_to_not_convert is respected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, this is possible yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the conversion test here: 79cbbd3

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Just one last thing. The test will pass as soon as the first converted layer is found. However, what is important to test is the recursive properly i.e. are layers within a module converted? And are ones which are listed in modules_to_not_convert not converted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice catch! OK will elaborate the test a bit more and let you know

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a better test in df9e691 - LMK what do you think!

@@ -3224,6 +3262,17 @@ def from_pretrained(
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.convert_model(model)
model._is_quantized_training_enabled = True
elif quantization_method_from_config == QuantizationMethod.AWQ:
from .integrations import get_keys_to_not_convert, replace_with_awq_linear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on @younesbelkada's comment here

src/transformers/integrations/awq.py Outdated Show resolved Hide resolved
src/transformers/integrations/awq.py Outdated Show resolved Hide resolved
Comment on lines 77 to 82
if backend == AwqBackendPackingMethod.AUTOAWQ:
target_cls = (
WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
)
else:
target_cls = WQLinear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be out of the for-loop in the lines above - we don't need to keep redefining it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved in 597ab7f

Comment on lines +2809 to +2812
logger.warning(
"You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK to not throw and error in this case. However, we should still properly handle if the user tries to do inference on CPU. Is this caught with an exception somewhere?

younesbelkada and others added 7 commits October 30, 2023 18:23
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - thanks for iterating with the test ❤️

There's a final iteration before we can merge to avoid using a large model in the test.

quantization_config = AwqConfig(bits=4)

with init_empty_weights():
model = OPTForCausalLM(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use OPTForCausalLM specifically here. Why not AutoModelForCausalLM?

Copy link
Contributor Author

@younesbelkada younesbelkada Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we have to use OPTForCausalLM because initializing a model with a config file through AutoModelForCausalLM is not supported:

>>> AutoModelForCausalLM(config)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/younes_huggingface_co/code/transformers/src/transformers/models/auto/auto_factory.py", line 411, in __init__
    raise EnvironmentError(
OSError: AutoModelForCausalLM is designed to be instantiated using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or `AutoModelForCausalLM.from_config(config)` methods.

config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = AwqConfig(bits=4)

with init_empty_weights():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will need a @require_accelerate decorator if this is used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The require_accelerate decorator is already being set at AwqTest class

self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
from transformers.integrations.awq import replace_with_awq_linear

model_id = "facebook/opt-350m"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a checkpoint (we don't need trained weights here) and why this checkpoint when a smaller one exists?

Even though init_empty_weights is used here - this is just handling loading into ram, it still requires the checkpoint to be downloaded --> larger models make things slower.

If you replace this with a hf-internal-testing tiny model OR a v. small model defined with a model config this will be more lightweight. In this case we might even be able to remove the accelerate dependency.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tiny model testing I use this - a 68M Llama model: https://huggingface.co/JackFram/llama-68m

Copy link
Contributor Author

@younesbelkada younesbelkada Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts thanks!
the init_empty weights context manager will create a model on the meta device, it will cost 0 RAM regardless of the architecture. during the entire test I am only dealing with the model that is loaded on the meta device and I am not downloading any checkpoint.

Screenshot 2023-10-31 at 14 00 48

it still requires the checkpoint to be downloaded --> larger models make things slower.

The only thing that I download is the config file, no model weight is being added here. The accelerate dependency is fine since we use it for all quantization integrations + always downloaded in our docker images for slow tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad - I misread and thought from_pretrained was being used in the model creation i.e. model = XxxForCausalLM.from_pretrained

@younesbelkada
Copy link
Contributor Author

Thanks @amyeroberts for your review! I have replied to your questions above, I believe that all the points are already being addressed. Let me know if I missed anything

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding and iterating on this!

@younesbelkada
Copy link
Contributor Author

Thanks for all your reviews @amyeroberts @ArthurZucker @SunMarc @casper-hansen !

@younesbelkada younesbelkada merged commit ae093ee into huggingface:main Nov 1, 2023
22 checks passed
@younesbelkada younesbelkada deleted the add-awq branch November 1, 2023 08:06
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* working v1

* oops

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fixup

* oops

* push

* more changes

* add docs

* some fixes

* fix copies

* add v1 doc

* added installation guide

* relax constraints

* revert

* attempt llm-awq

* oops

* oops

* fixup

* raise error when incorrect cuda compute capability

* nit

* add instructions for llm-awq

* fixup

* fix copies

* fixup and docs

* change

* few changes + add demo

* add v1 tests

* add autoawq in dockerfile

* finalize

* Update tests/quantization/autoawq/test_awq.py

* fix test

* fix

* fix issue

* Update src/transformers/integrations/awq.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add link to example script

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add more content

* add more details

* add link to quantization docs

* camel case + change backend class name

* change to string

* fixup

* raise errors if libs not installed

* change to `bits` and `group_size`

* nit

* nit

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* disable training

* address some comments and fix nits

* fix

* final nits and fix tests

* adapt to our new runners

* make fix-copies

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* move to top

* add conversion test

* final nit

* add more elaborated test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants