Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add BNN quantization algorithm #1832

Merged
merged 24 commits into from
Dec 24, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/model_compress/QAT_torch_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer):
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
quantizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

Expand Down
40 changes: 37 additions & 3 deletions src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import logging
import torch
from .compressor import Quantizer
from .compressor import Quantizer, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -234,13 +234,47 @@ def __init__(self, model, config_list):
super().__init__(model, config_list)

def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits'])
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out

def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
return output


class ClipGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type):
if quant_type == QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output


class BNNQuantizer(Quantizer):
Cjkkkk marked this conversation as resolved.
Show resolved Hide resolved
"""BNNQuantizer
"""
def __init__(self, model, config_list):
"""
config_list: supported keys:
Cjkkkk marked this conversation as resolved.
Show resolved Hide resolved
- q_bits
"""
super().__init__(model, config_list)
self.quant_grad = ClipGrad

def quantize_weight(self, weight, config, **kwargs):
# module.weight.data = torch.clamp(weight, -1.0, 1.0)
Cjkkkk marked this conversation as resolved.
Show resolved Hide resolved
out = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out

def quantize_output(self, output, config, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
return out
74 changes: 53 additions & 21 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""

def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = QuantGrad

def quantize_weight(self, weight, config, op, op_type, op_name):
"""
quantize should overload this method to quantize weight.
Expand Down Expand Up @@ -317,50 +321,78 @@ def _instrument_layer(self, layer, config):
if 'weight' in config["quant_types"]:
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
layer.module.register_parameter("old_weight", torch.nn.Parameter(layer.module.weight))
delattr(layer.module, "weight")
layer.module.register_buffer('weight', layer.module.old_weight)

layer._forward = layer.module.forward

def new_forward(*inputs):
if 'input' in config["quant_types"]:
inputs = straight_through_quantize_input.apply(inputs, self, config, layer)
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)

if 'weight' in config["quant_types"] and _check_weight(layer.module):
weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
layer.module.weight.data = weight
else:
result = layer._forward(*inputs)

if 'output' in config["quant_types"]:
result = straight_through_quantize_output.apply(result, self, config, layer)
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result

layer.module.forward = new_forward

class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2

class straight_through_quantize_output(torch.autograd.Function):
class QuantGrad(torch.autograd.Function):
Cjkkkk marked this conversation as resolved.
Show resolved Hide resolved
"""
Base class for overriding backward function of quantization operation.
"""
@staticmethod
def forward(ctx, output, quantizer, config, layer):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name)
def quant_backward(tensor, grad_output, quant_type):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator

@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
Parameters
----------
tensor : Tensor
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.

class straight_through_quantize_input(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, quantizer, config, layer):
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name)
Returns
-------
tensor
gradient of the input of quantization operation
"""
return grad_output

@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
def forward(ctx, tensor, quant_type, quant_func, config, layer):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)

@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
Cjkkkk marked this conversation as resolved.
Show resolved Hide resolved
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None

def _check_weight(module):
try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor)
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
Cjkkkk marked this conversation as resolved.
Show resolved Hide resolved