Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / Quantization ] AWQ integration #27045

Merged
merged 72 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
efb1de8
working v1
younesbelkada Oct 24, 2023
a0eee34
oops
younesbelkada Oct 24, 2023
a06a604
Update src/transformers/modeling_utils.py
younesbelkada Oct 25, 2023
33f1134
Merge remote-tracking branch 'upstream/main' into add-awq
younesbelkada Oct 25, 2023
4b29b0b
fixup
younesbelkada Oct 25, 2023
ddc3ea2
oops
younesbelkada Oct 25, 2023
13abcf2
push
younesbelkada Oct 25, 2023
1561463
more changes
younesbelkada Oct 25, 2023
433148e
add docs
younesbelkada Oct 25, 2023
ee4d301
some fixes
younesbelkada Oct 25, 2023
d8616fb
fix copies
younesbelkada Oct 25, 2023
42cc2e5
Merge remote-tracking branch 'upstream/main' into add-awq
younesbelkada Oct 26, 2023
fd172bd
add v1 doc
younesbelkada Oct 26, 2023
5e155e1
added installation guide
younesbelkada Oct 26, 2023
4a1413e
relax constraints
younesbelkada Oct 26, 2023
014d901
revert
younesbelkada Oct 26, 2023
39b4e6a
attempt llm-awq
younesbelkada Oct 26, 2023
3dd93dd
oops
younesbelkada Oct 26, 2023
96bac28
oops
younesbelkada Oct 26, 2023
d58c461
fixup
younesbelkada Oct 26, 2023
150a5ec
raise error when incorrect cuda compute capability
younesbelkada Oct 26, 2023
477e17d
nit
younesbelkada Oct 26, 2023
717d044
add instructions for llm-awq
younesbelkada Oct 26, 2023
d73fe54
Merge branch 'main' into add-awq
younesbelkada Oct 26, 2023
21571c5
fixup
younesbelkada Oct 26, 2023
5c146bd
fix copies
younesbelkada Oct 26, 2023
23f23df
fixup and docs
younesbelkada Oct 26, 2023
5b44946
change
younesbelkada Oct 26, 2023
8fc16ab
few changes + add demo
younesbelkada Oct 27, 2023
a33d17c
Merge branch 'main' into add-awq
younesbelkada Oct 27, 2023
b252bd7
add v1 tests
younesbelkada Oct 27, 2023
010de2e
add autoawq in dockerfile
younesbelkada Oct 27, 2023
75a9a55
finalize
younesbelkada Oct 27, 2023
6c28089
Update tests/quantization/autoawq/test_awq.py
younesbelkada Oct 27, 2023
7f30911
fix test
younesbelkada Oct 27, 2023
21b25a1
fix
younesbelkada Oct 27, 2023
576b69d
Merge remote-tracking branch 'upstream/main' into add-awq
younesbelkada Oct 30, 2023
1c19912
fix issue
younesbelkada Oct 30, 2023
02aa242
Update src/transformers/integrations/awq.py
younesbelkada Oct 30, 2023
3dbdd2c
Update docs/source/en/main_classes/quantization.md
younesbelkada Oct 30, 2023
90f0ea5
Update docs/source/en/main_classes/quantization.md
younesbelkada Oct 30, 2023
f271173
Update src/transformers/integrations/awq.py
younesbelkada Oct 30, 2023
027f76a
Update src/transformers/integrations/awq.py
younesbelkada Oct 30, 2023
2940188
add link to example script
younesbelkada Oct 30, 2023
7bc3549
Update docs/source/en/main_classes/quantization.md
younesbelkada Oct 30, 2023
1008152
add more content
younesbelkada Oct 30, 2023
d585182
add more details
younesbelkada Oct 30, 2023
b19ba13
add link to quantization docs
younesbelkada Oct 30, 2023
963676e
camel case + change backend class name
younesbelkada Oct 30, 2023
a50639b
change to string
younesbelkada Oct 30, 2023
249492d
fixup
younesbelkada Oct 30, 2023
d870195
raise errors if libs not installed
younesbelkada Oct 30, 2023
bd6a90f
change to `bits` and `group_size`
younesbelkada Oct 30, 2023
702f990
nit
younesbelkada Oct 30, 2023
1e03581
nit
younesbelkada Oct 30, 2023
e97026c
Merge remote-tracking branch 'upstream/main' into add-awq
younesbelkada Oct 30, 2023
e4dcaca
Apply suggestions from code review
younesbelkada Oct 30, 2023
bd3c37a
disable training
younesbelkada Oct 30, 2023
c353ace
address some comments and fix nits
younesbelkada Oct 30, 2023
fa38b26
fix
younesbelkada Oct 30, 2023
53957f0
final nits and fix tests
younesbelkada Oct 30, 2023
790b2fe
adapt to our new runners
younesbelkada Oct 30, 2023
6bdf5de
make fix-copies
younesbelkada Oct 30, 2023
64cbf01
Update src/transformers/utils/quantization_config.py
younesbelkada Oct 30, 2023
88aec66
Update src/transformers/utils/quantization_config.py
younesbelkada Oct 30, 2023
97d335a
Update src/transformers/integrations/awq.py
younesbelkada Oct 30, 2023
b0e2868
Update src/transformers/integrations/awq.py
younesbelkada Oct 30, 2023
597ab7f
move to top
younesbelkada Oct 30, 2023
79cbbd3
add conversion test
younesbelkada Oct 30, 2023
ebf58e1
final nit
younesbelkada Oct 30, 2023
df9e691
add more elaborated test
younesbelkada Oct 31, 2023
3f4b1a1
Merge remote-tracking branch 'upstream/main' into add-awq
younesbelkada Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@
"is_vision_available",
"logging",
],
"utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"],
"utils.quantization_config": ["AWQConfig", "BitsAndBytesConfig", "GPTQConfig"],
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
}

# sentencepiece-backed objects
Expand Down Expand Up @@ -4923,7 +4923,7 @@
)

# bitsandbytes config
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig
from .utils.quantization_config import AWQConfig, BitsAndBytesConfig, GPTQConfig

try:
if not is_sentencepiece_available():
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


_import_structure = {
"awq": ["replace_with_awq_linear"],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
Expand Down Expand Up @@ -77,6 +78,7 @@
}

if TYPE_CHECKING:
from .awq import replace_with_awq_linear
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,
Expand Down
72 changes: 72 additions & 0 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_accelerate_available, is_auto_awq_available

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

if is_auto_awq_available():
import torch.nn as nn
from awq.modules.linear import WQLinear_GEMM

if is_accelerate_available():
from accelerate import init_empty_weights


def replace_with_awq_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""
Private method that wraps the recursion for module replacement.

Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
in_features = module.in_features
out_features = module.out_features

model._modules[name] = WQLinear_GEMM(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the version in config is GEMV, this will fail to load that version. Would it be appropriate to get the WQLinear based on the version in the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know WDYT of 13abcf2 !

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, and should work as intended. If you want to test GEMV without quantizing yourself, I have quantized Vicuna 7B v1.5 in GEMV:
https://huggingface.co/casperhansen/vicuna-7b-v1.5-awq-gemv

w_bit=quantization_config.w_bit,
group_size=quantization_config.q_group_size,
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
dev=module.weight.device,
)
has_been_replaced = True

# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_awq_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
18 changes: 17 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
extract_commit_hash,
has_file,
is_accelerate_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
Expand All @@ -89,7 +90,7 @@
is_torch_fx_proxy,
is_torchdynamo_compiling,
)
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.quantization_config import AWQConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.versions import require_version_core


Expand Down Expand Up @@ -2787,6 +2788,12 @@ def from_pretrained(
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")

quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
elif quantization_method_from_config == QuantizationMethod.AWQ:
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run AWQ quantized model.")

if not is_auto_awq_available():
raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")

if (
is_8bit_serializable
Expand Down Expand Up @@ -3224,6 +3231,15 @@ def from_pretrained(
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.convert_model(model)
model._is_quantized_training_enabled = True
elif quantization_method_from_config == QuantizationMethod.AWQ:
from .integrations import replace_with_awq_linear

if quantization_config is None:
quantization_config = AWQConfig.from_dict(config.quantization_config)

model, _ = replace_with_awq_linear(
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modules_to_not_convert=["lm_head"] may end up causing issues eventually if the head is not named "lm_head" exactly. The way we deal with avoiding this is by only looking at the decoder layers of the model in AutoAWQ by calling the model's get_model_layers() function (e.g. llama)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! For bnb we have the same issue and use this method: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py#L243 which should be quite generic for most transformers models. I planned to use that instead

)

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if quantization_method_from_config is not None:
model.quantization_method = quantization_method_from_config
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
get_torch_version,
is_accelerate_available,
is_apex_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
Expand Down Expand Up @@ -675,6 +677,10 @@ def is_optimum_available():
return _optimum_available


def is_auto_awq_available():
return _auto_awq_available


def is_auto_gptq_available():
return _auto_gptq_available

Expand Down
34 changes: 34 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"


@dataclass
Expand Down Expand Up @@ -418,3 +419,36 @@ def post_init(self):
f"""dataset needs to be either a list of string or a value in
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
)


@dataclass
class AWQConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `auto-awq` library awq quantization relying on auto_awq backend.

Args:
w_bit (`int`):
The number of bits to quantize to.
zero_point (`bool`, *optional*, defaults to `True`):
Whether to use zero point quantization.
q_group_size (`int`, *optional*, defaults to 128):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
version (`str`, *optional*, defaults to `GEMM`):
The version of the quantization algorithm to use.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
w_bit: int,
q_group_size: int = 128,
zero_point: bool = True,
version: str = "GEMM",
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ

self.w_bit = w_bit
self.q_group_size = q_group_size
self.zero_point = zero_point
self.version = version
Loading