Skip to content

Commit

Permalink
support quant_lm_head arg in all WOQ configs (#1881)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <xin3.he@intel.com>
  • Loading branch information
xin3he authored Jun 27, 2024
1 parent cc763f5 commit 4ae2e87
Show file tree
Hide file tree
Showing 19 changed files with 379 additions and 181 deletions.
32 changes: 26 additions & 6 deletions docs/3x/PT_WeightOnlyQuant.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

PyTorch Weight Only Quantization
===============

- [Introduction](#introduction)
- [Supported Matrix](#supported-matrix)
- [Usage](#usage)
Expand Down Expand Up @@ -28,7 +29,6 @@ Besides, as mentioned in many papers[1][2], activation quantization is the main

Theoretically, round-to-nearest (RTN) is the most straightforward way to quantize weight using scale maps. However, when the number of bits is small (e.g. 3), the MSE loss is larger than expected. A group size is introduced to reduce elements using the same scale to improve accuracy.


## Supported Matrix

| Algorithms/Backend | PyTorch eager mode |
Expand Down Expand Up @@ -58,25 +58,29 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz
WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./PyTorch.md#quantization-apis).

#### Common arguments

| Config | Capability |
|---|---|
| dtype (str)| ['int', 'nf4', 'fp4'] |
| bits (int)| [1, ..., 8] |
| group_size (int)| [-1, 1, ..., $C_{in}$] |
| use_sym (bool)| [True, False] |
| quant_lm_head (bool)| [False, True] |
| use_double_quant (bool) | [True, False] |
| double_quant_dtype (str) | ['int'] |
| double_quant_bits (int) | [1, ..., bits] |
| double_quant_use_sym (bool) | [True, False] |
| double_quant_group_size (int) | [-1, 1, ..., $C_{in}$] |

Notes:

- *group_size = -1* refers to **per output channel quantization**. Taking a linear layer (input channel = $C_{in}$, output channel = $C_{out}$) for instance, when *group size = -1*, quantization will calculate total $C_{out}$ quantization parameters. Otherwise, when *group_size = gs* quantization parameters are calculate with every $gs$ elements along with the input channel, leading to total $C_{out} \times (C_{in} / gs)$ quantization parameters.
- 4-bit NormalFloat(NF4) is proposed in QLoRA[7]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb.
- Only RTN and GPTQ support double quant.

- *quant_lm_head* defaults to False. This means that, except for transformer blocks, the last layer in transformer models will not be quantized by default. The last layer may be named "lm_head", "output_layer" or "embed_out".
- Only RTN and GPTQ support double quant.

#### RTN

| rtn_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| group_dim (int) | Dimension for grouping | 1 |
Expand All @@ -86,6 +90,7 @@ Notes:
| model_path (str) | Model path that is used to load state_dict per layer | |

> **Notes:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.
``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
Expand All @@ -96,6 +101,7 @@ model = convert(model)
```

#### GPTQ

| gptq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| use_mse_search (bool) | Enables mean squared error (MSE) search | False
Expand All @@ -107,6 +113,7 @@ model = convert(model)
| block_size (int) | Execute GPTQ quantization per block, block shape = [C_out, block_size] | 128 |
| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. |
> **Note:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.
``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, GPTQConfig
Expand All @@ -118,6 +125,7 @@ model = convert(model)
```

#### AutoRound

| autoround_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| enable_full_range (bool) | Whether to enable full range quantization | False
Expand All @@ -138,6 +146,7 @@ model = convert(model)
| not_use_best_mse (bool) | Whether to use mean squared error | False |
| dynamic_max_gap (int) | The dynamic maximum gap | -1 |
| scale_dtype (str) | The data type of quantization scale to be used, different kernels have different choices | "float16" |

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, AutoRoundConfig
Expand All @@ -149,6 +158,7 @@ model = convert(model)
```

#### AWQ

| awq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| group_dim (int) | Dimension for grouping | 1 |
Expand All @@ -159,6 +169,7 @@ model = convert(model)
| use_auto_clip (bool) | Enables clip range search | True |
| folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False. |
> **Notes:** `layer-wise` is stay-tuned.
``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, AWQConfig
Expand All @@ -170,6 +181,7 @@ model = convert(model)
```

#### TEQ

| teq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| group_dim (int) | Dimension for grouping | 1 |
Expand All @@ -179,6 +191,7 @@ model = convert(model)
| use_double_quant (bool) | Enables double quantization | False |
| folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False |
> **Notes:** `layer-wise` is stay-tuned.
``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, TEQConfig
Expand All @@ -190,12 +203,13 @@ model = convert(model)
```

#### HQQ

| hqq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| quant_zero (bool) | Whether to quantize zero point | True |
| quant_scale: (bool) | Whether to quantize scale: point | False |
| scale_quant_group_size (int) | The group size for quantizing scale | 128 |
| skip_lm_head (bool) | Whether to skip for quantizing lm_head | True |

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, HQQConfig
Expand All @@ -205,10 +219,13 @@ model = prepare(model, quant_config)
run_fn(model) # calibration
model = convert(model)
```

### Specify Quantization Rules

Intel(R) Neural Compressor support specify quantization rules by operator name or operator type. Users can set `local` in dict or use `set_local` method of config class to achieve the above purpose.

1. Example of setting `local` from a dict

```python
quant_config = {
"rtn": {
Expand All @@ -226,15 +243,19 @@ quant_config = {
}
}
```

2. Example of using `set_local`

```python
quant_config = RTNConfig()
lm_head_config = RTNConfig(dtype="fp32")
quant_config.set_local("lm_head", lm_head_config)
```

### Saving and Loading

The saved_results folder contains two files: quantized_model.pt and qconfig.json, and the generated model is a quantized model. The quantitative model will include WeightOnlyLinear. To support low memory inference, Intel(R) Neural Compressor implemented WeightOnlyLinear, a torch.nn.Module, to compress the fake quantized fp32 model. Since torch does not provide flexible data type storage, WeightOnlyLinear combines low bits data into a long date type, such as torch.int8 and torch.int32. Low bits data includes weights and zero points. When using WeightOnlyLinear for inference, it will restore the compressed data to float32 and run torch linear function.

```python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
Expand All @@ -255,7 +276,6 @@ loaded_model = load(
) # Please note that the original_model parameter passes the original model.
```


## Examples

Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only) on how to quantize a model with WeightOnlyQuant.
Expand All @@ -272,6 +292,6 @@ Users can also refer to [examples](https://github.com/intel/neural-compressor/bl

[5]. Cheng, Wenhua, et al. "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs" arXiv preprint arXiv:2309.05516 (2023).

[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: https://mobiusml.github.io/hqq_blog/ (2023).
[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: <https://mobiusml.github.io/hqq_blog/> (2023).

[7]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023).
21 changes: 21 additions & 0 deletions docs/3x/PyTorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,24 @@ def load(output_dir="./saved_results", model=None):
</tr>
</tbody>
</table>
2. How to set different configuration for specific op_name or op_type?
> INC extends a `set_local` method based on the global configuration object to set custom configuration.
```python
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
"""Set custom configuration based on the global configuration object.
Args:
operator_name_or_list (Union[List, str, Callable]): specific operator
config (BaseConfig): specific configuration
"""
```

> Demo:

```python
quant_config = RTNConfig() # Initialize global configuration with default bits=4
quant_config.set_local(".*mlp.*", RTNConfig(bits=8)) # For layers with "mlp" in their names, set bits=8
quant_config.set_local("Conv1d", RTNConfig(dtype="fp32")) # For Conv1d layers, do not quantize them.
```
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,12 @@ def get_user_model():
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
for batch in tqdm(dataloader_for_calibration):
batch = move_input_to_device(batch, device=None)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
except ValueError:
pass
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
return
if args.double_quant_type is not None:
double_quant_config_dict.update(
Expand Down
23 changes: 19 additions & 4 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,25 @@ def local_config(self):
def local_config(self, config):
self._local_config = config

def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name] = config
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
"""Set custom configuration based on the global configuration object.
Args:
operator_name_or_list (Union[List, str, Callable]): specific operator
config (BaseConfig): specific configuration
Returns:
Updated Config
"""
if isinstance(operator_name_or_list, list):
for operator_name in operator_name_or_list:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name] = config
else:
if operator_name_or_list in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name_or_list] = config
return self

def to_dict(self):
Expand Down
13 changes: 13 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,18 @@ def forward(layer, *args, **kwargs):
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
)
# Step 3: replace model_forward to avoid ValueError
self.orig_model_forward_cache = self.model.forward
model_forward_cache = self.model.forward

def model_forward(model, *args, **kwargs):
nonlocal model_forward_cache
try:
model_forward_cache(*args, **kwargs)
except ValueError:
pass

self.model.forward = partial(model_forward, self.model)

@torch.no_grad()
def remove_prepare_for_calibration(self):
Expand All @@ -359,6 +371,7 @@ def remove_prepare_for_calibration(self):
logger.info("Done.")

# Step 4: restore original forward function, relocate layers back to cpu.
self.model.forward = self.orig_model_forward_cache
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
if not self.use_layer_wise: # pragma: no cover
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def save(self, model, path):
pass

def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
# * 3.x API use `bits` for woq while HQQ internal API use `nbits`
# TODO: (Yi) Please note that the configuration defined by INC should be separated from the algorithm.
# * 3.x API use `bits` for woq while HQQ internal API use `nbits`, we should change it in algorithm_entry.py
nbits = config.bits
group_size = config.group_size
quant_zero = config.quant_zero
Expand All @@ -146,9 +147,6 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
def _parse_hqq_configs_mapping(self, configs_mapping):
qconfig_mapping = {}
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.skip_lm_head and "lm_head" in op_name:
logger.warning("Skip quantizing %s due to `skip_lm_head` is True.", op_name)
continue
if quant_config is not None and quant_config.dtype == "fp32":
logger.warning("Fallback %s.", op_name)
continue
Expand Down
21 changes: 19 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,20 @@
# limitations under the License.


import copy
from collections import OrderedDict

import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils import (
get_accelerator,
get_attr,
is_transformers_imported,
logger,
set_attr,
set_module,
)

from .utility import cast_fp8, quant_tensor, search_clip

Expand Down Expand Up @@ -64,6 +72,7 @@ def convert(
quantile=1.0,
use_full_range=False,
use_mse_search=False,
quant_lm_head=False,
*args,
**kwargs,
):
Expand All @@ -80,8 +89,10 @@ def convert(
quantile (float, optional): percentile of clip. Defaults to 1.0.
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Defaults to False.
use_mse_search (bool, optional): Whether search clip range.
use_mse_search (bool, optional): Whether to search clip range.
Defaults to True.
quant_lm_head (bool, optional): Whether to quantize the lm_head layer.
Defaults to False.
Returns:
model: fake quantized torch module
Expand All @@ -93,6 +104,12 @@ def convert(
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)

# for transformers model. If lm_head is tied from embedding, we deepcopy it.
if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False):
for key in model._tied_weights_keys:
weight = get_attr(model, key)
set_attr(model, key, copy.deepcopy(weight))

assert isinstance(model, torch.nn.Module), "only support torch module"
if is_transformers_imported():
supported_layers = (torch.nn.Linear, transformers.Conv1D)
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def rtn_entry(
}

quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
model = quantizer.execute(model, mode=mode)
model = quantizer.execute(model, mode=mode, quant_lm_head=quant_config.quant_lm_head)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
Expand Down
Loading

0 comments on commit 4ae2e87

Please sign in to comment.