Skip to content

Commit

Permalink
Support absorb dict for awq (#1920)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
  • Loading branch information
Kaihui-intel authored Jul 16, 2024
1 parent e976595 commit de43d85
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 30 deletions.
54 changes: 46 additions & 8 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}
return block_absorb_dict, absorb_layer_dict


def _get_absorb_dict(model, absorb_layer_dict):
"""Get absorbed layer per block from absorbed layer dict.
Args:
model (torch.nn.Module): input model
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
Returns:
block_absorb_dict: dict of absorbed layer per block. eg. {0, [[absorbed_1, xx], [xx]], ...}
"""
block_absorb_dict = {}
block_prefix, block_num = get_block_prefix(model)
new_absorb_layer_dict = {}
for i in range(block_num):
block_absorb_dict[i] = []
block_name = block_prefix + "." + str(i) + "."

for k, v in absorb_layer_dict.items():

if isinstance(v, str):
name_list = (block_name + v,)
else:
name_list = tuple(block_name + vv for vv in v)
block_absorb_dict[i].append(name_list)
new_absorb_layer_dict[name_list] = block_name + k
logger.debug(f"The absorbed layers per block: {block_absorb_dict}")
logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}")
return block_absorb_dict, new_absorb_layer_dict


@torch.no_grad()
def _get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
Expand Down Expand Up @@ -123,6 +153,7 @@ def __init__(
total_block_args=[],
total_block_kwargs=[],
device="auto",
absorb_layer_dict={},
):

self.example_inputs = example_inputs
Expand All @@ -140,6 +171,7 @@ def __init__(
self.scheme = scheme
self.use_full_range = use_full_range
self.weight_config = weight_config
self.absorb_layer_dict = absorb_layer_dict

def _move_model_and_data_to_device(self):
# Put the model and example_inputs into target device
Expand All @@ -164,13 +196,16 @@ def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, retu
# Step 1: get absorbed module list per block, includes self-absorption
# block_absorb_dict is split per block, includes all absorb relationship.
# absorb_layer_dict is the inverse of block_absorb_dict for all blocks
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block(
self.model,
self.example_inputs,
# for only use_mse_search, folding is useless.
folding=folding if use_auto_scale else False,
weight_config=self.weight_config,
)
if not self.absorb_layer_dict:
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block(
self.model,
self.example_inputs,
# for only use_mse_search, folding is useless.
folding=folding if use_auto_scale else False,
weight_config=self.weight_config,
)
else:
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_dict(self.model, self.absorb_layer_dict)
# process per block
for i, module_list in self.block_absorb_dict.items():
logger.info(f"Processing block: {i+1}/{self.block_num}")
Expand Down Expand Up @@ -491,13 +526,15 @@ def module_inference(self, model, inputs):


class AWQQuantizer(Quantizer):
def __init__(self, quant_config: OrderedDict = {}):
def __init__(self, quant_config: OrderedDict = {}, absorb_layer_dict: dict = {}):
"""Init an AWQQuantizer object.
Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
"""
super().__init__(quant_config)
self.absorb_layer_dict = absorb_layer_dict

@torch.no_grad()
def prepare(self, model, *args, **kwargs):
Expand Down Expand Up @@ -566,6 +603,7 @@ def convert(
weight_config=self.quant_config,
total_block_args=total_block_args,
total_block_kwargs=total_block_kwargs,
absorb_layer_dict=self.absorb_layer_dict,
)
qdq_model = awq.quantize(
use_auto_scale=use_auto_scale,
Expand Down
45 changes: 24 additions & 21 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ def awq_quantize_entry(
from neural_compressor.torch.algorithms.weight_only.save_load import save

weight_config = {}
for (op_name, op_type), op_config in configs_mapping.items():
if op_config.name != AWQ:
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.name != AWQ:
continue
if op_config.dtype == "fp32":
if quant_config.dtype == "fp32":
weight_config[op_name] = {
"bits": -1,
"dtype": "fp32", # skip quantization
Expand All @@ -335,31 +335,34 @@ def awq_quantize_entry(
}
else:
weight_config[op_name] = {
"dtype": op_config.dtype,
"bits": op_config.bits,
"group_size": op_config.group_size,
"group_dim": op_config.group_dim,
"scheme": "sym" if op_config.use_sym else "asym",
"use_full_range": op_config.use_full_range,
"use_mse_search": op_config.use_mse_search,
"use_layer_wise": op_config.use_layer_wise,
"use_double_quant": op_config.use_double_quant,
"double_quant_dtype": op_config.double_quant_dtype,
"double_quant_bits": op_config.double_quant_bits,
"double_quant_scheme": op_config.double_quant_use_sym,
"double_quant_group_size": op_config.double_quant_group_size,
"dtype": quant_config.dtype,
"bits": quant_config.bits,
"group_size": quant_config.group_size,
"group_dim": quant_config.group_dim,
"scheme": "sym" if quant_config.use_sym else "asym",
"use_full_range": quant_config.use_full_range,
"use_mse_search": quant_config.use_mse_search,
"use_layer_wise": quant_config.use_layer_wise,
"use_double_quant": quant_config.use_double_quant,
"double_quant_dtype": quant_config.double_quant_dtype,
"double_quant_bits": quant_config.double_quant_bits,
"double_quant_scheme": quant_config.double_quant_use_sym,
"double_quant_group_size": quant_config.double_quant_group_size,
}
use_auto_scale = op_config.use_auto_scale
use_mse_search = op_config.use_auto_clip # for awq clip
folding = op_config.folding
use_full_range = op_config.use_full_range
use_auto_scale = quant_config.use_auto_scale
use_mse_search = quant_config.use_auto_clip # for awq clip
folding = quant_config.folding
use_full_range = quant_config.use_full_range
absorb_layer_dict = quant_config.absorb_layer_dict

run_fn = kwargs.get("run_fn", None)
run_args = kwargs.get("run_args", None)
example_inputs = kwargs.get("example_inputs", None)
assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."

quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
quantizer = get_quantizer(
model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_layer_dict=absorb_layer_dict
)
model = quantizer.execute(
model,
mode=mode,
Expand Down
6 changes: 5 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ class AWQConfig(TorchBaseConfig):
"use_auto_scale",
"use_auto_clip",
"folding",
"absorb_layer_dict",
]
name = AWQ

Expand All @@ -468,6 +469,7 @@ def __init__(
use_auto_clip: bool = True,
folding: bool = False,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
absorb_layer_dict: dict = {},
):
"""Init AWQ weight-only quantization config.
Expand All @@ -490,6 +492,7 @@ def __init__(
use_auto_clip (bool): Enables clip range search. Defaults to True.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
default is False.
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
"""
super().__init__(white_list=white_list)
self.dtype = dtype
Expand All @@ -510,6 +513,7 @@ def __init__(
self.use_auto_scale = use_auto_scale
self.use_auto_clip = use_auto_clip
self.folding = folding
self.absorb_layer_dict = absorb_layer_dict
self._post_init()

@classmethod
Expand Down Expand Up @@ -626,7 +630,7 @@ def __init__(
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
double_quant_group_size (int): Size of double_quant groups, default is 32.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}.
absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
default is False.
"""
Expand Down
38 changes: 38 additions & 0 deletions test/3x/torch/quantization/weight_only/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,41 @@ def test_quant_lm_head(self):
assert (
id(model.model.decoder.embed_tokens.weight) == lm_head_id
), "The tied lm_head weight is not deep copied, please check!"

def test_awq_absorb_to_layer(self):
absorb_layer_dict = {
"ln_1": (
"attn.q_proj",
"attn.k_proj",
"attn.v_proj",
"mlp.fc_in",
),
"attn.out_proj": "attn.out_proj",
"mlp.fc_out": ("mlp.fc_out"),
}

quant_config = AWQConfig(absorb_layer_dict=absorb_layer_dict)
logger.info(f"Test AWQ with config {quant_config}")
# prepare + convert API
model = prepare(
model=copy.deepcopy(self.tiny_gptj),
quant_config=quant_config,
example_inputs=self.example_inputs,
)
calib_func(model)
model = convert(model)
out1 = model(self.example_inputs)
quant_config = AWQConfig()
logger.info(f"Test AWQ with config {quant_config}")

# prepare + convert API
model = prepare(
model=copy.deepcopy(self.tiny_gptj),
quant_config=quant_config,
example_inputs=self.example_inputs,
)
calib_func(model)
model = convert(model)
out2 = model(self.example_inputs)

assert torch.all(out1[0].eq(out2[0])), "The results should be equal."

0 comments on commit de43d85

Please sign in to comment.