Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new post-quant methods #32208

Merged
merged 3 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _set_variable_data(scope, place, var_name, np_value):
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
Expand Down Expand Up @@ -138,8 +138,10 @@ def __init__(self,
batch_size=10,
batch_nums=None,
algo="KL",
hist_perc=0.99999,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议用完整的hist_percent

quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
bias_correct=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias_correction?

activation_bits=8,
weight_bits=8,
activation_quantize_type='range_abs_max',
Expand Down Expand Up @@ -180,14 +182,22 @@ def __init__(self,
get the KL threshold for quantized activations and get the abs_max
value for quantized weights. If algo='abs_max', get the abs max
value for activations and weights. If algo= 'min_max', get the min
and max value for quantized activations and weights. Default is KL.
and max value for quantized activations and weights. If algo='avg',
get the average value among the max values for activations. If
algo= 'hist', get the value of 'hist_perc' quantile as the threshold.
If algo='mse', get the value which makes the quantization mse loss
minimal. Default is KL.
hist_perc(float, optional): The threshold of algo 'hist' for activations.
Default is 0.99999.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
bias_correct(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
XGZhang11 marked this conversation as resolved.
Show resolved Hide resolved
activation_bits(int): quantization bit number for activation.
weight_bits(int, optional): quantization bit number for weights.
activation_quantize_type(str): quantization type for activation,
Expand Down Expand Up @@ -255,7 +265,7 @@ def __init__(self,
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._support_algo_type = ['KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max']
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
Expand All @@ -270,7 +280,7 @@ def __init__(self,
"cannot be None in the same time."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max."
"The algo should be KL, hist, mse, avg, abs_max or min_max."
assert activation_quantize_type in self._support_activation_quantize_type, \
"The activation_quantize_type ({}) should in ({}).".format(
activation_quantize_type, self._support_activation_quantize_type)
Expand All @@ -279,6 +289,7 @@ def __init__(self,
weight_quantize_type, self._support_weight_quantize_type)

# Save input params
self._bias_correct = bias_correct
self._executor = executor
self._scope = global_scope() if scope == None else scope
self._model_dir = model_dir
Expand All @@ -289,6 +300,7 @@ def __init__(self,
self._batch_size = batch_size
self._batch_nums = batch_nums
self._algo = algo
self._hist_perc = hist_perc
self._activation_bits = activation_bits
self._weight_bits = weight_bits
self._activation_quantize_type = activation_quantize_type
Expand Down Expand Up @@ -318,13 +330,17 @@ def __init__(self,
self._sampling_act_abs_min_max = {}
self._sampling_act_histogram = {}
self._sampling_data = {}
self._quantized_var_kl_threshold = {}
self._quantized_var_threshold = {}
self._histogram_bins = 2048
# The vars for algo = min_max
self._quantized_var_min = {}
self._quantized_var_max = {}
# The vars for algo = abs_max
self._quantized_var_abs_max = {}
# The vars for algo = avg
self._quantized_var_avg = {}
# The best loss of algo = mse
self._best_mse_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}

def quantize(self):
'''
Expand All @@ -341,7 +357,7 @@ def quantize(self):
self._collect_target_varnames()
self._set_activation_persistable()

if self._algo == "KL":
if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0
for data in self._data_loader():
Expand Down Expand Up @@ -373,14 +389,17 @@ def quantize(self):
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break

if self._algo == 'avg':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是获取阈值的逻辑,后面计算阈值的部分。

for var_name in self._quantized_act_var_name:
self._quantized_threshold[var_name] = np.array(self._quantized_var_avg[var_name]).mean()
_logger.info("Finish sampling stage, all batch: " + str(batch_id))

self._reset_activation_persistable()

if self._algo == "KL":
self._calculate_kl_threshold()

if self._algo in ["KL", "abs_max"]:
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]:
self._update_program()
else:
self._save_input_threhold()
Expand Down Expand Up @@ -524,16 +543,18 @@ def _sampling(self):
'''
Sample the min/max, abs_max or histogram in every iterations.
'''
if self._algo == "abs_max":
self._sample_abs_max()
if self._algo in ["avg", "abs_max"]:
self._sample_abs_max_avg()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两个不相同的采样方式,分开成两个函数。

elif self._algo == "min_max":
self._sample_min_max()
elif self._algo == "KL":
elif self._algo == "mse":
self._sample_mse()
elif self._algo in ["KL", "hist"]:
self._sample_histogram()

def _sample_abs_max(self):
def _sample_mse(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释使用”“”xxx “”“

# Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}:
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
Expand All @@ -549,14 +570,60 @@ def _sample_abs_max(self):
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_value
self._quantized_threshold[var_name] = abs_max_value
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
s = 0.3
best_scale = 0.0
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = 100000.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要使用固定数值,使用 = float("inf")

while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2 ** (self._activation_bits -1) - 1
quant_dequant_var = np.round(np.clip(var_tensor, 0.0, scale) / scale * bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var) ** 2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
best_scale = scale
if best_scale > 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断没有必要,self._quantized_threshold[var_name] = best_scale 可以放到if mse_loss <= self._best_mse_loss[var_name]:中

self._quantized_threshold[var_name] = best_scale

def _sample_abs_max_avg(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avg和abs_max分为两个函数实现

# Only calculate abs_max value for weight for once
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value

for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \
(abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value
if self._algo == 'avg':
if (var_name not in self._quantized_var_avg):
self._quantized_var_avg[var_name] = []
abs_avg_value = float(np.mean(np.max(np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意代码每行的长度

self._quantized_var_avg[var_name].append(abs_avg_value)
continue
if (var_name not in self._quantized_threshold) or \
(abs_max_value > self._quantized_threshold[var_name]):
self._quantized_threshold[var_name] = abs_max_value

def _sample_min_max(self):
if self._quantized_var_min == {} and self._quantized_var_max == {}:
Expand Down Expand Up @@ -646,12 +713,12 @@ def _init_sampling_act_histogram(self):
[], bins=self._histogram_bins, range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges]

def _calculate_kl_threshold(self):
def _calculate_kl_hist_threshold(self):
'''
Calculate the KL threshold of quantized variables.
Calculate the KL or hist threshold of quantized variables.
'''
_logger.info("Calculate KL threshold ...")
assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
_logger.info("Calculate {} threshold ...".format(self._algo))
assert self._algo in ["KL", "hist"], "The algo should be KL to calculate kl threshold."

# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
Expand All @@ -669,12 +736,16 @@ def _calculate_kl_threshold(self):
for i in range(weight_data.shape[0]):
weight_threshold.append(
float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold
self._quantized_var_threshold[var_name] = weight_threshold

for var_name in self._quantized_act_var_name:
hist, hist_edeges = self._sampling_act_histogram[var_name]
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges)
if self._algo == "KL":
self._quantized_var_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges)
elif self._algo == "hist":
self._quantized_var_threshold[var_name] = \
self._get_hist_scaling_factor(hist, hist_edeges)

def _update_program(self):
'''
Expand Down Expand Up @@ -712,10 +783,10 @@ def _update_program(self):
add_quant_dequant_pass.apply(graph)

# save abs_max or KL threshold to scale var node
if self._algo == "KL":
scale_dict = self._quantized_var_kl_threshold
if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_threshold
else:
scale_dict = self._quantized_var_abs_max
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
_set_variable_data(
self._scope,
Expand All @@ -734,6 +805,7 @@ def _update_program(self):
freeze_pass = QuantizationFreezePass(
scope=self._scope,
place=self._place,
bias_correct=self._bias_correct,
weight_bits=self._weight_bits,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
Expand Down Expand Up @@ -762,19 +834,29 @@ def analysis_and_save_info(op_node, out_var_name):
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name,
self._quantized_var_kl_threshold, "out_threshold",
self._quantized_var_threshold, "out_threshold",
"post_kl")
save_info(
op_node, out_var_name, self._quantized_var_kl_threshold,
op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
elif self._algo == "abs_max":
save_info(op_node, out_var_name, self._quantized_var_abs_max,
"out_threshold", "post_abs_max")
elif self._algo == "hist":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name,
self._quantized_var_threshold, "out_threshold",
"post_hist")
save_info(
op_node, out_var_name, self._quantized_var_abs_max,
op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
"post_hist")

elif self._algo in ["avg", "abs_max", "mse"]:
save_info(op_node, out_var_name, self._quantized_threshold,
"out_threshold", "post_absmax")
save_info(
op_node, out_var_name, self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_absmax")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

三种方法有点点差别,建议区分开来

elif self._algo == "min_max":
save_info(op_node, out_var_name, self._quantized_var_min,
"out_min", "post_min_max")
Expand Down Expand Up @@ -817,10 +899,27 @@ def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
op._set_attr("quantization_type", quantization_type)
op._set_attr("bit_length", self._weight_bits)

def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
def _get_hist_scaling_factor(self, hist, hist_edges):
'''
Using the hist method to get the scaling factor.
'''
threshold_rate = self._hist_perc
hist = hist / float(sum(hist))
hist_sum = 0
hist_index = 0
for i in range(len(hist)):
hist_sum += hist[i]
if hist_sum >= threshold_rate:
hist_index = i + 1
break
bin_width = hist_edges[1] - hist_edges[0]
return (hist_index - 0.5) * bin_width

def _get_kl_scaling_factor(self, hist, hist_edeges):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
num_quantized_bins = 2 ** (self._activation_bits - 1) - 1
ending_iter = self._histogram_bins - 1
starting_iter = int(ending_iter * 0.7)
bin_width = hist_edeges[1] - hist_edeges[0]
Expand Down
Loading