-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add new post-quant methods #32208
add new post-quant methods #32208
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,7 @@ def _set_variable_data(scope, place, var_name, np_value): | |
Set the value of var node by name, if the node exits, | ||
''' | ||
assert isinstance(np_value, np.ndarray), \ | ||
'The type of value should be numpy array.' | ||
'The type of value should be numpy array.' | ||
var_node = scope.find_var(var_name) | ||
if var_node != None: | ||
tensor = var_node.get_tensor() | ||
|
@@ -138,8 +138,10 @@ def __init__(self, | |
batch_size=10, | ||
batch_nums=None, | ||
algo="KL", | ||
hist_perc=0.99999, | ||
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], | ||
is_full_quantize=False, | ||
bias_correct=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bias_correction? |
||
activation_bits=8, | ||
weight_bits=8, | ||
activation_quantize_type='range_abs_max', | ||
|
@@ -180,14 +182,22 @@ def __init__(self, | |
get the KL threshold for quantized activations and get the abs_max | ||
value for quantized weights. If algo='abs_max', get the abs max | ||
value for activations and weights. If algo= 'min_max', get the min | ||
and max value for quantized activations and weights. Default is KL. | ||
and max value for quantized activations and weights. If algo='avg', | ||
get the average value among the max values for activations. If | ||
algo= 'hist', get the value of 'hist_perc' quantile as the threshold. | ||
If algo='mse', get the value which makes the quantization mse loss | ||
minimal. Default is KL. | ||
hist_perc(float, optional): The threshold of algo 'hist' for activations. | ||
Default is 0.99999. | ||
quantizable_op_type(list[str], optional): List the type of ops | ||
that will be quantized. Default is ["conv2d", "depthwise_conv2d", | ||
"mul"]. | ||
is_full_quantized(bool, optional): If set is_full_quantized as True, | ||
apply quantization to all supported quantizable op type. If set | ||
is_full_quantized as False, only apply quantization to the op type | ||
according to the input quantizable_op_type. | ||
bias_correct(bool, optional): If set as True, use the bias correction | ||
method of https://arxiv.org/abs/1810.05723. Default is False. | ||
XGZhang11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
activation_bits(int): quantization bit number for activation. | ||
weight_bits(int, optional): quantization bit number for weights. | ||
activation_quantize_type(str): quantization type for activation, | ||
|
@@ -255,7 +265,7 @@ def __init__(self, | |
'range_abs_max', 'moving_average_abs_max', 'abs_max' | ||
] | ||
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max'] | ||
self._support_algo_type = ['KL', 'abs_max', 'min_max'] | ||
self._support_algo_type = ['KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'] | ||
self._dynamic_quantize_op_type = ['lstm'] | ||
self._support_quantize_op_type = \ | ||
list(set(QuantizationTransformPass._supported_quantizable_op_type + | ||
|
@@ -270,7 +280,7 @@ def __init__(self, | |
"cannot be None in the same time." | ||
assert batch_size > 0, "The batch_size should be greater than 0." | ||
assert algo in self._support_algo_type, \ | ||
"The algo should be KL, abs_max or min_max." | ||
"The algo should be KL, hist, mse, avg, abs_max or min_max." | ||
assert activation_quantize_type in self._support_activation_quantize_type, \ | ||
"The activation_quantize_type ({}) should in ({}).".format( | ||
activation_quantize_type, self._support_activation_quantize_type) | ||
|
@@ -279,6 +289,7 @@ def __init__(self, | |
weight_quantize_type, self._support_weight_quantize_type) | ||
|
||
# Save input params | ||
self._bias_correct = bias_correct | ||
self._executor = executor | ||
self._scope = global_scope() if scope == None else scope | ||
self._model_dir = model_dir | ||
|
@@ -289,6 +300,7 @@ def __init__(self, | |
self._batch_size = batch_size | ||
self._batch_nums = batch_nums | ||
self._algo = algo | ||
self._hist_perc = hist_perc | ||
self._activation_bits = activation_bits | ||
self._weight_bits = weight_bits | ||
self._activation_quantize_type = activation_quantize_type | ||
|
@@ -318,13 +330,17 @@ def __init__(self, | |
self._sampling_act_abs_min_max = {} | ||
self._sampling_act_histogram = {} | ||
self._sampling_data = {} | ||
self._quantized_var_kl_threshold = {} | ||
self._quantized_var_threshold = {} | ||
self._histogram_bins = 2048 | ||
# The vars for algo = min_max | ||
self._quantized_var_min = {} | ||
self._quantized_var_max = {} | ||
# The vars for algo = abs_max | ||
self._quantized_var_abs_max = {} | ||
# The vars for algo = avg | ||
self._quantized_var_avg = {} | ||
# The best loss of algo = mse | ||
self._best_mse_loss = {} | ||
# The threshold for algo = abs_max, mse or avg | ||
self._quantized_threshold = {} | ||
|
||
def quantize(self): | ||
''' | ||
|
@@ -341,7 +357,7 @@ def quantize(self): | |
self._collect_target_varnames() | ||
self._set_activation_persistable() | ||
|
||
if self._algo == "KL": | ||
if self._algo in ["KL", "hist"]: | ||
_logger.info("Preparation stage ...") | ||
batch_id = 0 | ||
for data in self._data_loader(): | ||
|
@@ -373,14 +389,17 @@ def quantize(self): | |
batch_id += 1 | ||
if self._batch_nums and batch_id >= self._batch_nums: | ||
break | ||
|
||
if self._algo == 'avg': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是获取阈值的逻辑,后面计算阈值的部分。 |
||
for var_name in self._quantized_act_var_name: | ||
self._quantized_threshold[var_name] = np.array(self._quantized_var_avg[var_name]).mean() | ||
_logger.info("Finish sampling stage, all batch: " + str(batch_id)) | ||
|
||
self._reset_activation_persistable() | ||
|
||
if self._algo == "KL": | ||
self._calculate_kl_threshold() | ||
|
||
if self._algo in ["KL", "abs_max"]: | ||
if self._algo in ["KL", "hist"]: | ||
self._calculate_kl_hist_threshold() | ||
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]: | ||
self._update_program() | ||
else: | ||
self._save_input_threhold() | ||
|
@@ -524,16 +543,18 @@ def _sampling(self): | |
''' | ||
Sample the min/max, abs_max or histogram in every iterations. | ||
''' | ||
if self._algo == "abs_max": | ||
self._sample_abs_max() | ||
if self._algo in ["avg", "abs_max"]: | ||
self._sample_abs_max_avg() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 两个不相同的采样方式,分开成两个函数。 |
||
elif self._algo == "min_max": | ||
self._sample_min_max() | ||
elif self._algo == "KL": | ||
elif self._algo == "mse": | ||
self._sample_mse() | ||
elif self._algo in ["KL", "hist"]: | ||
self._sample_histogram() | ||
|
||
def _sample_abs_max(self): | ||
def _sample_mse(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 注释使用”“”xxx “”“ |
||
# Only calculate abs_max value for weight for once | ||
if self._quantized_var_abs_max == {}: | ||
if self._quantized_threshold == {}: | ||
for var_name in self._quantized_weight_var_name: | ||
var_tensor = _load_variable_data(self._scope, var_name) | ||
if self._weight_quantize_type == "abs_max": | ||
|
@@ -549,14 +570,60 @@ def _sample_abs_max(self): | |
for i in range(var_tensor.shape[0]): | ||
abs_max_value.append( | ||
float(np.max(np.abs(var_tensor[i])))) | ||
self._quantized_var_abs_max[var_name] = abs_max_value | ||
self._quantized_threshold[var_name] = abs_max_value | ||
_logger.info("MSE searching stage ...") | ||
for var_name in self._quantized_act_var_name: | ||
var_tensor = _load_variable_data(self._scope, var_name) | ||
var_tensor = var_tensor.flatten() | ||
abs_max_value = float(np.max(np.abs(var_tensor))) | ||
s = 0.3 | ||
best_scale = 0.0 | ||
if var_name not in self._best_mse_loss: | ||
self._best_mse_loss[var_name] = 100000.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要使用固定数值,使用 = float("inf") |
||
while s <= 1.0: | ||
scale = s * abs_max_value | ||
s += 0.02 | ||
bins = 2 ** (self._activation_bits -1) - 1 | ||
quant_dequant_var = np.round(np.clip(var_tensor, 0.0, scale) / scale * bins) / bins * scale | ||
mse_loss = ((var_tensor - quant_dequant_var) ** 2).mean() | ||
if mse_loss <= self._best_mse_loss[var_name]: | ||
self._best_mse_loss[var_name] = mse_loss | ||
best_scale = scale | ||
if best_scale > 0.0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个判断没有必要,self._quantized_threshold[var_name] = best_scale 可以放到if mse_loss <= self._best_mse_loss[var_name]:中 |
||
self._quantized_threshold[var_name] = best_scale | ||
|
||
def _sample_abs_max_avg(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avg和abs_max分为两个函数实现 |
||
# Only calculate abs_max value for weight for once | ||
if self._quantized_threshold == {}: | ||
for var_name in self._quantized_weight_var_name: | ||
var_tensor = _load_variable_data(self._scope, var_name) | ||
if self._weight_quantize_type == "abs_max": | ||
abs_max_value = float(np.max(np.abs(var_tensor))) | ||
elif self._weight_quantize_type == "channel_wise_abs_max": | ||
abs_max_value = [] | ||
if self._weight_op_pairs[ | ||
var_name] in _channelwise_quant_axis1_ops: | ||
for i in range(var_tensor.shape[1]): | ||
abs_max_value.append( | ||
float(np.max(np.abs(var_tensor[:, i])))) | ||
else: | ||
for i in range(var_tensor.shape[0]): | ||
abs_max_value.append( | ||
float(np.max(np.abs(var_tensor[i])))) | ||
self._quantized_threshold[var_name] = abs_max_value | ||
|
||
for var_name in self._quantized_act_var_name: | ||
var_tensor = _load_variable_data(self._scope, var_name) | ||
abs_max_value = float(np.max(np.abs(var_tensor))) | ||
if (var_name not in self._quantized_var_abs_max) or \ | ||
(abs_max_value > self._quantized_var_abs_max[var_name]): | ||
self._quantized_var_abs_max[var_name] = abs_max_value | ||
if self._algo == 'avg': | ||
if (var_name not in self._quantized_var_avg): | ||
self._quantized_var_avg[var_name] = [] | ||
abs_avg_value = float(np.mean(np.max(np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1)))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 注意代码每行的长度 |
||
self._quantized_var_avg[var_name].append(abs_avg_value) | ||
continue | ||
if (var_name not in self._quantized_threshold) or \ | ||
(abs_max_value > self._quantized_threshold[var_name]): | ||
self._quantized_threshold[var_name] = abs_max_value | ||
|
||
def _sample_min_max(self): | ||
if self._quantized_var_min == {} and self._quantized_var_max == {}: | ||
|
@@ -646,12 +713,12 @@ def _init_sampling_act_histogram(self): | |
[], bins=self._histogram_bins, range=(min_val, max_val)) | ||
self._sampling_act_histogram[var_name] = [hist, hist_edeges] | ||
|
||
def _calculate_kl_threshold(self): | ||
def _calculate_kl_hist_threshold(self): | ||
''' | ||
Calculate the KL threshold of quantized variables. | ||
Calculate the KL or hist threshold of quantized variables. | ||
''' | ||
_logger.info("Calculate KL threshold ...") | ||
assert self._algo == "KL", "The algo should be KL to calculate kl threshold." | ||
_logger.info("Calculate {} threshold ...".format(self._algo)) | ||
assert self._algo in ["KL", "hist"], "The algo should be KL to calculate kl threshold." | ||
|
||
# Abs_max threshold for weights | ||
for var_name in self._quantized_weight_var_name: | ||
|
@@ -669,12 +736,16 @@ def _calculate_kl_threshold(self): | |
for i in range(weight_data.shape[0]): | ||
weight_threshold.append( | ||
float(np.max(np.abs(weight_data[i])))) | ||
self._quantized_var_kl_threshold[var_name] = weight_threshold | ||
self._quantized_var_threshold[var_name] = weight_threshold | ||
|
||
for var_name in self._quantized_act_var_name: | ||
hist, hist_edeges = self._sampling_act_histogram[var_name] | ||
self._quantized_var_kl_threshold[var_name] = \ | ||
self._get_kl_scaling_factor(hist, hist_edeges) | ||
if self._algo == "KL": | ||
self._quantized_var_threshold[var_name] = \ | ||
self._get_kl_scaling_factor(hist, hist_edeges) | ||
elif self._algo == "hist": | ||
self._quantized_var_threshold[var_name] = \ | ||
self._get_hist_scaling_factor(hist, hist_edeges) | ||
|
||
def _update_program(self): | ||
''' | ||
|
@@ -712,10 +783,10 @@ def _update_program(self): | |
add_quant_dequant_pass.apply(graph) | ||
|
||
# save abs_max or KL threshold to scale var node | ||
if self._algo == "KL": | ||
scale_dict = self._quantized_var_kl_threshold | ||
if self._algo in ["KL", "hist"]: | ||
scale_dict = self._quantized_var_threshold | ||
else: | ||
scale_dict = self._quantized_var_abs_max | ||
scale_dict = self._quantized_threshold | ||
for key, val in scale_dict.items(): | ||
_set_variable_data( | ||
self._scope, | ||
|
@@ -734,6 +805,7 @@ def _update_program(self): | |
freeze_pass = QuantizationFreezePass( | ||
scope=self._scope, | ||
place=self._place, | ||
bias_correct=self._bias_correct, | ||
weight_bits=self._weight_bits, | ||
activation_bits=self._activation_bits, | ||
weight_quantize_type=self._weight_quantize_type, | ||
|
@@ -762,19 +834,29 @@ def analysis_and_save_info(op_node, out_var_name): | |
if self._algo == "KL": | ||
# For compatibility, we save output threshold by two methods. | ||
save_info(op_node, out_var_name, | ||
self._quantized_var_kl_threshold, "out_threshold", | ||
self._quantized_var_threshold, "out_threshold", | ||
"post_kl") | ||
save_info( | ||
op_node, out_var_name, self._quantized_var_kl_threshold, | ||
op_node, out_var_name, self._quantized_var_threshold, | ||
argname_index[0] + str(argname_index[1]) + "_threshold", | ||
"post_kl") | ||
elif self._algo == "abs_max": | ||
save_info(op_node, out_var_name, self._quantized_var_abs_max, | ||
"out_threshold", "post_abs_max") | ||
elif self._algo == "hist": | ||
# For compatibility, we save output threshold by two methods. | ||
save_info(op_node, out_var_name, | ||
self._quantized_var_threshold, "out_threshold", | ||
"post_hist") | ||
save_info( | ||
op_node, out_var_name, self._quantized_var_abs_max, | ||
op_node, out_var_name, self._quantized_var_threshold, | ||
argname_index[0] + str(argname_index[1]) + "_threshold", | ||
"post_kl") | ||
"post_hist") | ||
|
||
elif self._algo in ["avg", "abs_max", "mse"]: | ||
save_info(op_node, out_var_name, self._quantized_threshold, | ||
"out_threshold", "post_absmax") | ||
save_info( | ||
op_node, out_var_name, self._quantized_threshold, | ||
argname_index[0] + str(argname_index[1]) + "_threshold", | ||
"post_absmax") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 三种方法有点点差别,建议区分开来 |
||
elif self._algo == "min_max": | ||
save_info(op_node, out_var_name, self._quantized_var_min, | ||
"out_min", "post_min_max") | ||
|
@@ -817,10 +899,27 @@ def _collect_dynamic_quantize_op_threshold(self, target_ops_type): | |
op._set_attr("quantization_type", quantization_type) | ||
op._set_attr("bit_length", self._weight_bits) | ||
|
||
def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255): | ||
def _get_hist_scaling_factor(self, hist, hist_edges): | ||
''' | ||
Using the hist method to get the scaling factor. | ||
''' | ||
threshold_rate = self._hist_perc | ||
hist = hist / float(sum(hist)) | ||
hist_sum = 0 | ||
hist_index = 0 | ||
for i in range(len(hist)): | ||
hist_sum += hist[i] | ||
if hist_sum >= threshold_rate: | ||
hist_index = i + 1 | ||
break | ||
bin_width = hist_edges[1] - hist_edges[0] | ||
return (hist_index - 0.5) * bin_width | ||
|
||
def _get_kl_scaling_factor(self, hist, hist_edeges): | ||
''' | ||
Using the KL-divergenc method to get the more precise scaling factor. | ||
''' | ||
num_quantized_bins = 2 ** (self._activation_bits - 1) - 1 | ||
ending_iter = self._histogram_bins - 1 | ||
starting_iter = int(ending_iter * 0.7) | ||
bin_width = hist_edeges[1] - hist_edeges[0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议用完整的hist_percent