Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

AutoML for model compression #2573

Merged
merged 125 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from 112 commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
4b266f3
Merge pull request #47 from microsoft/master
chicm-ms Nov 15, 2019
237ff4b
Merge pull request #48 from microsoft/master
chicm-ms Nov 21, 2019
682be01
Merge pull request #49 from microsoft/master
chicm-ms Nov 25, 2019
133af82
Merge pull request #50 from microsoft/master
chicm-ms Nov 25, 2019
71a8a25
Merge pull request #51 from microsoft/master
chicm-ms Nov 26, 2019
d2a73bc
Merge pull request #52 from microsoft/master
chicm-ms Nov 26, 2019
198cf5e
Merge pull request #53 from microsoft/master
chicm-ms Dec 5, 2019
cdbfaf9
Merge pull request #54 from microsoft/master
chicm-ms Dec 6, 2019
7e9b29e
Merge pull request #55 from microsoft/master
chicm-ms Dec 10, 2019
d00c46d
Merge pull request #56 from microsoft/master
chicm-ms Dec 10, 2019
de7d1fa
Merge pull request #57 from microsoft/master
chicm-ms Dec 11, 2019
1835ab0
Merge pull request #58 from microsoft/master
chicm-ms Dec 12, 2019
24fead6
Merge pull request #59 from microsoft/master
chicm-ms Dec 20, 2019
0b7321e
Merge pull request #60 from microsoft/master
chicm-ms Dec 23, 2019
60058d4
Merge pull request #61 from microsoft/master
chicm-ms Dec 23, 2019
b111a55
Merge pull request #62 from microsoft/master
chicm-ms Dec 24, 2019
611c337
Merge pull request #63 from microsoft/master
chicm-ms Dec 30, 2019
4a1f14a
Merge pull request #64 from microsoft/master
chicm-ms Jan 10, 2020
7a9e604
Merge pull request #65 from microsoft/master
chicm-ms Jan 14, 2020
b8035b0
Merge pull request #66 from microsoft/master
chicm-ms Feb 4, 2020
47567d3
Merge pull request #67 from microsoft/master
chicm-ms Feb 10, 2020
614d427
Merge pull request #68 from microsoft/master
chicm-ms Feb 10, 2020
a0d9ed6
Merge pull request #69 from microsoft/master
chicm-ms Feb 11, 2020
22dc1ad
Merge pull request #70 from microsoft/master
chicm-ms Feb 19, 2020
0856813
Merge pull request #71 from microsoft/master
chicm-ms Feb 22, 2020
9e97bed
Merge pull request #72 from microsoft/master
chicm-ms Feb 25, 2020
16a1b27
Merge pull request #73 from microsoft/master
chicm-ms Mar 3, 2020
e246633
Merge pull request #74 from microsoft/master
chicm-ms Mar 4, 2020
0439bc1
Merge pull request #75 from microsoft/master
chicm-ms Mar 17, 2020
8b5613a
Merge pull request #76 from microsoft/master
chicm-ms Mar 18, 2020
43e8d31
Merge pull request #77 from microsoft/master
chicm-ms Mar 22, 2020
aae448e
Merge pull request #78 from microsoft/master
chicm-ms Mar 25, 2020
7095716
Merge pull request #79 from microsoft/master
chicm-ms Mar 25, 2020
c51263a
Merge pull request #80 from microsoft/master
chicm-ms Apr 11, 2020
9953c70
Merge pull request #81 from microsoft/master
chicm-ms Apr 14, 2020
f9136c4
Merge pull request #82 from microsoft/master
chicm-ms Apr 16, 2020
b384ad2
Merge pull request #83 from microsoft/master
chicm-ms Apr 20, 2020
ff592dd
Merge pull request #84 from microsoft/master
chicm-ms May 12, 2020
0b5378f
Merge pull request #85 from microsoft/master
chicm-ms May 18, 2020
a53e0b0
Merge pull request #86 from microsoft/master
chicm-ms May 25, 2020
3ea0b89
Merge pull request #87 from microsoft/master
chicm-ms May 28, 2020
cf3fb20
Merge pull request #88 from microsoft/master
chicm-ms May 28, 2020
7f4cdcd
Merge pull request #89 from microsoft/master
chicm-ms Jun 4, 2020
574db2c
Merge pull request #90 from microsoft/master
chicm-ms Jun 15, 2020
e2373ce
Original amc
chicm-ms Jun 16, 2020
e93dc9b
updates
chicm-ms Jun 16, 2020
eae9df6
first commit
chicm-ms Jun 17, 2020
e9e57b6
updates
chicm-ms Jun 18, 2020
295d6b7
move val_func out to user code
chicm-ms Jun 18, 2020
c1a4629
updates
chicm-ms Jun 18, 2020
32bedcc
Merge pull request #91 from microsoft/master
chicm-ms Jun 21, 2020
6155aa4
Merge pull request #92 from microsoft/master
chicm-ms Jun 22, 2020
8139c9c
Merge pull request #93 from microsoft/master
chicm-ms Jun 23, 2020
43419d7
Merge pull request #94 from microsoft/master
chicm-ms Jun 28, 2020
6b6ee55
Merge pull request #95 from microsoft/master
chicm-ms Jun 28, 2020
1b975e0
Merge pull request #96 from microsoft/master
chicm-ms Jun 28, 2020
c8f3c5d
Merge pull request #97 from microsoft/master
chicm-ms Jun 29, 2020
4c306f0
Merge pull request #98 from microsoft/master
chicm-ms Jun 30, 2020
64de4c2
Merge pull request #99 from microsoft/master
chicm-ms Jun 30, 2020
0e5d3ac
Merge pull request #100 from microsoft/master
chicm-ms Jul 1, 2020
4a52608
Merge pull request #101 from microsoft/master
chicm-ms Jul 3, 2020
208b1ee
Merge pull request #102 from microsoft/master
chicm-ms Jul 8, 2020
e7b1a2e
Merge pull request #103 from microsoft/master
chicm-ms Jul 10, 2020
57bcc85
Merge pull request #104 from microsoft/master
chicm-ms Jul 22, 2020
fbd7805
use nni pruner
chicm-ms Jul 23, 2020
1632e22
change input prune to output prune
chicm-ms Jul 23, 2020
15a0709
prune output channel
chicm-ms Jul 23, 2020
1a9ee26
updates
chicm-ms Jul 27, 2020
f5aae3d
updates
chicm-ms Jul 27, 2020
b63cba0
updates
chicm-ms Jul 27, 2020
39957cb
updates
chicm-ms Jul 27, 2020
4219df5
updates
chicm-ms Jul 27, 2020
01d252f
updates
chicm-ms Jul 27, 2020
4833229
updates
chicm-ms Jul 27, 2020
c626131
AMC weight masker implementation
Jul 29, 2020
f112178
Support linear reconstruct
chicm-ms Jul 29, 2020
030f5ef
Merge pull request #105 from microsoft/master
chicm-ms Jul 29, 2020
00010c3
Merge branch 'master' into amc
chicm-ms Jul 29, 2020
4e9fa3d
support export and mobilenetv2
chicm-ms Aug 1, 2020
773b2e5
refactor export
chicm-ms Aug 1, 2020
1bd16a7
refactor parameters
chicm-ms Aug 1, 2020
0e06bb0
updates
chicm-ms Aug 1, 2020
7c2915d
support cpu
chicm-ms Aug 2, 2020
0d73011
unit test
chicm-ms Aug 2, 2020
d04598c
fix pylint
chicm-ms Aug 2, 2020
058c8b7
Merge pull request #106 from microsoft/master
chicm-ms Aug 2, 2020
34b3dee
Merge branch 'master' into amc
chicm-ms Aug 2, 2020
4a49101
install tensorboardX for UT
chicm-ms Aug 2, 2020
0203c9a
updates
chicm-ms Aug 2, 2020
7afdad0
fix export layer meta-data
chicm-ms Aug 4, 2020
b17234a
refactor finetune
chicm-ms Aug 4, 2020
0814585
Fix pylint error
chicm-ms Aug 4, 2020
eb4999a
fix pylint errors
chicm-ms Aug 4, 2020
16f8ab8
documentation
chicm-ms Aug 5, 2020
d7dec5e
refactor export
chicm-ms Aug 5, 2020
b19493b
fix doc build error
chicm-ms Aug 6, 2020
9abd8c8
Merge pull request #107 from microsoft/master
chicm-ms Aug 10, 2020
eabe041
updates per comments
chicm-ms Aug 10, 2020
06e9adc
updates
chicm-ms Aug 10, 2020
26e2cd2
updates
chicm-ms Aug 10, 2020
e3c09ae
updates
chicm-ms Aug 11, 2020
12b3dbf
updates
chicm-ms Aug 11, 2020
bd34a8b
updates
chicm-ms Aug 11, 2020
d486f82
updates
chicm-ms Aug 11, 2020
2d05b6b
fix docstring format
chicm-ms Aug 11, 2020
e1500ca
updates
chicm-ms Aug 11, 2020
13c6623
Merge pull request #108 from microsoft/master
chicm-ms Aug 11, 2020
33ae5d6
Merge branch 'master' into amc
chicm-ms Aug 11, 2020
159e3b7
updates
chicm-ms Aug 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
set -e
sudo apt-get install -y pandoc
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==2.2.0 --user
python3 -m pip install keras==2.4.2 --user
python3 -m pip install gym onnx peewee thop --user
Expand Down Expand Up @@ -68,6 +69,7 @@ jobs:
- script: |
set -e
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx peewee --user
Expand Down Expand Up @@ -117,6 +119,7 @@ jobs:
set -e
# pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user
brew install swig@3
rm -f /usr/local/bin/swig
Expand Down Expand Up @@ -144,6 +147,7 @@ jobs:
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorboardX==1.9
python -m pip install tensorflow==1.15.2 --user
displayName: 'Install dependencies'
- script: |
Expand Down
34 changes: 34 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* [NetAdapt Pruner](#netadapt-pruner)
* [SimulatedAnnealing Pruner](#simulatedannealing-pruner)
* [AutoCompress Pruner](#autocompress-pruner)
* [AutoML for Model Compression Pruner](#automl-for-model-compression-pruner)

**Others**
* [ADMM Pruner](#admm-pruner)
Expand Down Expand Up @@ -497,6 +498,39 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AutoCompressPruner
```

## AutoML for Model Compression Pruner

AutoML for Model Compression Pruner (AMCPruner) leverages reinforcement learning to provide the model compression policy.
This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio,
better preserving the accuracy and freeing human labor.

![](../../img/amc_pruner.jpg)

For more details, please refer to [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf).


#### Usage

PyTorch code

```python
from nni.compression.torch import AMCPruner
config_list = [{
'op_types': ['Conv2d', 'Linear']
}]
pruner = AMCPruner(model, config_list, evaluator, val_loader, sparsity=0.5)
pruner.compress()
```

You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/amc/) for more information.

#### User configuration for AutoCompress Pruner

##### PyTorch

```eval_rst
.. autoclass:: nni.compression.torch.AMCPruner
```

## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
Expand Down
Binary file added docs/img/amc_pruner.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
134 changes: 134 additions & 0 deletions examples/model_compress/amc/amc_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
import time

import torch
import torch.nn as nn

from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy

device = None
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved

def parse_args():
parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='model to prune')
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset to use (cifar/imagenet)')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
parser.add_argument('--sparsity', default=0.5, type=float, help='sparsity of the model')
parser.add_argument('--lbound', default=0., type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=0.8, type=float, help='maximum sparsity')
parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint')

parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--data_bsize', default=50, type=int, help='number of data batch size')
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--export', action='store_true', help='search best pruning policy or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--export_source_path', default=None, type=str, help='path for searched best wrapped model')

return parser.parse_args()


def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet':
from mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif model == 'mobilenetv2' and dataset == 'imagenet':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'mobilenet' and dataset == 'cifar10':
from mobilenet import MobileNet
net = MobileNet(n_class=10)
elif model == 'mobilenetv2' and dataset == 'cifar10':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10)
else:
raise NotImplementedError
if checkpoint_path:
print('loading {}...'.format(checkpoint_path))
sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
if 'state_dict' in sd: # a checkpoint but not a state_dict
sd = sd['state_dict']
sd = {k.replace('module.', ''): v for k, v in sd.items()}
net.load_state_dict(sd)

if torch.cuda.is_available() and n_gpu > 0:
net = net.cuda()
if n_gpu > 1:
net = torch.nn.DataParallel(net, range(n_gpu))

return net

def init_data(args):
# split the train set into train + val
# for CIFAR, split 5k for val
# for ImageNet, split 3k for val
val_size = 5000 if 'cifar' in args.dataset else 3000
train_loader, val_loader, _ = get_split_dataset(
args.dataset, args.data_bsize,
args.n_worker, val_size,
data_root=args.data_root,
shuffle=False
) # same sampling
return train_loader, val_loader

def validate(val_loader, model, verbose=False):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

criterion = nn.CrossEntropyLoss().cuda()
# switch to evaluate mode
model.eval()
end = time.time()

t1 = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.to(device)
input_var = torch.autograd.Variable(input).to(device)
target_var = torch.autograd.Variable(target).to(device)

# compute output
output = model(input_var)
loss = criterion(output, target_var)

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
t2 = time.time()
if verbose:
print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' %
(losses.avg, top1.avg, top5.avg, t2 - t1))
return top5.avg


if __name__ == "__main__":
args = parse_args()

device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu')

model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu)
_, val_loader = init_data(args)

config_list = [{
'op_types': ['Conv2d', 'Linear']
}]
pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type,
train_episode=args.train_episode, export=args.export, export_path=args.export_path,
export_source_path=args.export_source_path,
sparsity=args.sparsity, lbound=args.lbound, rbound=args.rbound)
pruner.compress()
Loading