From e3979c09f2bc208459a3f8d237029c143bdde7b4 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Tue, 3 Dec 2019 15:40:50 +0800 Subject: [PATCH 01/20] update call to new API --- src/sdk/pynni/nni/compression/torch/builtin_quantizers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index ce6c8bc902..d5847f9090 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -234,9 +234,10 @@ def __init__(self, model, config_list): super().__init__(model, config_list) def quantize_weight(self, weight, config, **kwargs): + weight_bits = get_bits_length(config, 'weight') out = weight.tanh() out = out / (2 * out.abs().max()) + 0.5 - out = self.quantize(out, config['q_bits']) + out = self.quantize(out, weight_bits) out = 2 * out -1 return out From fd85395fa869b69c94592a421374dfb2845a7da3 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Thu, 5 Dec 2019 17:14:17 +0800 Subject: [PATCH 02/20] refactor framework --- .../model_compress/BNN_quantizer_cifar10.py | 173 ++++++++++++++++++ .../compression/torch/builtin_quantizers.py | 25 +++ .../pynni/nni/compression/torch/compressor.py | 22 +-- 3 files changed, 204 insertions(+), 16 deletions(-) create mode 100644 examples/model_compress/BNN_quantizer_cifar10.py diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py new file mode 100644 index 0000000000..e9ee94351d --- /dev/null +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -0,0 +1,173 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.compression.torch import BNNQuantizer + + +class vgg(nn.Module): + def __init__(self, init_weights=True): + super(vgg, self).__init__() + cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512] + self.cfg = cfg + self.feature = self.make_layers(cfg, True) + num_classes = 10 + self.classifier = nn.Sequential( + nn.Linear(cfg[-1], 512), + nn.BatchNorm1d(512), + nn.ReLU(inplace=True), + nn.Linear(512, num_classes) + ) + if init_weights: + self._initialize_weights() + + def make_layers(self, cfg, batch_norm=True): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + def forward(self, x): + x = self.feature(x) + x = nn.AvgPool2d(2)(x) + x = x.view(x.size(0), -1) + y = self.classifier(x) + return y + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(0.5) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def train(model, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + acc = 100 * correct / len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, acc)) + return acc + + +def main(): + torch.manual_seed(0) + device = torch.device('cuda') + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=True, download=True, + transform=transforms.Compose([ + transforms.Pad(4), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=200, shuffle=False) + + model = vgg() + model.to(device) + + # Train the base VGG-16 model + print('=' * 10 + 'Train the unpruned base model' + '=' * 10) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0) + for epoch in range(160): + train(model, device, train_loader, optimizer) + test(model, device, test_loader) + lr_scheduler.step(epoch) + torch.save(model.state_dict(), 'vgg16_cifar10.pth') + + # Test base model accuracy + print('=' * 10 + 'Test on the original model' + '=' * 10) + model.load_state_dict(torch.load('vgg16_cifar10.pth')) + test(model, device, test_loader) + # top1 = 93.51% + + # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', + # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' + configure_list = [{ + 'sparsity': 0.5, + 'op_types': ['default'], + 'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] + }] + + # Prune model and test accuracy without fine tuning. + print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) + pruner = L1FilterPruner(model, configure_list) + model = pruner.compress() + test(model, device, test_loader) + # top1 = 88.19% + + # Fine tune the pruned model for 40 epochs and test accuracy + print('=' * 10 + 'Fine tuning' + '=' * 10) + optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) + best_top1 = 0 + for epoch in range(40): + pruner.update_epoch(epoch) + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer_finetune) + top1 = test(model, device, test_loader) + if top1 > best_top1: + best_top1 = top1 + # Export the best model, 'model_path' stores state_dict of the pruned model, + # mask_path stores mask_dict of the pruned model + pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth') + + # Test the exported model + print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) + new_model = vgg() + new_model.to(device) + new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) + test(new_model, device, test_loader) + # top1 = 93.53% + + +if __name__ == '__main__': + main() diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index d5847f9090..3b72100b3b 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -245,3 +245,28 @@ def quantize(self, input_ri, q_bits): scale = pow(2, q_bits)-1 output = torch.round(input_ri*scale)/scale return output + +class BNNQuantizer(Quantizer): + """Quantizer using the DoReFa scheme, as defined in: + Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients + (https://arxiv.org/abs/1606.06160) + """ + def __init__(self, model, config_list): + """ + config_list: supported keys: + - q_bits + """ + super().__init__(model, config_list) + + def quantize_weight(self, weight, config, **kwargs): + out = torch.sign(weight) + # remove zeros + out[out == 0] = 1 + return out + + def quantize_output(self, output, config, **kwargs): + # dont quantize last layer output + out = torch.sign(output) + # remove zeros + out[out == 0] = 1 + return out \ No newline at end of file diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 18a2d15fe1..5e8b6a47dc 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -321,11 +321,11 @@ def _instrument_layer(self, layer, config): def new_forward(*inputs): if 'input' in config["quant_types"]: - inputs = straight_through_quantize_input.apply(inputs, self, config, layer) + inputs = straight_through_estimator.apply(inputs, self.quantize_input, config, layer) if 'weight' in config["quant_types"] and _check_weight(layer.module): weight = layer.module.weight.data - new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) + new_weight = straight_through_estimator.apply(weight, self.quantize_weight, config, layer) layer.module.weight.data = new_weight result = layer._forward(*inputs) layer.module.weight.data = weight @@ -333,26 +333,16 @@ def new_forward(*inputs): result = layer._forward(*inputs) if 'output' in config["quant_types"]: - result = straight_through_quantize_output.apply(result, self, config, layer) + result = straight_through_estimator.apply(result, self.quantize_output, config, layer) return result layer.module.forward = new_forward -class straight_through_quantize_output(torch.autograd.Function): +class straight_through_estimator(torch.autograd.Function): @staticmethod - def forward(ctx, output, quantizer, config, layer): - return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name) - - @staticmethod - def backward(ctx, grad_output): - # Straight-through estimator - return grad_output, None, None, None - -class straight_through_quantize_input(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, quantizer, config, layer): - return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name) + def forward(ctx, output, func, config, layer): + return func(output, config, op=layer.module, op_type=layer.type, op_name=layer.name) @staticmethod def backward(ctx, grad_output): From 8539b173a5b4f02a4444fc2e25202ed582c0f02d Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Mon, 9 Dec 2019 10:27:34 +0800 Subject: [PATCH 03/20] update framework --- .../model_compress/BNN_quantizer_cifar10.py | 173 ------------------ .../compression/torch/builtin_quantizers.py | 20 +- .../pynni/nni/compression/torch/compressor.py | 28 ++- 3 files changed, 33 insertions(+), 188 deletions(-) delete mode 100644 examples/model_compress/BNN_quantizer_cifar10.py diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py deleted file mode 100644 index e9ee94351d..0000000000 --- a/examples/model_compress/BNN_quantizer_cifar10.py +++ /dev/null @@ -1,173 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import datasets, transforms -from nni.compression.torch import BNNQuantizer - - -class vgg(nn.Module): - def __init__(self, init_weights=True): - super(vgg, self).__init__() - cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512] - self.cfg = cfg - self.feature = self.make_layers(cfg, True) - num_classes = 10 - self.classifier = nn.Sequential( - nn.Linear(cfg[-1], 512), - nn.BatchNorm1d(512), - nn.ReLU(inplace=True), - nn.Linear(512, num_classes) - ) - if init_weights: - self._initialize_weights() - - def make_layers(self, cfg, batch_norm=True): - layers = [] - in_channels = 3 - for v in cfg: - if v == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - else: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) - if batch_norm: - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] - else: - layers += [conv2d, nn.ReLU(inplace=True)] - in_channels = v - return nn.Sequential(*layers) - - def forward(self, x): - x = self.feature(x) - x = nn.AvgPool2d(2)(x) - x = x.view(x.size(0), -1) - y = self.classifier(x) - return y - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(0.5) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() - - -def train(model, device, train_loader, optimizer): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.cross_entropy(output, target) - loss.backward() - optimizer.step() - if batch_idx % 100 == 0: - print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - test_loss /= len(test_loader.dataset) - acc = 100 * correct / len(test_loader.dataset) - - print('Loss: {} Accuracy: {}%)\n'.format( - test_loss, acc)) - return acc - - -def main(): - torch.manual_seed(0) - device = torch.device('cuda') - train_loader = torch.utils.data.DataLoader( - datasets.CIFAR10('./data.cifar10', train=True, download=True, - transform=transforms.Compose([ - transforms.Pad(4), - transforms.RandomCrop(32), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - ])), - batch_size=64, shuffle=True) - test_loader = torch.utils.data.DataLoader( - datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - ])), - batch_size=200, shuffle=False) - - model = vgg() - model.to(device) - - # Train the base VGG-16 model - print('=' * 10 + 'Train the unpruned base model' + '=' * 10) - optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0) - for epoch in range(160): - train(model, device, train_loader, optimizer) - test(model, device, test_loader) - lr_scheduler.step(epoch) - torch.save(model.state_dict(), 'vgg16_cifar10.pth') - - # Test base model accuracy - print('=' * 10 + 'Test on the original model' + '=' * 10) - model.load_state_dict(torch.load('vgg16_cifar10.pth')) - test(model, device, test_loader) - # top1 = 93.51% - - # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', - # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' - configure_list = [{ - 'sparsity': 0.5, - 'op_types': ['default'], - 'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] - }] - - # Prune model and test accuracy without fine tuning. - print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) - pruner = L1FilterPruner(model, configure_list) - model = pruner.compress() - test(model, device, test_loader) - # top1 = 88.19% - - # Fine tune the pruned model for 40 epochs and test accuracy - print('=' * 10 + 'Fine tuning' + '=' * 10) - optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) - best_top1 = 0 - for epoch in range(40): - pruner.update_epoch(epoch) - print('# Epoch {} #'.format(epoch)) - train(model, device, train_loader, optimizer_finetune) - top1 = test(model, device, test_loader) - if top1 > best_top1: - best_top1 = top1 - # Export the best model, 'model_path' stores state_dict of the pruned model, - # mask_path stores mask_dict of the pruned model - pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth') - - # Test the exported model - print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) - new_model = vgg() - new_model.to(device) - new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) - test(new_model, device, test_loader) - # top1 = 93.53% - - -if __name__ == '__main__': - main() diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index 3b72100b3b..f22a789686 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -3,9 +3,9 @@ import logging import torch -from .compressor import Quantizer +from .compressor import Quantizer, QuantGrad -__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer'] +__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] logger = logging.getLogger(__name__) @@ -246,10 +246,17 @@ def quantize(self, input_ri, q_bits): output = torch.round(input_ri*scale)/scale return output + +class ClipGrad(QuantGrad): + @staticmethod + def quant_backward(tensor, grad_output, quant_type): + if quant_type == 2: + grad_output[torch.abs(tensor) > 1] = 0 + return grad_output + + class BNNQuantizer(Quantizer): - """Quantizer using the DoReFa scheme, as defined in: - Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients - (https://arxiv.org/abs/1606.06160) + """BNNQuantizer """ def __init__(self, model, config_list): """ @@ -257,15 +264,16 @@ def __init__(self, model, config_list): - q_bits """ super().__init__(model, config_list) + self.quant_grad = ClipGrad def quantize_weight(self, weight, config, **kwargs): + # module.weight.data = torch.clamp(weight, -1.0, 1.0) out = torch.sign(weight) # remove zeros out[out == 0] = 1 return out def quantize_output(self, output, config, **kwargs): - # dont quantize last layer output out = torch.sign(output) # remove zeros out[out == 0] = 1 diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 5e8b6a47dc..30d5460353 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -250,6 +250,10 @@ class Quantizer(Compressor): Base quantizer for pytorch quantizer """ + def __init__(self, model, config_list): + super().__init__(model, config_list) + self.quant_grad = QuantGrad + def quantize_weight(self, weight, config, op, op_type, op_name): """ quantize should overload this method to quantize weight. @@ -321,11 +325,11 @@ def _instrument_layer(self, layer, config): def new_forward(*inputs): if 'input' in config["quant_types"]: - inputs = straight_through_estimator.apply(inputs, self.quantize_input, config, layer) + inputs = self.quant_grad.apply(inputs, 0, self.quantize_input, config, layer) if 'weight' in config["quant_types"] and _check_weight(layer.module): weight = layer.module.weight.data - new_weight = straight_through_estimator.apply(weight, self.quantize_weight, config, layer) + new_weight = self.quant_grad.apply(weight, 1, self.quantize_weight, config, layer) # quant_grad is not necessary here layer.module.weight.data = new_weight result = layer._forward(*inputs) layer.module.weight.data = weight @@ -333,21 +337,27 @@ def new_forward(*inputs): result = layer._forward(*inputs) if 'output' in config["quant_types"]: - result = straight_through_estimator.apply(result, self.quantize_output, config, layer) + result = self.quant_grad.apply(result, 2, self.quantize_output, config, layer) return result layer.module.forward = new_forward -class straight_through_estimator(torch.autograd.Function): +class QuantGrad(torch.autograd.Function): @staticmethod - def forward(ctx, output, func, config, layer): - return func(output, config, op=layer.module, op_type=layer.type, op_name=layer.name) + def quant_backward(tensor, grad_output, quant_type): + return grad_output @staticmethod - def backward(ctx, grad_output): - # Straight-through estimator - return grad_output, None, None, None + def forward(ctx, tensor, quant_type, quant_func, config, layer): + ctx.save_for_backward(tensor, torch.tensor(quant_type)) + return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name) + + @classmethod + def backward(cls, ctx, grad_output): + tensor, quant_type = ctx.saved_variables + output = cls.quant_backward(tensor, grad_output, quant_type) + return output, None, None, None, None def _check_weight(module): try: From 6c71a4d46d7b930334cefc67c4e54d0cbe9117b4 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Mon, 9 Dec 2019 14:53:11 +0800 Subject: [PATCH 04/20] enable quant_weight customize backward --- .../model_compress/QAT_torch_quantizer.py | 1 - .../compression/torch/builtin_quantizers.py | 4 ++-- .../pynni/nni/compression/torch/compressor.py | 20 +++++++++++-------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/model_compress/QAT_torch_quantizer.py b/examples/model_compress/QAT_torch_quantizer.py index 04747c9f10..f1dead41f9 100644 --- a/examples/model_compress/QAT_torch_quantizer.py +++ b/examples/model_compress/QAT_torch_quantizer.py @@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer): loss = F.nll_loss(output, target) loss.backward() optimizer.step() - quantizer.step() if batch_idx % 100 == 0: print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index f22a789686..75c39a8722 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -250,7 +250,7 @@ def quantize(self, input_ri, q_bits): class ClipGrad(QuantGrad): @staticmethod def quant_backward(tensor, grad_output, quant_type): - if quant_type == 2: + if quant_type == 2: # quant_type is quant_output grad_output[torch.abs(tensor) > 1] = 0 return grad_output @@ -272,7 +272,7 @@ def quantize_weight(self, weight, config, **kwargs): # remove zeros out[out == 0] = 1 return out - + def quantize_output(self, output, config, **kwargs): out = torch.sign(output) # remove zeros diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 30d5460353..373307dcea 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -253,7 +253,7 @@ class Quantizer(Compressor): def __init__(self, model, config_list): super().__init__(model, config_list) self.quant_grad = QuantGrad - + def quantize_weight(self, weight, config, op, op_type, op_name): """ quantize should overload this method to quantize weight. @@ -321,6 +321,11 @@ def _instrument_layer(self, layer, config): if 'weight' in config["quant_types"]: if not _check_weight(layer.module): _logger.warning('Module %s does not have parameter "weight"', layer.name) + else: + layer.module.register_parameter("old_weight", torch.nn.Parameter(layer.module.weight)) + delattr(layer.module, "weight") + layer.module.register_buffer('weight', layer.module.old_weight) + layer._forward = layer.module.forward def new_forward(*inputs): @@ -328,11 +333,9 @@ def new_forward(*inputs): inputs = self.quant_grad.apply(inputs, 0, self.quantize_input, config, layer) if 'weight' in config["quant_types"] and _check_weight(layer.module): - weight = layer.module.weight.data - new_weight = self.quant_grad.apply(weight, 1, self.quantize_weight, config, layer) # quant_grad is not necessary here - layer.module.weight.data = new_weight + new_weight = self.quant_grad.apply(layer.module.old_weight, 1, self.quantize_weight, config, layer) + layer.module.weight = new_weight result = layer._forward(*inputs) - layer.module.weight.data = weight else: result = layer._forward(*inputs) @@ -350,9 +353,9 @@ def quant_backward(tensor, grad_output, quant_type): @staticmethod def forward(ctx, tensor, quant_type, quant_func, config, layer): - ctx.save_for_backward(tensor, torch.tensor(quant_type)) + ctx.save_for_backward(tensor, torch.Tensor([quant_type])) return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name) - + @classmethod def backward(cls, ctx, grad_output): tensor, quant_type = ctx.saved_variables @@ -361,6 +364,7 @@ def backward(cls, ctx, grad_output): def _check_weight(module): try: - return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) + return isinstance(module.weight.data, torch.Tensor) + # return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) except AttributeError: return False From 6c5b98555b642da797c1cd25f9b36c300b3b008e Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Thu, 12 Dec 2019 11:02:00 +0800 Subject: [PATCH 05/20] add docstring && add enum class for quant type --- .../compression/torch/builtin_quantizers.py | 4 +-- .../pynni/nni/compression/torch/compressor.py | 36 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index 75c39a8722..a20a4789e0 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -3,7 +3,7 @@ import logging import torch -from .compressor import Quantizer, QuantGrad +from .compressor import Quantizer, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] @@ -250,7 +250,7 @@ def quantize(self, input_ri, q_bits): class ClipGrad(QuantGrad): @staticmethod def quant_backward(tensor, grad_output, quant_type): - if quant_type == 2: # quant_type is quant_output + if quant_type == QuantType.QUANT_OUTPUT: grad_output[torch.abs(tensor) > 1] = 0 return grad_output diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 373307dcea..874adaab1c 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -330,25 +330,54 @@ def _instrument_layer(self, layer, config): def new_forward(*inputs): if 'input' in config["quant_types"]: - inputs = self.quant_grad.apply(inputs, 0, self.quantize_input, config, layer) + inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer) if 'weight' in config["quant_types"] and _check_weight(layer.module): - new_weight = self.quant_grad.apply(layer.module.old_weight, 1, self.quantize_weight, config, layer) + new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer) layer.module.weight = new_weight result = layer._forward(*inputs) else: result = layer._forward(*inputs) if 'output' in config["quant_types"]: - result = self.quant_grad.apply(result, 2, self.quantize_output, config, layer) + result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer) return result layer.module.forward = new_forward +class QuantType: + """ + Enum class for quantization type. + """ + QUANT_INPUT = 0 + QUANT_WEIGHT = 1 + QUANT_OUTPUT = 2 class QuantGrad(torch.autograd.Function): + """ + Base class for overriding backward function of quantization operation. + """ @staticmethod def quant_backward(tensor, grad_output, quant_type): + """ + This method should be overrided by subclass to provide customized backward function, + default implementation is Straight-Through Estimator + + Parameters + ---------- + tensor : Tensor + input of quantization operation + grad_output : Tensor + gradient of the output of quantization operation + quant_type : QuantType + the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, + you can define different behavior for different types. + + Returns + ------- + tensor + gradient of the input of quantization operation + """ return grad_output @staticmethod @@ -365,6 +394,5 @@ def backward(cls, ctx, grad_output): def _check_weight(module): try: return isinstance(module.weight.data, torch.Tensor) - # return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) except AttributeError: return False From ecced2cd0c783bf7a23d73b008ee20dc1327ba8c Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Fri, 13 Dec 2019 11:22:04 +0800 Subject: [PATCH 06/20] add BNN quantizer and cifar10 example --- .../model_compress/BNN_quantizer_cifar10.py | 154 ++++++++++++++++++ .../compression/torch/builtin_quantizers.py | 15 +- .../pynni/nni/compression/torch/compressor.py | 2 +- 3 files changed, 159 insertions(+), 12 deletions(-) create mode 100644 examples/model_compress/BNN_quantizer_cifar10.py diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py new file mode 100644 index 0000000000..1f3ef1f605 --- /dev/null +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -0,0 +1,154 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.compression.torch import BNNQuantizer + + +class VGG_Cifar10(nn.Module): + def __init__(self, num_classes=1000): + super(VGG_Cifar10, self).__init__() + self.infl_ratio=1 # may be set to large number ? + self.features = nn.Sequential( + nn.Conv2d(3, 128*self.infl_ratio, kernel_size=3, stride=1, padding=1, + bias=False), + nn.BatchNorm2d(128*self.infl_ratio), + nn.Hardtanh(inplace=True), + + nn.Conv2d(128*self.infl_ratio, 128*self.infl_ratio, kernel_size=3, padding=1, bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(128*self.infl_ratio), + nn.Hardtanh(inplace=True), + + nn.Conv2d(128*self.infl_ratio, 256*self.infl_ratio, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(256*self.infl_ratio), + nn.Hardtanh(inplace=True), + + + nn.Conv2d(256*self.infl_ratio, 256*self.infl_ratio, kernel_size=3, padding=1, bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(256*self.infl_ratio), + nn.Hardtanh(inplace=True), + + + nn.Conv2d(256*self.infl_ratio, 512*self.infl_ratio, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(512*self.infl_ratio), + nn.Hardtanh(inplace=True), + + + nn.Conv2d(512*self.infl_ratio, 512, kernel_size=3, padding=1, bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(512), + nn.Hardtanh(inplace=True) + ) + + self.classifier = nn.Sequential( + nn.Linear(512 * 4 * 4, 1024, bias=False), + nn.BatchNorm1d(1024), + nn.Hardtanh(inplace=True), + nn.Linear(1024, 1024, bias=False), + nn.BatchNorm1d(1024), + nn.Hardtanh(inplace=True), + nn.Linear(1024, num_classes), # do not quantize output + nn.BatchNorm1d(num_classes, affine=False), + nn.LogSoftmax() + ) + + + def forward(self, x): + x = self.features(x) + x = x.view(-1, 512 * 4 * 4) + x = self.classifier(x) + return x + + +def train(model, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + for name, param in model.named_parameters(): + if name.endswith('old_weight'): + param = param.clamp(-1,1) + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + acc = 100 * correct / len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, acc)) + return acc + + +def main(): + torch.manual_seed(0) + device = torch.device('cuda') + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=200, shuffle=False) + + model = VGG_Cifar10(num_classes=10) + model.to(device) + + # print(model) + + configure_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 1, + 'op_types': ['Conv2d', 'Linear'], + 'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3'] + },{'quant_types': ['output'], + 'quant_bits': 1, + 'op_types': ['Hardtanh'], + 'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5'] + }] + + quantizer = BNNQuantizer(model, configure_list) + model = quantizer.compress() + print(model) + # for name, p in model.named_parameters(): + # print(name) + + print('=' * 10 + 'train' + '=' * 10) + optimizer_finetune = torch.optim.Adam(model.parameters(), lr=1e-3) + best_top1 = 0 + for epoch in range(120): + quantizer.update_epoch(epoch) + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer_finetune) + top1 = test(model, device, test_loader) + if top1 > best_top1: + best_top1 = top1 + print(best_top1) + + +if __name__ == '__main__': + main() diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index a20a4789e0..7d1e690b14 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -100,7 +100,7 @@ def get_bits_length(config, quant_type): class QAT_Quantizer(Quantizer): - """Quantizer using the DoReFa scheme, as defined in: + """Quantizer defined in: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ @@ -227,10 +227,6 @@ class DoReFaQuantizer(Quantizer): (https://arxiv.org/abs/1606.06160) """ def __init__(self, model, config_list): - """ - config_list: supported keys: - - q_bits - """ super().__init__(model, config_list) def quantize_weight(self, weight, config, **kwargs): @@ -256,18 +252,15 @@ def quant_backward(tensor, grad_output, quant_type): class BNNQuantizer(Quantizer): - """BNNQuantizer + """Binarized Neural Networks, as defined in: + Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 + (https://arxiv.org/abs/1602.02830) """ def __init__(self, model, config_list): - """ - config_list: supported keys: - - q_bits - """ super().__init__(model, config_list) self.quant_grad = ClipGrad def quantize_weight(self, weight, config, **kwargs): - # module.weight.data = torch.clamp(weight, -1.0, 1.0) out = torch.sign(weight) # remove zeros out[out == 0] = 1 diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 874adaab1c..3f0e9da1fc 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -331,7 +331,7 @@ def _instrument_layer(self, layer, config): def new_forward(*inputs): if 'input' in config["quant_types"]: inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer) - + if 'weight' in config["quant_types"] and _check_weight(layer.module): new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer) layer.module.weight = new_weight From 30307ee6fcf6936a209298222d0e4198903c2fe2 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Fri, 13 Dec 2019 15:31:27 +0800 Subject: [PATCH 07/20] fix pylint --- examples/model_compress/BNN_quantizer_cifar10.py | 10 +++++----- src/sdk/pynni/nni/compression/torch/compressor.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py index 1f3ef1f605..92d7f1cc72 100644 --- a/examples/model_compress/BNN_quantizer_cifar10.py +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -1,4 +1,3 @@ -import math import torch import torch.nn as nn import torch.nn.functional as F @@ -9,7 +8,7 @@ class VGG_Cifar10(nn.Module): def __init__(self, num_classes=1000): super(VGG_Cifar10, self).__init__() - self.infl_ratio=1 # may be set to large number ? + self.infl_ratio = 1 # may be set to large number ? self.features = nn.Sequential( nn.Conv2d(3, 128*self.infl_ratio, kernel_size=3, stride=1, padding=1, bias=False), @@ -74,7 +73,7 @@ def train(model, device, train_loader, optimizer): optimizer.step() for name, param in model.named_parameters(): if name.endswith('old_weight'): - param = param.clamp(-1,1) + param = param.clamp(-1, 1) if batch_idx % 100 == 0: print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) @@ -117,7 +116,7 @@ def main(): model = VGG_Cifar10(num_classes=10) model.to(device) - + # print(model) configure_list = [{ @@ -125,7 +124,8 @@ def main(): 'quant_bits': 1, 'op_types': ['Conv2d', 'Linear'], 'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3'] - },{'quant_types': ['output'], + }, { + 'quant_types': ['output'], 'quant_bits': 1, 'op_types': ['Hardtanh'], 'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5'] diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 3f0e9da1fc..874adaab1c 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -331,7 +331,7 @@ def _instrument_layer(self, layer, config): def new_forward(*inputs): if 'input' in config["quant_types"]: inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer) - + if 'weight' in config["quant_types"] and _check_weight(layer.module): new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer) layer.module.weight = new_weight From 0a3954b01e7964e8f4deb3fd3eed3384fa9969ad Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Mon, 16 Dec 2019 10:47:15 +0800 Subject: [PATCH 08/20] update --- .../model_compress/BNN_quantizer_cifar10.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py index 92d7f1cc72..79c1357227 100644 --- a/examples/model_compress/BNN_quantizer_cifar10.py +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -8,35 +8,34 @@ class VGG_Cifar10(nn.Module): def __init__(self, num_classes=1000): super(VGG_Cifar10, self).__init__() - self.infl_ratio = 1 # may be set to large number ? self.features = nn.Sequential( - nn.Conv2d(3, 128*self.infl_ratio, kernel_size=3, stride=1, padding=1, + nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(128*self.infl_ratio), + nn.BatchNorm2d(128), nn.Hardtanh(inplace=True), - nn.Conv2d(128*self.infl_ratio, 128*self.infl_ratio, kernel_size=3, padding=1, bias=False), + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), nn.MaxPool2d(kernel_size=2, stride=2), - nn.BatchNorm2d(128*self.infl_ratio), + nn.BatchNorm2d(128), nn.Hardtanh(inplace=True), - nn.Conv2d(128*self.infl_ratio, 256*self.infl_ratio, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(256*self.infl_ratio), + nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(256), nn.Hardtanh(inplace=True), - nn.Conv2d(256*self.infl_ratio, 256*self.infl_ratio, kernel_size=3, padding=1, bias=False), + nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), nn.MaxPool2d(kernel_size=2, stride=2), - nn.BatchNorm2d(256*self.infl_ratio), + nn.BatchNorm2d(256), nn.Hardtanh(inplace=True), - nn.Conv2d(256*self.infl_ratio, 512*self.infl_ratio, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(512*self.infl_ratio), + nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(512), nn.Hardtanh(inplace=True), - nn.Conv2d(512*self.infl_ratio, 512, kernel_size=3, padding=1, bias=False), + nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(512), nn.Hardtanh(inplace=True) @@ -117,8 +116,6 @@ def main(): model = VGG_Cifar10(num_classes=10) model.to(device) - # print(model) - configure_list = [{ 'quant_types': ['weight'], 'quant_bits': 1, @@ -133,9 +130,6 @@ def main(): quantizer = BNNQuantizer(model, configure_list) model = quantizer.compress() - print(model) - # for name, p in model.named_parameters(): - # print(name) print('=' * 10 + 'train' + '=' * 10) optimizer_finetune = torch.optim.Adam(model.parameters(), lr=1e-3) From ecf8dee01be8776b4ce5a38145c338e0ab6158fe Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Mon, 16 Dec 2019 11:06:21 +0800 Subject: [PATCH 09/20] revert --- .../compression/torch/builtin_quantizers.py | 45 ++----------------- 1 file changed, 4 insertions(+), 41 deletions(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index a20a4789e0..7f9c3b144a 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -3,9 +3,9 @@ import logging import torch -from .compressor import Quantizer, QuantGrad, QuantType +from .compressor import Quantizer -__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] +__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer'] logger = logging.getLogger(__name__) @@ -100,7 +100,7 @@ def get_bits_length(config, quant_type): class QAT_Quantizer(Quantizer): - """Quantizer using the DoReFa scheme, as defined in: + """Quantizer defined in: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ @@ -227,10 +227,6 @@ class DoReFaQuantizer(Quantizer): (https://arxiv.org/abs/1606.06160) """ def __init__(self, model, config_list): - """ - config_list: supported keys: - - q_bits - """ super().__init__(model, config_list) def quantize_weight(self, weight, config, **kwargs): @@ -244,37 +240,4 @@ def quantize_weight(self, weight, config, **kwargs): def quantize(self, input_ri, q_bits): scale = pow(2, q_bits)-1 output = torch.round(input_ri*scale)/scale - return output - - -class ClipGrad(QuantGrad): - @staticmethod - def quant_backward(tensor, grad_output, quant_type): - if quant_type == QuantType.QUANT_OUTPUT: - grad_output[torch.abs(tensor) > 1] = 0 - return grad_output - - -class BNNQuantizer(Quantizer): - """BNNQuantizer - """ - def __init__(self, model, config_list): - """ - config_list: supported keys: - - q_bits - """ - super().__init__(model, config_list) - self.quant_grad = ClipGrad - - def quantize_weight(self, weight, config, **kwargs): - # module.weight.data = torch.clamp(weight, -1.0, 1.0) - out = torch.sign(weight) - # remove zeros - out[out == 0] = 1 - return out - - def quantize_output(self, output, config, **kwargs): - out = torch.sign(output) - # remove zeros - out[out == 0] = 1 - return out \ No newline at end of file + return output \ No newline at end of file From 884a4730fec266fa9701051c8703ed9f9438a146 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Thu, 19 Dec 2019 13:32:18 +0800 Subject: [PATCH 10/20] remove update_epoch --- examples/model_compress/BNN_quantizer_cifar10.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py index 79c1357227..fc14303ba5 100644 --- a/examples/model_compress/BNN_quantizer_cifar10.py +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -135,7 +135,6 @@ def main(): optimizer_finetune = torch.optim.Adam(model.parameters(), lr=1e-3) best_top1 = 0 for epoch in range(120): - quantizer.update_epoch(epoch) print('# Epoch {} #'.format(epoch)) train(model, device, train_loader, optimizer_finetune) top1 = test(model, device, test_loader) From f5529f25d2376937fc6d8765409813aac0de14a6 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Sun, 22 Dec 2019 22:34:57 +0800 Subject: [PATCH 11/20] update doc --- docs/en_US/Compressor/Quantizer.md | 71 ++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index 5dd99e3432..42c7262bea 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -1,6 +1,5 @@ Quantizer on NNI Compressor === - ## Naive Quantizer We provide Naive Quantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm without any configure. @@ -53,10 +52,18 @@ You can view example for more information #### User configuration for QAT Quantizer * **quant_types:** : list of string -type of quantization you want to apply, currently support 'weight', 'input', 'output' +type of quantization you want to apply, currently support 'weight', 'input', 'output'. + +* **op_types:** list of string +specify the type of modules that will be quantized. eg. 'Conv2D' + +* **op_names:** list of string +specify the name of modules that will be quantized. eg. 'conv1' + * **quant_bits:** int or dict of {str : int} -bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8}, -when the type is int, all quantization types share same bits length +bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, +when the type is int, all quantization types share same bits length. + * **quant_start_step:** int disable quantization until model are run by certain number of steps, this allows the network to enter a more stable state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 @@ -90,3 +97,59 @@ You can view example for more information #### User configuration for DoReFa Quantizer * **q_bits:** This is to specify the q_bits operations to be quantized to + + +## BNN Quantizer +In [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), + +>We introduce a method to train Binarized Neural Networks (BNNs) - neural networks with binary weights and activations at run-time. At training-time the binary weights and activations are used for computing the parameters gradients. During the forward pass, BNNs drastically reduce memory size and accesses, and replace most arithmetic operations with bit-wise operations, which is expected to substantially improve power-efficiency. + + +### Usage + +PyTorch code +```python +from nni.compression.torch import BNNQuantizer +model = VGG_Cifar10(num_classes=10) + +configure_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 1, + 'op_types': ['Conv2d', 'Linear'], + 'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3'] +}, { + 'quant_types': ['output'], + 'quant_bits': 1, + 'op_types': ['Hardtanh'], + 'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5'] +}] + +quantizer = BNNQuantizer(model, configure_list) +model = quantizer.compress() +``` + +You can view example [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) for more information. + +#### User configuration for BNN Quantizer +* **quant_types:** : list of string +type of quantization you want to apply, currently support 'weight', 'input', 'output'. + +* **op_types:** list of string +specify the type of modules that will be quantized. eg. 'Conv2D' + +* **op_names:** list of string +specify the name of modules that will be quantized. eg. 'conv1' + +* **quant_bits:** int or dict of {str : int} +bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, +when the type is int, all quantization types share same bits length. + +### Experiment +We implemented one of the experiments in [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), we quantized the **VGGNet** for CIFAR-10 in the paper. Our experiments results are as follows: + +| Model | Accuracy | +| ------------- | --------- | +| VGGNet | 85.1% | + + +The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) \ No newline at end of file From 768430b451e23a4aca531d1fbf768848847c7725 Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Sun, 22 Dec 2019 22:37:35 +0800 Subject: [PATCH 12/20] update doc --- docs/en_US/Compressor/Quantizer.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index 42c7262bea..00a036981f 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -52,19 +52,24 @@ You can view example for more information #### User configuration for QAT Quantizer * **quant_types:** : list of string + type of quantization you want to apply, currently support 'weight', 'input', 'output'. * **op_types:** list of string + specify the type of modules that will be quantized. eg. 'Conv2D' * **op_names:** list of string + specify the name of modules that will be quantized. eg. 'conv1' * **quant_bits:** int or dict of {str : int} + bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, when the type is int, all quantization types share same bits length. * **quant_start_step:** int + disable quantization until model are run by certain number of steps, this allows the network to enter a more stable state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 @@ -132,15 +137,19 @@ You can view example [examples/model_compress/BNN_quantizer_cifar10.py]( https:/ #### User configuration for BNN Quantizer * **quant_types:** : list of string + type of quantization you want to apply, currently support 'weight', 'input', 'output'. * **op_types:** list of string + specify the type of modules that will be quantized. eg. 'Conv2D' * **op_names:** list of string + specify the name of modules that will be quantized. eg. 'conv1' * **quant_bits:** int or dict of {str : int} + bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, when the type is int, all quantization types share same bits length. From 940378399bef1edff18eec6fb06515a425b5c59a Mon Sep 17 00:00:00 2001 From: Cjkkkk Date: Sun, 22 Dec 2019 22:43:03 +0800 Subject: [PATCH 13/20] update doc --- docs/en_US/Compressor/Quantizer.md | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index 00a036981f..0037be8f8d 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -83,17 +83,14 @@ In [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bit ### Usage To implement DoReFa Quantizer, you can add code below before your training code -Tensorflow code -```python -from nni.compressors.tensorflow import DoReFaQuantizer -config_list = [{ 'q_bits': 8, 'op_types': 'default' }] -quantizer = DoReFaQuantizer(tf.get_default_graph(), config_list) -quantizer.compress() -``` PyTorch code ```python from nni.compressors.torch import DoReFaQuantizer -config_list = [{ 'q_bits': 8, 'op_types': 'default' }] +config_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 8, + 'op_types': 'default' +}] quantizer = DoReFaQuantizer(model, config_list) quantizer.compress() ``` @@ -101,7 +98,22 @@ quantizer.compress() You can view example for more information #### User configuration for DoReFa Quantizer -* **q_bits:** This is to specify the q_bits operations to be quantized to +* **quant_types:** : list of string + +type of quantization you want to apply, currently support 'weight', 'input', 'output'. + +* **op_types:** list of string + +specify the type of modules that will be quantized. eg. 'Conv2D' + +* **op_names:** list of string + +specify the name of modules that will be quantized. eg. 'conv1' + +* **quant_bits:** int or dict of {str : int} + +bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, +when the type is int, all quantization types share same bits length. ## BNN Quantizer From afa84b5a01399c29fda32d3fc92d65e8d469368b Mon Sep 17 00:00:00 2001 From: xuehui Date: Mon, 23 Dec 2019 14:36:47 +0800 Subject: [PATCH 14/20] update example --- .../model_compress/BNN_quantizer_cifar10.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py index fc14303ba5..56a5590741 100644 --- a/examples/model_compress/BNN_quantizer_cifar10.py +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -9,35 +9,34 @@ class VGG_Cifar10(nn.Module): def __init__(self, num_classes=1000): super(VGG_Cifar10, self).__init__() self.features = nn.Sequential( - nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, - bias=False), - nn.BatchNorm2d(128), + nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), nn.MaxPool2d(kernel_size=2, stride=2), - nn.BatchNorm2d(128), + nn.BatchNorm2d(128, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(256), + nn.BatchNorm2d(256, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), nn.MaxPool2d(kernel_size=2, stride=2), - nn.BatchNorm2d(256), + nn.BatchNorm2d(256, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(512), + nn.BatchNorm2d(512, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), nn.MaxPool2d(kernel_size=2, stride=2), - nn.BatchNorm2d(512), + nn.BatchNorm2d(512, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True) ) @@ -49,8 +48,7 @@ def __init__(self, num_classes=1000): nn.BatchNorm1d(1024), nn.Hardtanh(inplace=True), nn.Linear(1024, num_classes), # do not quantize output - nn.BatchNorm1d(num_classes, affine=False), - nn.LogSoftmax() + nn.BatchNorm1d(num_classes, affine=False) ) @@ -95,6 +93,12 @@ def test(model, device, test_loader): test_loss, acc)) return acc +def adjust_learning_rate(optimizer, epoch): + update_list = [55, 100, 150,200,400,600] + if epoch in update_list: + for param_group in optimizer.param_groups: + param_group['lr'] = param_group['lr'] * 0.1 + return def main(): torch.manual_seed(0) @@ -120,7 +124,7 @@ def main(): 'quant_types': ['weight'], 'quant_bits': 1, 'op_types': ['Conv2d', 'Linear'], - 'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3'] + 'op_names': ['features.3', 'features.7', 'features.10', 'features.14', 'classifier.0', 'classifier.3'] }, { 'quant_types': ['output'], 'quant_bits': 1, @@ -132,11 +136,12 @@ def main(): model = quantizer.compress() print('=' * 10 + 'train' + '=' * 10) - optimizer_finetune = torch.optim.Adam(model.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) best_top1 = 0 - for epoch in range(120): + for epoch in range(400): print('# Epoch {} #'.format(epoch)) - train(model, device, train_loader, optimizer_finetune) + train(model, device, train_loader, optimizer) + adjust_learning_rate(optimizer, epoch) top1 = test(model, device, test_loader) if top1 > best_top1: best_top1 = top1 From 4aa325a37b9dbacab7a6fadcb369ccec8d4b8b4e Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Mon, 23 Dec 2019 16:11:53 +0800 Subject: [PATCH 15/20] fix punc & update comments --- .../pynni/nni/compression/torch/compressor.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 874adaab1c..e7965d837b 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -266,7 +266,7 @@ def quantize_weight(self, weight, config, op, op_type, op_name): config : dict the configuration for weight quantization """ - raise NotImplementedError("Quantizer must overload quantize_weight()") + raise NotImplementedError('Quantizer must overload quantize_weight()') def quantize_output(self, output, config, op, op_type, op_name): """ @@ -280,7 +280,7 @@ def quantize_output(self, output, config, op, op_type, op_name): config : dict the configuration for output quantization """ - raise NotImplementedError("Quantizer must overload quantize_output()") + raise NotImplementedError('Quantizer must overload quantize_output()') def quantize_input(self, *inputs, config, op, op_type, op_name): """ @@ -294,7 +294,7 @@ def quantize_input(self, *inputs, config, op, op_type, op_name): config : dict the configuration for inputs quantization """ - raise NotImplementedError("Quantizer must overload quantize_input()") + raise NotImplementedError('Quantizer must overload quantize_input()') def _instrument_layer(self, layer, config): @@ -309,37 +309,40 @@ def _instrument_layer(self, layer, config): the configuration for quantization """ assert layer._forward is None, 'Each model can only be compressed once' - assert "quant_types" in config, 'must provide quant_types in config' - assert isinstance(config["quant_types"], list), 'quant_types must be list type' - assert "quant_bits" in config, 'must provide quant_bits in config' - assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type' + assert 'quant_types' in config, 'must provide quant_types in config' + assert isinstance(config['quant_types'], list), 'quant_types must be list type' + assert 'quant_bits' in config, 'must provide quant_bits in config' + assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type' - if isinstance(config["quant_bits"], dict): - for quant_type in config["quant_types"]: - assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type + if isinstance(config['quant_bits'], dict): + for quant_type in config['quant_types']: + assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type - if 'weight' in config["quant_types"]: + if 'weight' in config['quant_types']: if not _check_weight(layer.module): _logger.warning('Module %s does not have parameter "weight"', layer.name) else: - layer.module.register_parameter("old_weight", torch.nn.Parameter(layer.module.weight)) - delattr(layer.module, "weight") + # old_weight is used to store origin weight and weight is used to store quantized weight + # the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf + # if weight is leaf , then old_weight can not be updated. + layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight)) + delattr(layer.module, 'weight') layer.module.register_buffer('weight', layer.module.old_weight) layer._forward = layer.module.forward def new_forward(*inputs): - if 'input' in config["quant_types"]: + if 'input' in config['quant_types']: inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer) - if 'weight' in config["quant_types"] and _check_weight(layer.module): + if 'weight' in config['quant_types'] and _check_weight(layer.module): new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer) layer.module.weight = new_weight result = layer._forward(*inputs) else: result = layer._forward(*inputs) - if 'output' in config["quant_types"]: + if 'output' in config['quant_types']: result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer) return result From 49b87ab90d955d173bcbb4ad2956b4294e585c51 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Mon, 23 Dec 2019 19:31:03 +0800 Subject: [PATCH 16/20] update experiment results --- docs/en_US/Compressor/Quantizer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index 0037be8f8d..bef163cb12 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -170,7 +170,7 @@ We implemented one of the experiments in [Binarized Neural Networks: Training De | Model | Accuracy | | ------------- | --------- | -| VGGNet | 85.1% | +| VGGNet | 86.8% | The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) \ No newline at end of file From 841f0835bf51c4315665108f8f838de39483ae5a Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Mon, 23 Dec 2019 23:12:58 +0800 Subject: [PATCH 17/20] update experiment results --- docs/en_US/Compressor/Quantizer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index bef163cb12..87c8e88168 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -170,7 +170,7 @@ We implemented one of the experiments in [Binarized Neural Networks: Training De | Model | Accuracy | | ------------- | --------- | -| VGGNet | 86.8% | +| VGGNet | 86.93% | The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) \ No newline at end of file From 38613a2cb45c486f2c01afcbca2c976519cf13ee Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Tue, 24 Dec 2019 11:18:02 +0800 Subject: [PATCH 18/20] update doc --- docs/en_US/Compressor/Quantizer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md index 87c8e88168..67791117e1 100644 --- a/docs/en_US/Compressor/Quantizer.md +++ b/docs/en_US/Compressor/Quantizer.md @@ -170,7 +170,7 @@ We implemented one of the experiments in [Binarized Neural Networks: Training De | Model | Accuracy | | ------------- | --------- | -| VGGNet | 86.93% | +| VGGNet | 86.93% | The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) \ No newline at end of file From 471b910ee88cc791ebc86690bb1a0ad5544869d6 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Tue, 24 Dec 2019 11:38:56 +0800 Subject: [PATCH 19/20] fix test --- src/sdk/pynni/nni/compression/torch/builtin_quantizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py index 5c631cc1c1..2204428574 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_quantizers.py @@ -3,7 +3,7 @@ import logging import torch -from .compressor import Quantizer +from .compressor import Quantizer, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer'] From 54eb9111fd83096f1a7174ddbc0ba939801ea6d6 Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Tue, 24 Dec 2019 11:44:12 +0800 Subject: [PATCH 20/20] add MIT License --- examples/model_compress/BNN_quantizer_cifar10.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/model_compress/BNN_quantizer_cifar10.py b/examples/model_compress/BNN_quantizer_cifar10.py index 56a5590741..d4908885c3 100644 --- a/examples/model_compress/BNN_quantizer_cifar10.py +++ b/examples/model_compress/BNN_quantizer_cifar10.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import torch import torch.nn as nn import torch.nn.functional as F @@ -94,7 +97,7 @@ def test(model, device, test_loader): return acc def adjust_learning_rate(optimizer, epoch): - update_list = [55, 100, 150,200,400,600] + update_list = [55, 100, 150, 200, 400, 600] if epoch in update_list: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.1