Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
add quantization backward support (#1854)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk authored and chicm-ms committed Dec 24, 2019
1 parent 7a55811 commit 4f3ee9c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 43 deletions.
1 change: 0 additions & 1 deletion examples/model_compress/QAT_torch_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer):
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
quantizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

Expand Down
11 changes: 4 additions & 7 deletions src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_bits_length(config, quant_type):


class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
Expand Down Expand Up @@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(model, config_list)

def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits'])
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out

def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
return output
return output
105 changes: 70 additions & 35 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""

def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = QuantGrad

def quantize_weight(self, weight, config, op, op_type, op_name):
"""
quantize should overload this method to quantize weight.
Expand All @@ -262,7 +266,7 @@ def quantize_weight(self, weight, config, op, op_type, op_name):
config : dict
the configuration for weight quantization
"""
raise NotImplementedError("Quantizer must overload quantize_weight()")
raise NotImplementedError('Quantizer must overload quantize_weight()')

def quantize_output(self, output, config, op, op_type, op_name):
"""
Expand All @@ -276,7 +280,7 @@ def quantize_output(self, output, config, op, op_type, op_name):
config : dict
the configuration for output quantization
"""
raise NotImplementedError("Quantizer must overload quantize_output()")
raise NotImplementedError('Quantizer must overload quantize_output()')

def quantize_input(self, *inputs, config, op, op_type, op_name):
"""
Expand All @@ -290,7 +294,7 @@ def quantize_input(self, *inputs, config, op, op_type, op_name):
config : dict
the configuration for inputs quantization
"""
raise NotImplementedError("Quantizer must overload quantize_input()")
raise NotImplementedError('Quantizer must overload quantize_input()')


def _instrument_layer(self, layer, config):
Expand All @@ -305,62 +309,93 @@ def _instrument_layer(self, layer, config):
the configuration for quantization
"""
assert layer._forward is None, 'Each model can only be compressed once'
assert "quant_types" in config, 'must provide quant_types in config'
assert isinstance(config["quant_types"], list), 'quant_types must be list type'
assert "quant_bits" in config, 'must provide quant_bits in config'
assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type'
assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config'
assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type'

if isinstance(config["quant_bits"], dict):
for quant_type in config["quant_types"]:
assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type
if isinstance(config['quant_bits'], dict):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type

if 'weight' in config["quant_types"]:
if 'weight' in config['quant_types']:
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)

layer._forward = layer.module.forward

def new_forward(*inputs):
if 'input' in config["quant_types"]:
inputs = straight_through_quantize_input.apply(inputs, self, config, layer)
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)

if 'weight' in config["quant_types"] and _check_weight(layer.module):
weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
layer.module.weight.data = weight
else:
result = layer._forward(*inputs)

if 'output' in config["quant_types"]:
result = straight_through_quantize_output.apply(result, self, config, layer)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result

layer.module.forward = new_forward

class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2

class straight_through_quantize_output(torch.autograd.Function):
class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@staticmethod
def forward(ctx, output, quantizer, config, layer):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name)
def quant_backward(tensor, grad_output, quant_type):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
Parameters
----------
tensor : Tensor
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
class straight_through_quantize_input(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, quantizer, config, layer):
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name)
Returns
-------
tensor
gradient of the input of quantization operation
"""
return grad_output

@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
def forward(ctx, tensor, quant_type, quant_func, config, layer):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)

@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None

def _check_weight(module):
try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor)
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False

0 comments on commit 4f3ee9c

Please sign in to comment.