Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save/load support for HQQ #1913

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/hqq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# NOTICE: the original `Quantizer` has been modified to `HQQTensorHandle`
# and `QTensor` to decouple the data structure and the quantization logic.

from typing import Any, Dict, Tuple
from typing import Any, Dict, Mapping, Tuple

import torch

Expand Down Expand Up @@ -278,3 +278,61 @@ def from_float(
# !!! Delete the float explicitly to save memory
del float_module
return new_mod

def state_dict(self, *args, **kwargs): # nn.Module override compatible
state_dict = self.q_weight.to_state_dict()
if self.bias is not None:
state_dict["bias"] = self.bias
if "destination" in kwargs and "prefix" in kwargs:
for key, value in state_dict.items():
kwargs["destination"][kwargs["prefix"] + key] = value
return state_dict

def _load_from_state_dict(
yuwenzho marked this conversation as resolved.
Show resolved Hide resolved
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
all_expected_keys = ["val", "scale_quantized", "zero_quantized", "meta_info"]
if self.bias is not None:
all_expected_keys.append("bias")

for key in all_expected_keys:
if prefix + key not in state_dict:
missing_keys.append(key)
if missing_keys:
return # Can't load weights if either weight or meta is missing

cur_state_dict = {}
for key in all_expected_keys:
cur_state_dict[key] = state_dict.pop(prefix + key)

unexpected_keys += state_dict.keys()
self._assign_state_dict(cur_state_dict, strict)

def _assign_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
_scale_quantized = state_dict["scale_quantized"]
_zero_quantized = state_dict["zero_quantized"]
scale_state = state_dict["meta_info"]["scale"]
zero_state = state_dict["meta_info"]["zero"]
if _scale_quantized:
scale = HQQTensorHandle._create_q_tensor(scale_state["val"], scale_state["meta_info"])
else:
scale = state_dict["meta_info"]["scale"]
if _zero_quantized:
zero = HQQTensorHandle._create_q_tensor(zero_state["val"], zero_state["meta_info"])
else:
zero = state_dict["meta_info"]["zero"]
meta = state_dict["meta_info"]
meta["scale"] = scale
meta["zero"] = zero
self.q_weight = HQQTensorHandle._create_q_tensor(state_dict["val"], meta)
if self.bias is not None:
self.bias = state_dict["bias"]
self.quantized = True
return self
16 changes: 16 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ def half(self):
if self.zero is not None:
self.zero = self.zero.half()
return self

def to_state_dict(self):
state = {}
state["val"] = self.val
state["meta_info"] = self.meta_info.to_dict()
state["scale_quantized"] = self.is_scale_quantized()
state["zero_quantized"] = self.is_zero_quantized()
if self.is_scale_quantized():
state["meta_info"]["scale"] = self.scale.to_state_dict()
else:
state["meta_info"]["scale"] = self.scale
if self.is_zero_quantized():
state["meta_info"]["zero"] = self.zero.to_state_dict()
else:
state["meta_info"]["zero"] = self.zero
return state
29 changes: 28 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path):

with open(qconfig_file_path, "r") as file:
self.quantization_config = json.load(file)

model = self._build_woq_model()
model.load_state_dict(qweights, assign=True)
model.eval()
Expand Down Expand Up @@ -157,8 +156,19 @@ def load_hf_format_woq_model(self):

return model

def _is_hqq_model(self):
for name, module in self.original_model.named_modules():
pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))"
for q_config_key, q_config_value in self.quantization_config.items():
if re.search(pattern, q_config_key):
if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "hqq":
return True

def _build_woq_model(self):
"""Build weight-only quantization model."""
if self._is_hqq_model():
return self._build_hqq_model()

from neural_compressor.torch.utils import set_module

from .modules import MulLinear
Expand Down Expand Up @@ -228,6 +238,23 @@ def _build_woq_model(self):
woq_model = self.original_model
return woq_model

def _build_hqq_model(self):
"""Replace quantized Linear with HQQLinear."""
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
from neural_compressor.torch.utils import set_module

for name, module in self.original_model.named_modules():
if isinstance(module, torch.nn.Linear):
loaded_state_dict_keys_set = set(self.loaded_state_dict_keys)
if name + ".val" not in loaded_state_dict_keys_set:
continue
new_module = HQQLinear(
in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None
)
set_module(self.original_model, name, new_module)
woq_model = self.original_model
return woq_model

def _get_model_class_and_config(self):
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,14 @@ def hqq_entry(
**kwargs,
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save

logger.info("Quantize model with the HQQ algorithm.")

quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
model = quantizer.execute(model, mode=mode)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

Expand Down
5 changes: 4 additions & 1 deletion neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AWQConfig,
FP8Config,
GPTQConfig,
HQQConfig,
RTNConfig,
TEQConfig,
)
Expand Down Expand Up @@ -89,7 +90,9 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
# select load function
config_object = config_mapping[next(iter(config_mapping))]

if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
if isinstance(
config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig, HQQConfig)
): # WOQ
from neural_compressor.torch.algorithms import weight_only

return weight_only.load(model_name_or_path, original_model, format=LoadFormat.DEFAULT)
Expand Down
78 changes: 78 additions & 0 deletions test/3x/torch/quantization/weight_only/test_hqq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import copy
import os
import time
from copy import deepcopy

import pytest
import torch
import transformers
from transformers import AutoModelForCausalLM

from neural_compressor.common import options
from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
Expand Down Expand Up @@ -93,6 +96,27 @@ def test_hqq_quant(self, force_use_cpu, force_not_half):
q_label_1.eq(q_label_2)
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."

def test_hqq_load_save(self, force_use_cpu, force_not_half):

hqq_global_option.use_half = False
fp32_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-OPTForCausalLM")
example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu")
# test_default_config
quant_config = get_default_hqq_config()

# prepare + convert API
model = prepare(deepcopy(fp32_model), quant_config)
qmodel = convert(model)
qmodel_out_ref = model(example_inputs)[0]
save_path = options.workspace + f"/_hqq_model_{time.time()}.pth"
qmodel.save(save_path)
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load(save_path, copy.deepcopy(fp32_model))
loaded_model_out = loaded_model(example_inputs)[0]
assert torch.allclose(qmodel_out_ref, loaded_model_out), "Unexpected result. Please double check."

def test_hqq_fallback(self, force_use_cpu, force_not_half):

class ToyModel(torch.nn.Module):
Expand Down Expand Up @@ -181,3 +205,57 @@ def test_hqq_module(
scale_quant_group_size=scale_quant_group_size,
device=torch.device(device_name),
)

@pytest.mark.parametrize(
"nbits, group_size, quant_zero, quant_scale, scale_quant_group_size",
[
(4, 64, True, False, 128),
(4, 64, False, False, 128),
(4, 64, True, True, 128),
(4, 64, False, True, 128),
(8, 64, True, False, 128),
],
)
def test_hqq_linear_save_and_load(
self,
nbits,
group_size,
quant_zero,
quant_scale,
scale_quant_group_size,
):
hqq_global_option.use_half = False
# Parse config
weight_qconfig = QTensorConfig(
nbits=nbits,
channel_wise=True,
group_size=group_size,
optimize=True,
round_zero=True if nbits == 4 else False,
)
zero_qconfig = None
if quant_zero:
zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False)
scale_qconfig = None
if quant_scale:
scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False)
hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig)
# Create HQQ Linear
bs = 4
in_features = 64
out_features = 128
float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features)
float_linear.to(device)
float_linear_copy = deepcopy(float_linear)
input = torch.randn(bs, in_features, device=device)
hqq_linear = HQQLinear.from_float(float_linear_copy, quant_config=hqq_quant_config)
out_ref = hqq_linear(input)
state_dict = hqq_linear.state_dict()
hqq_module_path = options.workspace + f"/_hqq_linear_{time.time()}.pth"
torch.save(state_dict, hqq_module_path)
reload_state_dict = torch.load(hqq_module_path)
new_float = torch.nn.Linear(in_features=in_features, out_features=out_features)
new_hqq_linear = HQQLinear.from_float(new_float, quant_config=hqq_quant_config)
new_hqq_linear.load_state_dict(reload_state_dict)
out = new_hqq_linear(input)
assert torch.equal(out_ref, out), f"out_ref: {out_ref}, out: {out}"
Loading