Skip to content

Commit

Permalink
add new post-quant methods (#32208)
Browse files Browse the repository at this point in the history
  • Loading branch information
XGZhang11 authored Apr 14, 2021
1 parent cb81826 commit 4281eb4
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _set_variable_data(scope, place, var_name, np_value):
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
Expand Down Expand Up @@ -138,8 +138,10 @@ def __init__(self,
batch_size=10,
batch_nums=None,
algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
bias_correction=False,
activation_bits=8,
weight_bits=8,
activation_quantize_type='range_abs_max',
Expand Down Expand Up @@ -180,14 +182,22 @@ def __init__(self,
get the KL threshold for quantized activations and get the abs_max
value for quantized weights. If algo='abs_max', get the abs max
value for activations and weights. If algo= 'min_max', get the min
and max value for quantized activations and weights. Default is KL.
and max value for quantized activations and weights. If algo='avg',
get the average value among the max values for activations. If
algo= 'hist', get the value of 'hist_percent' quantile as the threshold.
If algo='mse', get the value which makes the quantization mse loss
minimal. Default is KL.
hist_percent(float, optional): The threshold of algo 'hist' for activations.
Default is 0.99999.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
activation_bits(int): quantization bit number for activation.
weight_bits(int, optional): quantization bit number for weights.
activation_quantize_type(str): quantization type for activation,
Expand Down Expand Up @@ -255,7 +265,9 @@ def __init__(self,
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'
]
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
Expand All @@ -270,7 +282,7 @@ def __init__(self,
"cannot be None in the same time."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max."
"The algo should be KL, hist, mse, avg, abs_max or min_max."
assert activation_quantize_type in self._support_activation_quantize_type, \
"The activation_quantize_type ({}) should in ({}).".format(
activation_quantize_type, self._support_activation_quantize_type)
Expand All @@ -279,6 +291,7 @@ def __init__(self,
weight_quantize_type, self._support_weight_quantize_type)

# Save input params
self._bias_correction = bias_correction
self._executor = executor
self._scope = global_scope() if scope == None else scope
self._model_dir = model_dir
Expand All @@ -289,6 +302,7 @@ def __init__(self,
self._batch_size = batch_size
self._batch_nums = batch_nums
self._algo = algo
self._hist_percent = hist_percent
self._activation_bits = activation_bits
self._weight_bits = weight_bits
self._activation_quantize_type = activation_quantize_type
Expand All @@ -314,17 +328,21 @@ def __init__(self,
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._weight_op_pairs = {}
# The vars for alog = KL
# The vars for alog = KL or hist
self._sampling_act_abs_min_max = {}
self._sampling_act_histogram = {}
self._sampling_data = {}
self._quantized_var_kl_threshold = {}
self._quantized_var_threshold = {}
self._histogram_bins = 2048
# The vars for algo = min_max
self._quantized_var_min = {}
self._quantized_var_max = {}
# The vars for algo = abs_max
self._quantized_var_abs_max = {}
# The vars for algo = avg
self._quantized_var_avg = {}
# The best loss of algo = mse
self._best_mse_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}

def quantize(self):
'''
Expand All @@ -341,7 +359,7 @@ def quantize(self):
self._collect_target_varnames()
self._set_activation_persistable()

if self._algo == "KL":
if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0
for data in self._data_loader():
Expand Down Expand Up @@ -374,13 +392,14 @@ def quantize(self):
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))

self._reset_activation_persistable()

if self._algo == "KL":
self._calculate_kl_threshold()

if self._algo in ["KL", "abs_max"]:
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
self._quantized_threshold[var_name] = \
np.array(self._quantized_var_avg[var_name]).mean()
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]:
self._update_program()
else:
self._save_input_threhold()
Expand Down Expand Up @@ -526,14 +545,84 @@ def _sampling(self):
'''
if self._algo == "abs_max":
self._sample_abs_max()
elif self._algo == "avg":
self._sample_avg()
elif self._algo == "min_max":
self._sample_min_max()
elif self._algo == "KL":
elif self._algo == "mse":
self._sample_mse()
elif self._algo in ["KL", "hist"]:
self._sample_histogram()

def _sample_mse(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
s = 0.3
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
self._quantized_threshold[var_name] = scale

def _sample_avg(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value

for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_avg):
self._quantized_var_avg[var_name] = []
abs_avg_value = float(np.mean(np.max( \
np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1))))
self._quantized_var_avg[var_name].append(abs_avg_value)
continue

def _sample_abs_max(self):
# Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}:
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
Expand All @@ -549,14 +638,14 @@ def _sample_abs_max(self):
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_value
self._quantized_threshold[var_name] = abs_max_value

for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \
(abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value
if (var_name not in self._quantized_threshold) or \
(abs_max_value > self._quantized_threshold[var_name]):
self._quantized_threshold[var_name] = abs_max_value

def _sample_min_max(self):
if self._quantized_var_min == {} and self._quantized_var_max == {}:
Expand Down Expand Up @@ -646,12 +735,12 @@ def _init_sampling_act_histogram(self):
[], bins=self._histogram_bins, range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges]

def _calculate_kl_threshold(self):
def _calculate_kl_hist_threshold(self):
'''
Calculate the KL threshold of quantized variables.
Calculate the KL or hist threshold of quantized variables.
'''
_logger.info("Calculate KL threshold ...")
assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
_logger.info("Calculate {} threshold ...".format(self._algo))
assert self._algo in ["KL", "hist"], "The algo should be KL or hist."

# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
Expand All @@ -669,18 +758,22 @@ def _calculate_kl_threshold(self):
for i in range(weight_data.shape[0]):
weight_threshold.append(
float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold
self._quantized_var_threshold[var_name] = weight_threshold

for var_name in self._quantized_act_var_name:
hist, hist_edeges = self._sampling_act_histogram[var_name]
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges)
if self._algo == "KL":
self._quantized_var_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges)
elif self._algo == "hist":
self._quantized_var_threshold[var_name] = \
self._get_hist_scaling_factor(hist, hist_edeges)

def _update_program(self):
'''
Use QuantizationTransformPass and AddQuantDequantPass to insert
fake_quantize, fake_dequantize and fake_quant_dequant op.
Besides, save all kl threshold to the scale var node.
Besides, save all threshold to the scale var node.
'''
_logger.info("Update the program ...")
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
Expand Down Expand Up @@ -711,11 +804,11 @@ def _update_program(self):
quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph)

# save abs_max or KL threshold to scale var node
if self._algo == "KL":
scale_dict = self._quantized_var_kl_threshold
# save threshold to scale var node
if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_threshold
else:
scale_dict = self._quantized_var_abs_max
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
_set_variable_data(
self._scope,
Expand All @@ -734,6 +827,7 @@ def _update_program(self):
freeze_pass = QuantizationFreezePass(
scope=self._scope,
place=self._place,
bias_correction=self._bias_correction,
weight_bits=self._weight_bits,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
Expand Down Expand Up @@ -761,20 +855,28 @@ def analysis_and_save_info(op_node, out_var_name):
out_var_name + " is not the output of the op"
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name,
self._quantized_var_kl_threshold, "out_threshold",
"post_kl")
save_info(op_node, out_var_name, self._quantized_var_threshold,
"out_threshold", "post_kl")
save_info(
op_node, out_var_name, self._quantized_var_kl_threshold,
op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
elif self._algo == "abs_max":
save_info(op_node, out_var_name, self._quantized_var_abs_max,
"out_threshold", "post_abs_max")
elif self._algo == "hist":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name, self._quantized_var_threshold,
"out_threshold", "post_hist")
save_info(
op_node, out_var_name, self._quantized_var_abs_max,
op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
"post_hist")

elif self._algo in ["avg", "abs_max", "mse"]:
save_info(op_node, out_var_name, self._quantized_threshold,
"out_threshold", "post_" + str(self._algo))
save_info(
op_node, out_var_name, self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_" + str(self._algo))
elif self._algo == "min_max":
save_info(op_node, out_var_name, self._quantized_var_min,
"out_min", "post_min_max")
Expand Down Expand Up @@ -817,10 +919,27 @@ def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
op._set_attr("quantization_type", quantization_type)
op._set_attr("bit_length", self._weight_bits)

def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
def _get_hist_scaling_factor(self, hist, hist_edges):
'''
Using the hist method to get the scaling factor.
'''
threshold_rate = self._hist_percent
hist = hist / float(sum(hist))
hist_sum = 0
hist_index = 0
for i in range(len(hist)):
hist_sum += hist[i]
if hist_sum >= threshold_rate:
hist_index = i + 1
break
bin_width = hist_edges[1] - hist_edges[0]
return (hist_index - 0.5) * bin_width

def _get_kl_scaling_factor(self, hist, hist_edeges):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
num_quantized_bins = 2**(self._activation_bits - 1) - 1
ending_iter = self._histogram_bins - 1
starting_iter = int(ending_iter * 0.7)
bin_width = hist_edeges[1] - hist_edeges[0]
Expand Down
Loading

0 comments on commit 4281eb4

Please sign in to comment.