-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Analysis utils #2435
Analysis utils #2435
Changes from 41 commits
0a4b7b0
712d982
202593c
8a7a799
362441a
e69e78f
5823276
4a70d79
fc95dd7
a90c35e
1909ff0
2d13dda
0e79624
96cea74
6029603
9beb1e2
6b25ff3
d0bda49
4154cf0
83f0b26
388056c
ccbcc6c
4ce8255
0f70f67
2eac259
810f20e
dcdc736
3b9f4df
a214bb8
6d1a546
3aeb8a2
bf72f3d
caced25
69ea95e
c0e93e5
6d7ea88
e7790a2
b7671da
1b9705b
9d0519e
f563802
a24acd0
91d5f49
33178a2
7cab808
e8d4c31
3351cef
db0ff63
7153bd7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Python API Reference of Compression Utilities | ||
|
||
```eval_rst | ||
.. contents:: | ||
``` | ||
|
||
## Sensitivity Utilities | ||
|
||
```eval_rst | ||
.. autoclass:: nni.compression.torch.utils.sensitivity_analysis.SensitivityAnalysis | ||
:members: | ||
|
||
``` | ||
|
||
## Topology Utilities | ||
|
||
```eval_rst | ||
.. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency | ||
:members: | ||
|
||
.. autoclass:: nni.compression.torch.utils.mask_conflict.MaskConflict | ||
:members: | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Analysis Utils for Model Compression | ||
We provide several easy-to-use tools for users to analyze their model during model compression. | ||
|
||
## Sensitivity | ||
First, we provide a sensitivity analysis tool (**SensitivityAnalysis**) for users to analyze the sensitivity of each convolutional layer in their model. Specifically, the SensitiviyAnalysis gradually prune each layer of the model, and test the accuracy of the model at the same time. Note that, SensitivityAnalysis only prunes a layer once a time, and the other layers are set to their original weights. According to the accuracies of different convolutional layers under different sparsities, we can easily find out which layers the model accuracy is more sensitive to. | ||
|
||
### Usage | ||
|
||
Following codes show the basic usage of the SensitivityAnalysis. | ||
```python | ||
from nni.compression.torch.utils.sensitivity_analysis import SensitivityAnalysis | ||
|
||
def val(model): | ||
model.eval() | ||
total = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for batchid, (data, label) in enumerate(val_loader): | ||
data, label = data.cuda(), label.cuda() | ||
out = model(data) | ||
_, predicted = out.max(1) | ||
total += data.size(0) | ||
correct += predicted.eq(label).sum().item() | ||
return correct / total | ||
|
||
s_analyzer = SensitivityAnalysis(model=net, val_func=val) | ||
sensitivity = s_analyzer.analysis(val_args=[net]) | ||
os.makedir(outdir) | ||
s_analyzer.export(os.path.join(outdir, filename)) | ||
``` | ||
|
||
Two key parameters of SensitivityAnalysis are model, and val_func. 'model' is the neural network that to be analyzed and the 'val_func' is the validation function that returns the model accuracy/loss/ or other metrics on the validation dataset. Due to different scenarios may have different ways to calculate the loss/accuracy, so users should prepare a function that returns the model accuracy/loss on the dataset and pass it to SensitivityAnalysis. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'model' -> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks~ I have fixed it, please review on the latest version~ |
||
SensitivityAnalysis can export the sensitivity results as a csv file usage is shown in the example above. | ||
|
||
Futhermore, users can specify the sparsities values used to prune for each layer by optinal parameter 'sparsities'. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optinal -> optional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fixed several grammar errors, please review on the latest version, thanks~ |
||
```python | ||
s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75]) | ||
``` | ||
the SensitivityAnalysis will prune 25% 50% 75% weights gradually for each layer, and record the model's accuracy at the same time (SensitivityAnalysis only prune a layer once a time, the other layers are set to their original weights). If the sparsities is not set, SensitivityAnalysis will use the numpy.arange(0.1, 1.0, 0.1) as the default sparsity values. | ||
|
||
Users can also speed up the progress of sensitivity analysis by the early_stop/min_threshold/max_threshold option. By default, the SensitivityAnalysis will test the accuracy under all sparsities for each layer. In contrast, when the early_stop is set, the sensitivity analysis for a layer will stop, when the accuracy/loss has already droped/raised the value of early_stop. If the min_threshold/max_threshold is set, when the validation metric returned by the val_func is lower/larger than the threshold, the sensitivity analysis will stop. | ||
```python | ||
s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75], early_stop=0.1) | ||
zheng-ningxin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
If users only want to analyze several specified convolutional layers, users can specify the target conv layers by the 'sepcified_layers' parameter in analysis function. For example | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how to specify a conv layer, using module name or module type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sensitivity analyisis only analyze the conv layers, so we use the module name to specify the layers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then better to make it clear that the layers are specified through PyTorch module name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'll update the doc and make this clear, thanks. |
||
```python | ||
sensitivity = s_analyzer.analysis(val_args=[net], specified_layers=['Conv1']) | ||
``` | ||
In this example, only the Conv1 layer is analyzed. | ||
|
||
|
||
### Output example | ||
The following lines are the example csv file exported from SensitivityAnalysis. The first line is constructed by 'layername' and sparsity list. Here the sparsity value means how much weight SensitivityAnalysis prune for each layer. Each line below records the model accuracy when this layer is under different sparsities. Note that, due to the early_stop option, some layers may | ||
not have model accuracies under all sparsities, because its accuracy drop has alreay exceeded the threshold set by the user. | ||
``` | ||
layername,0.05,0.1,0.2,0.3,0.4,0.5,0.7,0.85,0.95 | ||
features.0,0.54566,0.46308,0.06978,0.0374,0.03024,0.01512,0.00866,0.00492,0.00184 | ||
features.3,0.54878,0.51184,0.37978,0.19814,0.07178,0.02114,0.00438,0.00442,0.00142 | ||
features.6,0.55128,0.53566,0.4887,0.4167,0.31178,0.19152,0.08612,0.01258,0.00236 | ||
features.8,0.55696,0.54194,0.48892,0.42986,0.33048,0.2266,0.09566,0.02348,0.0056 | ||
features.10,0.55468,0.5394,0.49576,0.4291,0.3591,0.28138,0.14256,0.05446,0.01578 | ||
``` | ||
|
||
## Topology | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> Topology Analysis |
||
We also provide several tools for the topology analysis during the model compression. | ||
|
||
### ChannelDependency | ||
Complicated models may has residual connection/concat operations in their models. When the user prune these models, they need to be careful about the channel-count dependencies between the convolution layers in the model. If the layers has channel dependency are assigned with different sparsities (here we only discuss the structured pruning by L1FilterPruner/L2FilterPruner), then even the pruned model with mask works fine. but the pruned model cannot be speedup to the final model that run on the devices, because there will be a shape conflict when the model try to add/concat the outputs of these layers. This tool is to find the layers that has channel count dependencies to help user better prune their model. | ||
|
||
#### Usage | ||
```python | ||
from nni.compression.torch.utils.shape_dependency import ChannelDependency | ||
data = torch.ones(1, 3, 224, 224).cuda() | ||
channel_depen = ChannelDependency(net, data) | ||
channel_depen.export('dependency.csv') | ||
``` | ||
|
||
#### Output Example | ||
Following lines are the output example of torchvision.models.resnet18 exported by ChannelDependency. The layers at the same line have output channel dependencies with each other. For example, layer1.1.conv2, conv1, and layer1.0.conv2 have output channel dependencies with each other, which means the output channel(filters) numbers of these three layers should be same with each other, otherwise the model may has shape conflict. | ||
``` | ||
Dependency Set,Convolutional Layers | ||
Set 1,layer1.1.conv2,layer1.0.conv2,conv1 | ||
Set 2,layer1.0.conv1 | ||
Set 3,layer1.1.conv1 | ||
Set 4,layer2.0.conv1 | ||
Set 5,layer2.1.conv2,layer2.0.conv2,layer2.0.downsample.0 | ||
Set 6,layer2.1.conv1 | ||
Set 7,layer3.0.conv1 | ||
Set 8,layer3.0.downsample.0,layer3.1.conv2,layer3.0.conv2 | ||
Set 9,layer3.1.conv1 | ||
Set 10,layer4.0.conv1 | ||
Set 11,layer4.0.downsample.0,layer4.1.conv2,layer4.0.conv2 | ||
Set 12,layer4.1.conv1 | ||
``` | ||
|
||
### MaskConflict | ||
When the masks of different layers in a model has conflict, we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value. | ||
|
||
``` | ||
from nni.compression.torch.utils.mask_conflict import MaskConflict | ||
mc = MaskConflict('./resnet18_mask', net, data) | ||
mc.fix_mask_conflict() | ||
mc.export('./resnet18_fixed_mask') | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
import logging | ||
import torch | ||
import numpy as np | ||
from .shape_dependency import ChannelDependency | ||
# logging.basicConfig(level = logging.DEBUG) | ||
_logger = logging.getLogger('FixMaskConflict') | ||
|
||
class MaskConflict: | ||
def __init__(self, mask_file, model=None, dummy_input=None, graph=None): | ||
""" | ||
MaskConflict fix the mask conflict between the layers that | ||
has channel dependecy with each other. | ||
|
||
Parameters | ||
---------- | ||
model : torch.nn.Module | ||
model to fix the mask conflict | ||
dummy_input : torch.Tensor | ||
input example to trace the model | ||
mask_file : str | ||
the path of the original mask file | ||
graph : torch._C.Graph | ||
the traced graph of the target model, is this parameter is not None, | ||
we donnot use the model and dummpy_input to get the trace graph. | ||
""" | ||
# check if the parameters are valid | ||
parameter_valid = False | ||
if graph is not None: | ||
parameter_valid = True | ||
elif (model is not None) and (dummy_input is not None): | ||
parameter_valid = True | ||
if not parameter_valid: | ||
raise Exception('The input parameters is invalid!') | ||
self.model = model | ||
self.dummy_input = dummy_input | ||
self.graph = graph | ||
self.mask_file = mask_file | ||
self.masks = torch.load(self.mask_file) | ||
|
||
def fix_mask_conflict(self): | ||
""" | ||
Fix the mask conflict before the mask inference for the layers that | ||
has shape dependencies. This function should be called before the | ||
mask inference of the 'speedup' module. | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
channel_depen = ChannelDependency(self.model, self.dummy_input, self.graph) | ||
depen_sets = channel_depen.dependency_sets | ||
for dset in depen_sets: | ||
if len(dset) == 1: | ||
# This layer has no channel dependency with other layers | ||
continue | ||
channel_remain = set() | ||
fine_grained = False | ||
for name in dset: | ||
if name not in self.masks: | ||
# this layer is not pruned | ||
continue | ||
w_mask = self.masks[name]['weight'] | ||
shape = w_mask.size() | ||
count = np.prod(shape[1:]) | ||
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() | ||
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist() | ||
if len(all_ones) + len(all_zeros) < w_mask.size(0): | ||
# In fine-grained pruning, there is no need to check | ||
# the shape conflict | ||
_logger.info('Layers %s using fine-grained pruning', ','.join(dset)) | ||
fine_grained = True | ||
break | ||
channel_remain.update(all_ones) | ||
_logger.debug('Layer: %s ', name) | ||
_logger.debug('Original pruned filters: %s', str(all_zeros)) | ||
# Update the masks for the layers in the dependency set | ||
if fine_grained: | ||
continue | ||
ori_channels = 0 | ||
for name in dset: | ||
mask = self.masks[name] | ||
w_shape = mask['weight'].size() | ||
ori_channels = w_shape[0] | ||
for i in channel_remain: | ||
mask['weight'][i] = torch.ones(w_shape[1:]) | ||
if hasattr(mask, 'bias'): | ||
mask['bias'][i] = 1 | ||
_logger.info(','.join(dset)) | ||
_logger.info('Pruned Filters after fixing conflict:') | ||
pruned_filters = set(list(range(ori_channels)))-channel_remain | ||
_logger.info(str(sorted(pruned_filters))) | ||
return self.masks | ||
|
||
def export(self, path): | ||
""" | ||
Export the masks after fixing the conflict to file. | ||
""" | ||
torch.save(self.masks, path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add
after this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in order for users to easily get what is the content of this doc