Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
add pruner unit test (#1771)
Browse files Browse the repository at this point in the history
* add pruner unit test
* modify pruners compatible with torch0.4.1
  • Loading branch information
tanglang96 authored and chicm-ms committed Nov 25, 2019
1 parent 8ac61b7 commit 503a357
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/en_US/Compressor/SlimPruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ We implemented one of the experiments in ['Learning Efficient Convolutional Netw
| Model | Error(paper/ours) | Parameters | Pruned |
| ------------- | ----------------- | ---------- | --------- |
| VGGNet | 6.34/6.40 | 20.04M | |
| Pruned-VGGNet | 6.20/6.39 | 2.03M | 88.5% |
| Pruned-VGGNet | 6.20/6.26 | 2.03M | 88.5% |

The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/)
2 changes: 1 addition & 1 deletion examples/model_compress/slim_pruner_torch_vgg19.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def main():
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 93.61%
# top1 = 93.74%


if __name__ == '__main__':
Expand Down
10 changes: 5 additions & 5 deletions src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def calc_mask(self, layer, config):
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False})
Expand Down Expand Up @@ -108,7 +108,7 @@ def calc_mask(self, layer, config):
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
Expand Down Expand Up @@ -336,7 +336,7 @@ def calc_mask(self, layer, config):
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max()
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
Expand Down Expand Up @@ -370,10 +370,10 @@ def __init__(self, model, config_list):
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.clone())
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max()
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()

def calc_mask(self, layer, config):
"""
Expand Down
95 changes: 85 additions & 10 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor


def get_tf_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
Expand All @@ -20,42 +21,49 @@ def get_tf_model():
tf.keras.layers.Dense(units=10, activation='softmax'),
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
optimizer=tf.keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
return model


class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def tf2(func):
def test_tf2_func(*args):
if tf.__version__ >= '2.0':
func(*args)

return test_tf2_func

k1 = [[1]*3]*3
k2 = [[2]*3]*3
k3 = [[3]*3]*3
k4 = [[4]*3]*3
k5 = [[5]*3]*3

k1 = [[1] * 3] * 3
k2 = [[2] * 3] * 3
k3 = [[3] * 3] * 3
k4 = [[4] * 3] * 3
k5 = [[5] * 3] * 3

w = [[k1, k2, k3, k4, k5]] * 10


class CompressorTestCase(TestCase):
def test_torch_level_pruner(self):
model = TorchModel()
Expand All @@ -74,7 +82,7 @@ def test_torch_naive_quantizer(self):
'quant_bits': {
'weight': 8,
},
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}]
torch_compressor.NaiveQuantizer(model, configure_list).compress()

Expand Down Expand Up @@ -133,6 +141,73 @@ def test_tf_fpgm_pruner(self):

assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))

def test_torch_l1filter_pruner(self):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710
So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`
If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
"""
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)

model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))

def test_torch_slim_pruner(self):
"""
Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/pdf/1708.06519.pdf
So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
`all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`
If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
`all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
"""
w = np.array([0, 1, 2, 3, 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_compressor.SlimPruner(model, config_list)

layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))

config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_compressor.SlimPruner(model, config_list)

layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))


if __name__ == '__main__':
main()

0 comments on commit 503a357

Please sign in to comment.