Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compressor op_types (#1670) #40

Merged
merged 1 commit into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/en_US/Compressor/AutoCompression.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ You can easily compress a model with NNI compression. Take pruning for example,

```python
from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }]
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list)
pruner(model)
```

```{ 'sparsity': 0.8, 'op_types': 'default' }```means that **all layers with weight will be compressed with the same 0.8 sparsity**. When ```pruner(model)``` called, the model is compressed with masks and after that you can normally fine tune this model and **pruned weights won't be updated** which have been masked.
The 'default' op_type stands for the module types defined in [default_layers.py](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/default_layers.py) for pytorch.

Therefore ```{ 'sparsity': 0.8, 'op_types': ['default'] }```means that **all layers with specified op_types will be compressed with the same 0.8 sparsity**. When ```pruner(model)``` called, the model is compressed with masks and after that you can normally fine tune this model and **pruned weights won't be updated** which have been masked.

## Then, make this automatic

Expand Down
10 changes: 5 additions & 5 deletions docs/en_US/Compressor/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ We use a simple example to show how to modify your trial code in order to apply
Tensorflow code
```python
from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }]
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list)
pruner(tf.get_default_graph())
```

PyTorch code
```python
from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }]
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list)
pruner(model)
```
Expand Down Expand Up @@ -58,7 +58,7 @@ A simple example of configuration is shown below:
[
{
'sparsity': 0.8,
'op_types': 'default'
'op_types': ['default']
},
{
'sparsity': 0.6,
Expand Down Expand Up @@ -115,7 +115,7 @@ class YourPruner(nni.compression.tensorflow.Pruner):
def calc_mask(self, weight, config, **kwargs):
# weight is the target weight tensor
# config is the selected dict object in config_list for this layer
# kwargs contains op, op_type, and op_name
# kwargs contains op, op_types, and op_name
# design your mask and return your mask
return your_mask

Expand Down Expand Up @@ -158,7 +158,7 @@ class YourPruner(nni.compression.tensorflow.Quantizer):
def quantize_weight(self, weight, config, **kwargs):
# weight is the target weight tensor
# config is the selected dict object in config_list for this layer
# kwargs contains op, op_type, and op_name
# kwargs contains op, op_types, and op_name
# design your quantizer and return new weight
return new_weight

Expand Down
4 changes: 2 additions & 2 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ We first sort the weights in the specified layer by their absolute values. And t
Tensorflow code
```
from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }]
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list)
pruner(model_graph)
```

PyTorch code
```
from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }]
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list)
pruner(model)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/en_US/Compressor/Quantizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ You can quantize your model to 8 bits with the code below before your training c
Tensorflow code
```python
from nni.compressors.tensorflow import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': 'default' }]
config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
quantizer = QAT_Quantizer(config_list)
quantizer(tf.get_default_graph())
```
PyTorch code
```python
from nni.compressors.torch import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': 'default' }]
config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
quantizer = QAT_Quantizer(config_list)
quantizer(model)
```
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ configure_list = [{
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_type': 'default'
'op_types': ['default']
}]
pruner = AGP_Pruner(configure_list)
```
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/configure_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ AGPruner:
frequency: 1
initial_sparsity: 0.05
final_sparsity: 0.8
op_type: 'default'
op_types: ['default']
2 changes: 1 addition & 1 deletion examples/model_compress/main_tf_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def main():
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_type': 'default'
'op_types': ['default']
}]
pruner = AGP_Pruner(configure_list)
# if you want to load from yaml file
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/main_tf_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main():
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(tf.get_default_graph())
'''
configure_list = [{'q_bits':8, 'op_type':'default'}]
configure_list = [{'q_bits':8, 'op_types':['default']}]
quantizer = QAT_Quantizer(configure_list)
quantizer(tf.get_default_graph())
# you can also use compress(model) or compress_default_graph()
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/main_torch_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_type': 'default'
'op_types': ['default']
}]

pruner = AGP_Pruner(configure_list)
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compress/main_torch_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list = [{'q_bits':8, 'op_type':'default'}]
configure_list = [{'q_bits':8, 'op_types':['default']}]
quantizer = QAT_Quantizer(configure_list)
quantizer(model)
# you can also use compress(model) method
Expand Down
20 changes: 12 additions & 8 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ def _instrument_layer(self, layer, config):
def _select_config(self, layer):
ret = None
for config in self._config_list:
op_types = config.get('op_types')
if op_types == 'default':
op_types = default_layers.weighted_modules
if op_types and layer.type not in op_types:
config['op_types'] = self._expand_config_op_types(config)
if layer.type not in config['op_types']:
continue
if config.get('op_names') and layer.name not in config['op_names']:
continue
Expand All @@ -70,6 +68,16 @@ def _select_config(self, layer):
return None
return ret

def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types

class Pruner(Compressor):
"""
Expand Down Expand Up @@ -112,10 +120,6 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""

def __call__(self, model):
self.compress(model)
return model

def quantize_weight(self, weight, config, op, op_type, op_name):
"""user should know where dequantize goes and implement it in quantize method
we now do not provide dequantize method
Expand Down
8 changes: 4 additions & 4 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ def forward(self, x):
class CompressorTestCase(TestCase):
def test_tf_pruner(self):
model = TfMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}]
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(configure_list).compress_default_graph()

def test_tf_quantizer(self):
model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph()
tf_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress_default_graph()

def test_torch_pruner(self):
model = TorchMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}]
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(configure_list).compress(model)

def test_torch_quantizer(self):
model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model)
torch_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress(model)


if __name__ == '__main__':
Expand Down