Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge master #254

Merged
merged 4 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/en_US/Compressor/CompressionReference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Python API Reference of Compression Utilities

```eval_rst
.. contents::
```

## Sensitivity Utilities

```eval_rst
.. autoclass:: nni.compression.torch.utils.sensitivity_analysis.SensitivityAnalysis
:members:

```

## Topology Utilities

```eval_rst
.. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency
:members:

.. autoclass:: nni.compression.torch.utils.mask_conflict.MaskConflict
:members:
```
123 changes: 123 additions & 0 deletions docs/en_US/Compressor/CompressionUtils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Analysis Utils for Model Compression

```eval_rst
.. contents::
```

We provide several easy-to-use tools for users to analyze their model during model compression.

## Sensitivity Analysis
First, we provide a sensitivity analysis tool (**SensitivityAnalysis**) for users to analyze the sensitivity of each convolutional layer in their model. Specifically, the SensitiviyAnalysis gradually prune each layer of the model, and test the accuracy of the model at the same time. Note that, SensitivityAnalysis only prunes a layer once a time, and the other layers are set to their original weights. According to the accuracies of different convolutional layers under different sparsities, we can easily find out which layers the model accuracy is more sensitive to.

### Usage

The following codes show the basic usage of the SensitivityAnalysis.
```python
from nni.compression.torch.utils.sensitivity_analysis import SensitivityAnalysis

def val(model):
model.eval()
total = 0
correct = 0
with torch.no_grad():
for batchid, (data, label) in enumerate(val_loader):
data, label = data.cuda(), label.cuda()
out = model(data)
_, predicted = out.max(1)
total += data.size(0)
correct += predicted.eq(label).sum().item()
return correct / total

s_analyzer = SensitivityAnalysis(model=net, val_func=val)
sensitivity = s_analyzer.analysis(val_args=[net])
os.makedir(outdir)
s_analyzer.export(os.path.join(outdir, filename))
```

Two key parameters of SensitivityAnalysis are `model`, and `val_func`. `model` is the neural network that to be analyzed and the `val_func` is the validation function that returns the model accuracy/loss/ or other metrics on the validation dataset. Due to different scenarios may have different ways to calculate the loss/accuracy, so users should prepare a function that returns the model accuracy/loss on the dataset and pass it to SensitivityAnalysis.
SensitivityAnalysis can export the sensitivity results as a csv file usage is shown in the example above.

Futhermore, users can specify the sparsities values used to prune for each layer by optional parameter `sparsities`.
```python
s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75])
```
the SensitivityAnalysis will prune 25% 50% 75% weights gradually for each layer, and record the model's accuracy at the same time (SensitivityAnalysis only prune a layer once a time, the other layers are set to their original weights). If the sparsities is not set, SensitivityAnalysis will use the numpy.arange(0.1, 1.0, 0.1) as the default sparsity values.

Users can also speed up the progress of sensitivity analysis by the early_stop_mode and early_stop_value option. By default, the SensitivityAnalysis will test the accuracy under all sparsities for each layer. In contrast, when the early_stop_mode and early_stop_value are set, the sensitivity analysis for a layer will stop, when the accuracy/loss has already met the threshold set by early_stop_value. We support four early stop modes: minimize, maximize, dropped, raised.

minimize: The analysis stops when the validation metric return by the val_func lower than `early_stop_value`.

maximize: The analysis stops when the validation metric return by the val_func larger than `early_stop_value`.

dropped: The analysis stops when the validation metric has dropped by `early_stop_value`.

raised: The analysis stops when the validation metric has raised by `early_stop_value`.

```python
s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75], early_stop_mode='dropped', early_stop_value=0.1)
```
If users only want to analyze several specified convolutional layers, users can specify the target conv layers by the `specified_layers` in analysis function. `specified_layers` is a list that consists of the Pytorch module names of the conv layers. For example
```python
sensitivity = s_analyzer.analysis(val_args=[net], specified_layers=['Conv1'])
```
In this example, only the `Conv1` layer is analyzed. In addtion, users can quickly and easily achieve the analysis parallelization by launching multiple processes and assigning different conv layers of the same model to each process.


### Output example
The following lines are the example csv file exported from SensitivityAnalysis. The first line is constructed by 'layername' and sparsity list. Here the sparsity value means how much weight SensitivityAnalysis prune for each layer. Each line below records the model accuracy when this layer is under different sparsities. Note that, due to the early_stop option, some layers may
not have model accuracies/losses under all sparsities, for example, its accuracy drop has already exceeded the threshold set by the user.
```
layername,0.05,0.1,0.2,0.3,0.4,0.5,0.7,0.85,0.95
features.0,0.54566,0.46308,0.06978,0.0374,0.03024,0.01512,0.00866,0.00492,0.00184
features.3,0.54878,0.51184,0.37978,0.19814,0.07178,0.02114,0.00438,0.00442,0.00142
features.6,0.55128,0.53566,0.4887,0.4167,0.31178,0.19152,0.08612,0.01258,0.00236
features.8,0.55696,0.54194,0.48892,0.42986,0.33048,0.2266,0.09566,0.02348,0.0056
features.10,0.55468,0.5394,0.49576,0.4291,0.3591,0.28138,0.14256,0.05446,0.01578
```

## Topology Analysis
We also provide several tools for the topology analysis during the model compression. These tools are to help users compress their model better. Because of the complex topology of the network, when compressing the model, users often need to spend a lot of effort to check whether the compression configuration is reasonable. So we provide these tools for topology analysis to reduce the burden on users.

### ChannelDependency
Complicated models may have residual connection/concat operations in their models. When the user prunes these models, they need to be careful about the channel-count dependencies between the convolution layers in the model. Taking the following residual block in the resnet18 as an example. The output features of the `layer2.0.conv2` and `layer2.0.downsample.0` are added together, so the number of the output channels of `layer2.0.conv2` and `layer2.0.downsample.0` should be the same, or there may be a tensor shape conflict.

![](../../img/channel_dependency_example.jpg)


If the layers have channel dependency are assigned with different sparsities (here we only discuss the structured pruning by L1FilterPruner/L2FilterPruner), then there will be a shape conflict during these layers. Even the pruned model with mask works fine, the pruned model cannot be speedup to the final model directly that runs on the devices, because there will be a shape conflict when the model tries to add/concat the outputs of these layers. This tool is to find the layers that have channel count dependencies to help users better prune their model.

#### Usage
```python
from nni.compression.torch.utils.shape_dependency import ChannelDependency
data = torch.ones(1, 3, 224, 224).cuda()
channel_depen = ChannelDependency(net, data)
channel_depen.export('dependency.csv')
```

#### Output Example
The following lines are the output example of torchvision.models.resnet18 exported by ChannelDependency. The layers at the same line have output channel dependencies with each other. For example, layer1.1.conv2, conv1, and layer1.0.conv2 have output channel dependencies with each other, which means the output channel(filters) numbers of these three layers should be same with each other, otherwise, the model may have shape conflict.
```
Dependency Set,Convolutional Layers
Set 1,layer1.1.conv2,layer1.0.conv2,conv1
Set 2,layer1.0.conv1
Set 3,layer1.1.conv1
Set 4,layer2.0.conv1
Set 5,layer2.1.conv2,layer2.0.conv2,layer2.0.downsample.0
Set 6,layer2.1.conv1
Set 7,layer3.0.conv1
Set 8,layer3.0.downsample.0,layer3.1.conv2,layer3.0.conv2
Set 9,layer3.1.conv1
Set 10,layer4.0.conv1
Set 11,layer4.0.downsample.0,layer4.1.conv2,layer4.0.conv2
Set 12,layer4.1.conv1
```

### MaskConflict
When the masks of different layers in a model have conflict (for example, assigning different sparsities for the layers that have channel dependency), we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value.

```
from nni.compression.torch.utils.mask_conflict import MaskConflict
mc = MaskConflict('./resnet18_mask', net, data)
mc.fix_mask_conflict()
mc.export('./resnet18_fixed_mask')
```
1 change: 1 addition & 0 deletions docs/en_US/model_compression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ For details, please refer to the following tutorials:
Model Speedup <Compressor/ModelSpeedup>
Automatic Model Compression <Compressor/AutoCompression>
Implementation <Compressor/Framework>
Compression Utilities <Compressor/CompressionUtils>
3 changes: 2 additions & 1 deletion docs/en_US/sdk_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Python API Reference
:maxdepth: 1

Auto Tune <autotune_ref>
NAS <NAS/NasReference>
NAS <NAS/NasReference>
Compression Utilities <Compressor/CompressionReference>
Binary file added docs/img/channel_dependency_example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/model_compress/model_prune_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default']
'op_types': ['Conv2d']
}]
},
'slim': {
Expand Down Expand Up @@ -79,7 +79,7 @@
'pruner_class': ActivationAPoZRankFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['default'],
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
}
Expand Down
2 changes: 1 addition & 1 deletion src/nni_manager/training_service/pai/paiK8S/paiK8SData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ export const PAI_K8S_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& NNI_CODE_DIR={6} && mkdir -p $NNI_SYS_DIR/code && cp -r $NNI_CODE_DIR/. $NNI_SYS_DIR/code && sh $NNI_SYS_DIR/install_nni.sh \
&& cd $NNI_SYS_DIR/code && python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' \
--nni_manager_version '{10}' --log_collection '{11}'`;
--nni_manager_version '{10}' --log_collection '{11}' | tee $NNI_OUTPUT_DIR/trial.log`;
96 changes: 96 additions & 0 deletions src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import numpy as np
from .shape_dependency import ChannelDependency
# logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger('FixMaskConflict')

class MaskConflict:
def __init__(self, mask_file, model=None, dummy_input=None, graph=None):
"""
MaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.

Parameters
----------
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
mask_file : str
the path of the original mask file
graph : torch._C.Graph
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
# check if the parameters are valid
parameter_valid = False
if graph is not None:
parameter_valid = True
elif (model is not None) and (dummy_input is not None):
parameter_valid = True
if not parameter_valid:
raise Exception('The input parameters is invalid!')
self.model = model
self.dummy_input = dummy_input
self.graph = graph
self.mask_file = mask_file
self.masks = torch.load(self.mask_file)

def fix_mask_conflict(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
channel_depen = ChannelDependency(self.model, self.dummy_input, self.graph)
depen_sets = channel_depen.dependency_sets
for dset in depen_sets:
if len(dset) == 1:
# This layer has no channel dependency with other layers
continue
channel_remain = set()
fine_grained = False
for name in dset:
if name not in self.masks:
# this layer is not pruned
continue
w_mask = self.masks[name]['weight']
shape = w_mask.size()
count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, there is no need to check
# the shape conflict
_logger.info('Layers %s using fine-grained pruning', ','.join(dset))
fine_grained = True
break
channel_remain.update(all_ones)
_logger.debug('Layer: %s ', name)
_logger.debug('Original pruned filters: %s', str(all_zeros))
# Update the masks for the layers in the dependency set
if fine_grained:
continue
ori_channels = 0
for name in dset:
mask = self.masks[name]
w_shape = mask['weight'].size()
ori_channels = w_shape[0]
for i in channel_remain:
mask['weight'][i] = torch.ones(w_shape[1:])
if hasattr(mask, 'bias'):
mask['bias'][i] = 1
_logger.info(','.join(dset))
_logger.info('Pruned Filters after fixing conflict:')
pruned_filters = set(list(range(ori_channels)))-channel_remain
_logger.info(str(sorted(pruned_filters)))
return self.masks

def export(self, path):
"""
Export the masks after fixing the conflict to file.
"""
torch.save(self.masks, path)
Loading