Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Refactor model pruning framework #2504

Merged
merged 80 commits into from
Jun 12, 2020

Conversation

chicm-ms
Copy link
Contributor

@chicm-ms chicm-ms commented May 28, 2020

Before refactor:

A pruner class/instance includes 2 parts:

  1. pruning algorithms implementation ( how to prune a specified weight with specified sparsity)
  2. facilities to support the algorithms implementation:
    • module wrapper
    • config data and config validation
    • hooks for optimizer.step()
    • nn.Module forward hook for collecting activations (layer output)

problem: In one Pruner, it is not convenient to reuse another Pruner's pruning algorithm, especially for the algorithms need some intialization such as TaylorOF and APoz.

After refactor

Pruner interface for user is not changed, create a new pruning algorithms interface WeightMasker for code reuse
A previous pruner class is split into 2 classes:

  • A Masker class implements the pruning algorithms, with a common Masker interface.
  • A Pruner class is responsible for the facilities, and integrate a Masker class using general Masker interface, therefore it is easy to reuse any Masker class in any Pruner.

chicm-ms added 30 commits August 6, 2019 11:19
Filter prune algo implementation (microsoft#1655)
document the dispatcher working dir (microsoft#1866)
@QuanluZhang
Copy link
Contributor

@chicm-ms this refactor looks great. please update doc (i.e., tutorial) about how to write a new pruner.

@chicm-ms
Copy link
Contributor Author

@chicm-ms this refactor looks great. please update doc (i.e., tutorial) about how to write a new pruner.

Thanks, the doc updated.

return {'weight_mask': torch.ones(weight.shape).type_as(weight)}
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight_mask': mask_weight}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so bias has no mask in our implementation?

Copy link
Contributor Author

@chicm-ms chicm-ms Jun 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, maybe add bias mask later?

docs/en_US/Compressor/Framework.md Outdated Show resolved Hide resolved
docs/en_US/Compressor/Framework.md Outdated Show resolved Hide resolved
```
### Set wrapper attribute
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded link to source code is not a good idea. Recommend to use link to API docs instead. Still, keep it if you feel necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it for now.

docs/en_US/Compressor/Framework.md Outdated Show resolved Hide resolved
docs/en_US/Compressor/Framework.md Outdated Show resolved Hide resolved

logger = logging.getLogger('torch pruner')

class AGP_Pruner(Pruner):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggests AGPPruner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change it later, this PR is not intended to update interface.


__all__ = ['AGP_Pruner']

logger = logging.getLogger('torch pruner')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest using package name. If torch pruner is actually used by all existing pruners, feel free to keep it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's keep it for now.

src/sdk/pynni/nni/compression/torch/pruning/constants.py Outdated Show resolved Hide resolved
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

span could be 0 in default setting?

@chicm-ms chicm-ms merged commit 89fa23c into microsoft:master Jun 12, 2020
@chicm-ms chicm-ms deleted the refactor-compressor-framework branch June 12, 2020 04:14
@suiguoxin suiguoxin mentioned this pull request Jun 29, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants