Skip to content

Commit

Permalink
Integrate AutoRound v0.3 (#1925)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
  • Loading branch information
Kaihui-intel authored Jul 17, 2024
1 parent 5767aed commit bfa27e4
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 68 deletions.
134 changes: 86 additions & 48 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,69 +31,95 @@ class AutoRoundQuantizer(Quantizer):
def __init__(
self,
quant_config: dict = {},
enable_full_range: bool = False,
enable_full_range: bool = False, ##for symmetric, TODO support later
batch_size: int = 8,
amp: bool = True,
device=None,
device: str = None,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
low_gpu_mem_usage: bool = False,
iters: int = 200,
seqlen: int = 2048,
n_samples: int = 512,
nsamples: int = 128,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
nblocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
data_type: str = "int",
scale_dtype: str = "fp16",
multimodal: bool = False,
act_bits: int = 32,
act_group_size: int = None,
act_sym: bool = None,
act_dynamic: bool = True,
low_cpu_mem_usage: bool = False,
**kwargs,
):
"""Init a AutQRoundQuantizer object.
Args:
quant_config (dict): Configuration for weight quantization (default is None).
quant_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
quant_config (dict): Configuration for weight quantization (default is None).
quant_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
'act_data_type': None,
'act_bits': 32,
'act_sym': None,
'act_dynamic': True,
}
...,
}
...
}
keys:
data_type (str): The data type to be used (default is "int").
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
device: The device to be used for tuning (default is None). Automatically detect and set.
lr_scheduler: The learning rate scheduler to be used.
use_quant_input (bool): Whether to use quantized input data (default is True).
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
keys:
data_type (str): The data type to be used (default is "int").
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether symmetric quantization is to be used (default is False).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for tuning (default is "auto").
lr_scheduler: The learning rate scheduler to be used.
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
the input for the current block (default is True).
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
lr (float): The learning rate (default is None, will be set to 1.0/iters).
minmax_lr (float): The learning rate for min-max tuning
(default is None, it will be set to lr automatically).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Data length of the sequence for tuning (default is 2048).
nsamples (int): Number of samples (default is 128).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
nblocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
multimodal(bool): Enable multimodal model quantization, (default is "False").
act_bits (int): Number of bits for activation quantization. Default is 32.
act_group_size (int): Group size for activation quantization. Default is None.
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
Returns:
The quantized model.
"""
super().__init__(quant_config)
self.tokenizer = None
Expand All @@ -109,15 +135,21 @@ def __init__(
self.low_gpu_mem_usage = low_gpu_mem_usage
self.iters = iters
self.seqlen = seqlen
self.n_samples = n_samples
self.nsamples = nsamples
self.sampler = sampler
self.seed = seed
self.n_blocks = n_blocks
self.nblocks = nblocks
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.data_type = data_type
self.scale_dtype = scale_dtype
self.multimodal = multimodal
self.act_bits = act_bits
self.act_group_size = act_group_size
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.low_cpu_mem_usage = low_cpu_mem_usage

def prepare(self, model: torch.nn.Module, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -137,7 +169,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
model=model,
tokenizer=None,
dataset=dataloader,
weight_config=self.quant_config or {},
layer_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
Expand All @@ -150,23 +182,29 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
low_gpu_mem_usage=self.low_gpu_mem_usage,
iters=self.iters,
seqlen=self.seqlen,
n_samples=self.n_samples,
nsamples=self.nsamples,
sampler=self.sampler,
seed=self.seed,
n_blocks=self.n_blocks,
nblocks=self.nblocks,
gradient_accumulate_steps=self.gradient_accumulate_steps,
not_use_best_mse=self.not_use_best_mse,
dynamic_max_gap=self.dynamic_max_gap,
data_type=self.data_type,
scale_dtype=self.scale_dtype,
multimodal=self.multimodal,
act_bits=self.act_bits,
act_group_size=self.act_group_size,
act_sym=self.act_sym,
act_dynamic=self.act_dynamic,
low_cpu_mem_usage=self.low_cpu_mem_usage,
)
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
model = pack_model(model, weight_config, device=self.device, inplace=True)
return model


def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=128):
"""Generate a DataLoader for calibration using specified parameters.
Args:
Expand All @@ -186,6 +224,6 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

dataloader = get_dataloader(
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples
)
return dataloader
16 changes: 12 additions & 4 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ def autoround_quantize_entry(
"bits": quant_config.bits,
"sym": quant_config.use_sym,
"group_size": quant_config.group_size,
"act_bits": quant_config.act_bits,
"act_group_size": quant_config.act_group_size,
"act_sym": quant_config.act_sym,
"act_dynamic": quant_config.act_dynamic,
}
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
Expand All @@ -583,14 +587,16 @@ def autoround_quantize_entry(
low_gpu_mem_usage = quant_config.low_gpu_mem_usage
iters = quant_config.iters
seqlen = quant_config.seqlen
n_samples = quant_config.n_samples
nsamples = quant_config.nsamples
sampler = quant_config.sampler
seed = quant_config.seed
n_blocks = quant_config.n_blocks
nblocks = quant_config.nblocks
gradient_accumulate_steps = quant_config.gradient_accumulate_steps
not_use_best_mse = quant_config.not_use_best_mse
dynamic_max_gap = quant_config.dynamic_max_gap
scale_dtype = quant_config.scale_dtype
multimodal = quant_config.multimodal
low_cpu_mem_usage = quant_config.use_layer_wise

kwargs.pop("example_inputs")

Expand All @@ -608,14 +614,16 @@ def autoround_quantize_entry(
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
nsamples=nsamples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
nblocks=nblocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
scale_dtype=scale_dtype,
multimodal=multimodal,
low_cpu_mem_usage=low_cpu_mem_usage,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
Expand Down
36 changes: 26 additions & 10 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,8 @@ class AutoRoundConfig(TorchBaseConfig):
"minmax_lr",
"iters",
"seqlen",
"n_samples",
"n_blocks",
"nsamples",
"nblocks",
"gradient_accumulate_steps",
"not_use_best_mse",
"dynamic_max_gap",
Expand All @@ -750,6 +750,10 @@ def __init__(
use_sym: bool = False,
group_size: int = 128,
# AUTOROUND
act_bits: int = 32,
act_group_size: int = None,
act_sym: bool = None,
act_dynamic: bool = True,
enable_full_range: bool = False,
batch_size: int = 8,
lr_scheduler=None,
Expand All @@ -759,16 +763,17 @@ def __init__(
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
iters: int = 200,
seqlen: int = 512,
n_samples: int = 512,
seqlen: int = 2048,
nsamples: int = 128,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
nblocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype: str = "fp16",
use_layer_wise: bool = False,
multimodal: bool = False,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init AUTOROUND weight-only quantization config.
Expand All @@ -778,6 +783,10 @@ def __init__(
bits (int): Number of bits used to represent weights, default is 4.
use_sym (bool): Indicates whether weights are symmetric, default is False.
group_size (int): Size of weight groups, default is 128.
act_bits (int): Number of bits for activation quantization. Default is 32.
act_group_size (int): Group size for activation quantization. Default is None.
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
lr_scheduler: The learning rate scheduler to be used.
Expand All @@ -788,21 +797,27 @@ def __init__(
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
nsamples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
nblocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
multimodal(bool): Enable multimodal model quantization, (default is "False").
"""
super().__init__(white_list=white_list)
self.dtype = dtype
self.bits = bits
self.use_sym = use_sym
self.group_size = group_size
self.act_bits = act_bits
self.act_group_size = act_group_size
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.lr_scheduler = lr_scheduler
Expand All @@ -813,15 +828,16 @@ def __init__(
self.low_gpu_mem_usage = low_gpu_mem_usage
self.iters = iters
self.seqlen = seqlen
self.n_samples = n_samples
self.nsamples = nsamples
self.sampler = sampler
self.seed = seed
self.n_blocks = n_blocks
self.nblocks = nblocks
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.scale_dtype = scale_dtype
self.use_layer_wise = use_layer_wise
self.multimodal = multimodal
self._post_init()

@classmethod
Expand Down Expand Up @@ -1526,7 +1542,7 @@ def get_woq_tuning_config() -> list:
the list of WOQ quant config.
"""
RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32)
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32)
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32, seqlen=512)
GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32)
AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32)
return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM]
Loading

0 comments on commit bfa27e4

Please sign in to comment.