Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Filter prune algo implementation #1655

Merged
merged 47 commits into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
e47c923
fpgm pruner pytorch implementation
chicm-ms Oct 23, 2019
c51f688
updates
chicm-ms Oct 25, 2019
b1165da
updates
chicm-ms Oct 28, 2019
cd32a6a
updates
chicm-ms Oct 28, 2019
8717026
updates
chicm-ms Oct 29, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
e25d9be
Merge branch 'master' into filter_prune
chicm-ms Oct 30, 2019
216a9a7
updates
chicm-ms Oct 31, 2019
a42a067
updates
chicm-ms Oct 31, 2019
8fd58bb
updates
chicm-ms Oct 31, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
fef6ec2
Merge branch 'master' into filter_prune
chicm-ms Nov 7, 2019
ec2b3fb
updates per refactored framework
chicm-ms Nov 7, 2019
3040b6e
updates
chicm-ms Nov 7, 2019
0ca60cf
updates
chicm-ms Nov 7, 2019
cd069fd
updates
chicm-ms Nov 7, 2019
a4a999b
update documents
chicm-ms Nov 7, 2019
bd622a2
updates
chicm-ms Nov 7, 2019
8a939b4
updates
chicm-ms Nov 7, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
302f1bd
tensorflow 2.0 implementation
chicm-ms Nov 13, 2019
a3e4b90
updates
chicm-ms Nov 13, 2019
20aedfc
updates
chicm-ms Nov 13, 2019
676348b
updates
chicm-ms Nov 13, 2019
2b22a1a
updates
chicm-ms Nov 13, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
43bf2b7
Merge branch 'master' into filter_prune
chicm-ms Nov 14, 2019
9e68e2b
updates
chicm-ms Nov 14, 2019
eadc941
updates
chicm-ms Nov 14, 2019
2978b7c
updates
chicm-ms Nov 14, 2019
03d71da
updates
chicm-ms Nov 14, 2019
22b6475
updates
chicm-ms Nov 14, 2019
053e3d1
updates
chicm-ms Nov 14, 2019
08f5237
updates
chicm-ms Nov 14, 2019
ec8bb4e
updates
chicm-ms Nov 14, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en_US/Compressor/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)|
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
| [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)|
Expand Down
46 changes: 46 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,49 @@ You can view example for more information

***

## FPGM Pruner
FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf)

>Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance.

### Usage
First, you should import pruner and add mask to model.

Tensorflow code
```python
from nni.compression.tensorflow import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
```
PyTorch code
```python
from nni.compression.torch import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
```
Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers.

Second, you should add code below to update epoch number at beginning of each epoch.

Tensorflow code
```python
pruner.update_epoch(epoch, sess)
```
PyTorch code
```python
pruner.update_epoch(epoch)
```
You can view example for more information

#### User configuration for FPGM Pruner
* **sparsity:** How much percentage of convolutional filters are to be pruned.

***
56 changes: 56 additions & 0 deletions examples/model_compress/fpgm_tf_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"
import numpy as np
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from nni.compression.tensorflow import FPGMPruner

def get_data():
(X_train_full, y_train_full), _ = keras.datasets.mnist.load_data()
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

X_mean = X_train.mean(axis=0, keepdims=True)
X_std = X_train.std(axis=0, keepdims=True) + 1e-7
X_train = (X_train - X_mean) / X_std
X_valid = (X_valid - X_mean) / X_std

X_train = X_train[..., np.newaxis]
X_valid = X_valid[..., np.newaxis]

return X_train, X_valid, y_train, y_valid

def get_model():
model = keras.models.Sequential([
Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"),
MaxPooling2D(pool_size=2),
Flatten(),
Dense(units=128, activation='relu'),
Dropout(0.5),
Dense(units=10, activation='softmax'),
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
return model

def main():
X_train, X_valid, y_train, y_valid = get_data()
model = get_model()

configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(model, configure_list)
pruner.compress()

update_epoch_callback = keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: pruner.update_epoch(epoch))

model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid), callbacks=[update_epoch_callback])


if __name__ == '__main__':
main()
101 changes: 101 additions & 0 deletions examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from nni.compression.torch import FPGMPruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms


class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

def _get_conv_weight_sparsity(self, conv_layer):
num_zero_filters = (conv_layer.weight.data.sum((2,3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0) * conv_layer.weight.data.size(1)
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters

def print_conv_filter_sparsity(self):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))

def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
model.print_conv_filter_sparsity()
loss.backward()
optimizer.step()

def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


def main():
torch.manual_seed(0)
device = torch.device('cpu')

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()
model.print_conv_filter_sparsity()

'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]

pruner = FPGMPruner(model, configure_list)
pruner.compress()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)


if __name__ == '__main__':
main()
104 changes: 103 additions & 1 deletion src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import numpy as np
import tensorflow as tf
from .compressor import Pruner

__all__ = ['LevelPruner', 'AGP_Pruner']
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner']

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,3 +99,104 @@ def update_epoch(self, epoch, sess):
sess.run(tf.assign(self.now_epoch, int(epoch)))
for k in self.if_init_list:
self.if_init_list[k] = True

class FPGMPruner(Pruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""

def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_dict = {}
self.assign_handler = []
self.epoch_pruned_layers = set()

def calc_mask(self, layer, config):
"""
Supports Conv1D, Conv2D
filter dimensions for Conv1D:
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
LEN: filter length
IN: number of input channel
OUT: number of output channel

filter dimensions for Conv2D:
H: filter height
W: filter width
IN: number of input channel
OUT: number of output channel

Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
config : dict
the configuration for generating the mask
"""

weight = layer.weight
op_type = layer.type
op_name = layer.name
assert 0 <= config.get('sparsity') < 1
assert op_type in ['Conv1D', 'Conv2D']
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
assert op_type in config['op_types']

if layer.name in self.epoch_pruned_layers:
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)

try:
weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1]))
masks = np.ones(weight.shape)

num_kernels = weight.shape[0] * weight.shape[1]
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[tuple(idx)] = 0.
finally:
masks = np.transpose(masks, [2, 3, 0, 1])
masks = tf.Variable(masks)
self.mask_dict.update({op_name: masks})
self.epoch_pruned_layers.add(layer.name)

return masks

def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2

dist_list, idx_list = [], []
for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append(dist_sum)
idx_list.append([in_i, out_i])
dist_tensor = tf.convert_to_tensor(dist_list)
idx_tensor = tf.constant(idx_list)

_, idx = tf.math.top_k(dist_tensor, k=n)
return tf.gather(idx_tensor, idx)

def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0], 1, 1])
x = w - anchor_w
x = tf.math.reduce_sum((x*x), (-2, -1))
x = tf.math.sqrt(x)
return tf.math.reduce_sum(x)

def update_epoch(self, epoch):
self.epoch_pruned_layers = set()
Loading