diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index 5dd99e3432..67791117e1 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -1,6 +1,5 @@ Quantizer on NNI Compressor === - ## Naive Quantizer We provide Naive Quantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm without any configure. @@ -53,11 +52,24 @@ You can view example for more information #### User configuration for QAT Quantizer * **quant_types:** : list of string -type of quantization you want to apply, currently support 'weight', 'input', 'output' + +type of quantization you want to apply, currently support 'weight', 'input', 'output'. + +* **op_types:** list of string + +specify the type of modules that will be quantized. eg. 'Conv2D' + +* **op_names:** list of string + +specify the name of modules that will be quantized. eg. 'conv1' + * **quant_bits:** int or dict of {str : int} -bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8}, -when the type is int, all quantization types share same bits length + +bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, +when the type is int, all quantization types share same bits length. + * **quant_start_step:** int + disable quantization until model are run by certain number of steps, this allows the network to enter a more stable state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 @@ -71,17 +83,14 @@ In [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bit ### Usage To implement DoReFa Quantizer, you can add code below before your training code -Tensorflow code -```python -from nni.compressors.tensorflow import DoReFaQuantizer -config_list = [{ 'q_bits': 8, 'op_types': 'default' }] -quantizer = DoReFaQuantizer(tf.get_default_graph(), config_list) -quantizer.compress() -``` PyTorch code ```python from nni.compressors.torch import DoReFaQuantizer -config_list = [{ 'q_bits': 8, 'op_types': 'default' }] +config_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 8, + 'op_types': 'default' +}] quantizer = DoReFaQuantizer(model, config_list) quantizer.compress() ``` @@ -89,4 +98,79 @@ quantizer.compress() You can view example for more information #### User configuration for DoReFa Quantizer -* **q_bits:** This is to specify the q_bits operations to be quantized to +* **quant_types:** : list of string + +type of quantization you want to apply, currently support 'weight', 'input', 'output'. + +* **op_types:** list of string + +specify the type of modules that will be quantized. eg. 'Conv2D' + +* **op_names:** list of string + +specify the name of modules that will be quantized. eg. 'conv1' + +* **quant_bits:** int or dict of {str : int} + +bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, +when the type is int, all quantization types share same bits length. + + +## BNN Quantizer +In [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), + +>We introduce a method to train Binarized Neural Networks (BNNs) - neural networks with binary weights and activations at run-time. At training-time the binary weights and activations are used for computing the parameters gradients. During the forward pass, BNNs drastically reduce memory size and accesses, and replace most arithmetic operations with bit-wise operations, which is expected to substantially improve power-efficiency. + + +### Usage + +PyTorch code +```python +from nni.compression.torch import BNNQuantizer +model = VGG_Cifar10(num_classes=10) + +configure_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 1, + 'op_types': ['Conv2d', 'Linear'], + 'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3'] +}, { + 'quant_types': ['output'], + 'quant_bits': 1, + 'op_types': ['Hardtanh'], + 'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5'] +}] + +quantizer = BNNQuantizer(model, configure_list) +model = quantizer.compress() +``` + +You can view example [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) for more information. + +#### User configuration for BNN Quantizer +* **quant_types:** : list of string + +type of quantization you want to apply, currently support 'weight', 'input', 'output'. + +* **op_types:** list of string + +specify the type of modules that will be quantized. eg. 'Conv2D' + +* **op_names:** list of string + +specify the name of modules that will be quantized. eg. 'conv1' + +* **quant_bits:** int or dict of {str : int} + +bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, +when the type is int, all quantization types share same bits length. + +### Experiment +We implemented one of the experiments in [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), we quantized the **VGGNet** for CIFAR-10 in the paper. Our experiments results are as follows: + +| Model | Accuracy | +| ------------- | --------- | +| VGGNet | 86.93% | + + +The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) \ No newline at end of file diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py new file mode 100644 index 0000000000..d4908885c3 --- /dev/null +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.compression.torch import BNNQuantizer + + +class VGG_Cifar10(nn.Module): + def __init__(self, num_classes=1000): + super(VGG_Cifar10, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128, eps=1e-4, momentum=0.1), + nn.Hardtanh(inplace=True), + + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(128, eps=1e-4, momentum=0.1), + nn.Hardtanh(inplace=True), + + nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(256, eps=1e-4, momentum=0.1), + nn.Hardtanh(inplace=True), + + + nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(256, eps=1e-4, momentum=0.1), + nn.Hardtanh(inplace=True), + + + nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(512, eps=1e-4, momentum=0.1), + nn.Hardtanh(inplace=True), + + + nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(512, eps=1e-4, momentum=0.1), + nn.Hardtanh(inplace=True) + ) + + self.classifier = nn.Sequential( + nn.Linear(512 * 4 * 4, 1024, bias=False), + nn.BatchNorm1d(1024), + nn.Hardtanh(inplace=True), + nn.Linear(1024, 1024, bias=False), + nn.BatchNorm1d(1024), + nn.Hardtanh(inplace=True), + nn.Linear(1024, num_classes), # do not quantize output + nn.BatchNorm1d(num_classes, affine=False) + ) + + + def forward(self, x): + x = self.features(x) + x = x.view(-1, 512 * 4 * 4) + x = self.classifier(x) + return x + + +def train(model, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + for name, param in model.named_parameters(): + if name.endswith('old_weight'): + param = param.clamp(-1, 1) + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + acc = 100 * correct / len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, acc)) + return acc + +def adjust_learning_rate(optimizer, epoch): + update_list = [55, 100, 150, 200, 400, 600] + if epoch in update_list: + for param_group in optimizer.param_groups: + param_group['lr'] = param_group['lr'] * 0.1 + return + +def main(): + torch.manual_seed(0) + device = torch.device('cuda') + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=200, shuffle=False) + + model = VGG_Cifar10(num_classes=10) + model.to(device) + + configure_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 1, + 'op_types': ['Conv2d', 'Linear'], + 'op_names': ['features.3', 'features.7', 'features.10', 'features.14', 'classifier.0', 'classifier.3'] + }, { + 'quant_types': ['output'], + 'quant_bits': 1, + 'op_types': ['Hardtanh'], + 'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5'] + }] + + quantizer = BNNQuantizer(model, configure_list) + model = quantizer.compress() + + print('=' * 10 + 'train' + '=' * 10) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + best_top1 = 0 + for epoch in range(400): + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer) + adjust_learning_rate(optimizer, epoch) + top1 = test(model, device, test_loader) + if top1 > best_top1: + best_top1 = top1 + print(best_top1) + + +if __name__ == '__main__': + main() diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index 7f9c3b144a..2204428574 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -3,7 +3,7 @@ import logging import torch -from .compressor import Quantizer +from .compressor import Quantizer, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer'] @@ -240,4 +240,34 @@ def quantize_weight(self, weight, config, **kwargs): def quantize(self, input_ri, q_bits): scale = pow(2, q_bits)-1 output = torch.round(input_ri*scale)/scale - return output \ No newline at end of file + return output + + +class ClipGrad(QuantGrad): + @staticmethod + def quant_backward(tensor, grad_output, quant_type): + if quant_type == QuantType.QUANT_OUTPUT: + grad_output[torch.abs(tensor) > 1] = 0 + return grad_output + + +class BNNQuantizer(Quantizer): + """Binarized Neural Networks, as defined in: + Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 + (https://arxiv.org/abs/1602.02830) + """ + def __init__(self, model, config_list): + super().__init__(model, config_list) + self.quant_grad = ClipGrad + + def quantize_weight(self, weight, config, **kwargs): + out = torch.sign(weight) + # remove zeros + out[out == 0] = 1 + return out + + def quantize_output(self, output, config, **kwargs): + out = torch.sign(output) + # remove zeros + out[out == 0] = 1 + return out