Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Model Compression / TensorFlow] Support exporting pruned model (#3487)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Apr 9, 2021
1 parent f0e3c58 commit b7062b5
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 6 deletions.
90 changes: 86 additions & 4 deletions nni/compression/tensorflow/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def _instrument(self, layer):

return layer

def _uninstrument(self, layer):
# note that ``self._wrappers`` cache is not cleared here,
# so the same wrapper objects will be recovered in next ``self._instrument()`` call
if isinstance(layer, LayerWrapper):
layer._instrumented = False
return self._uninstrument(layer.layer)
if isinstance(layer, tf.keras.Sequential):
return self._uninstrument_sequential(layer)
if isinstance(layer, tf.keras.Model):
return self._uninstrument_model(layer)
return layer

def _instrument_sequential(self, seq):
layers = list(seq.layers) # seq.layers is read-only property
need_rebuild = False
Expand All @@ -97,6 +109,16 @@ def _instrument_sequential(self, seq):
need_rebuild = True
return tf.keras.Sequential(layers) if need_rebuild else seq

def _uninstrument_sequential(self, seq):
layers = list(seq.layers)
rebuilt = False
for i, layer in enumerate(layers):
orig_layer = self._uninstrument(layer)
if orig_layer is not layer:
layers[i] = orig_layer
rebuilt = True
return tf.keras.Sequential(layers) if rebuilt else seq

def _instrument_model(self, model):
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration"
if isinstance(value, tf.keras.layers.Layer):
Expand All @@ -109,6 +131,17 @@ def _instrument_model(self, model):
value[i] = self._instrument(item)
return model

def _uninstrument_model(self, model):
for key, value in list(model.__dict__.items()):
if isinstance(value, tf.keras.layers.Layer):
orig_layer = self._uninstrument(value)
if orig_layer is not value:
setattr(model, key, orig_layer)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, tf.keras.layers.Layer):
value[i] = self._uninstrument(item)
return model

def _select_config(self, layer):
# Find the last matching config block for given layer.
Expand All @@ -129,6 +162,17 @@ def _select_config(self, layer):
return last_match


class LayerWrapper(tf.keras.Model):
"""
Abstract base class of layer wrappers.
Concrete layer wrapper classes must inherit this to support ``isinstance`` check.
"""
def __init__(self):
super().__init__()
self._instrumented = True


class Pruner(Compressor):
"""
Base class for pruning algorithms.
Expand Down Expand Up @@ -167,6 +211,43 @@ def compress(self):
self._update_mask()
return self.compressed_model

def export_model(self, model_path, mask_path=None):
"""
Export pruned model and optionally mask tensors.
Parameters
----------
model_path : path-like
The path passed to ``Model.save()``.
You can use ".h5" extension name to export HDF5 format.
mask_path : path-like or None
Export masks to the path when set.
Because Keras cannot save tensors without a ``Model``,
this will create a model, set all masks as its weights, and then save that model.
Masks in saved model will be named by corresponding layer name in compressed model.
Returns
-------
None
"""
_logger.info('Saving model to %s', model_path)
input_shape = self.compressed_model._build_input_shape # cannot find a public API
model = self._uninstrument(self.compressed_model)
if input_shape:
model.build(input_shape)
model.save(model_path)
self._instrument(model)

if mask_path is not None:
_logger.info('Saving masks to %s', mask_path)
# can't find "save raw weights" API in tensorflow, so build a simple model
mask_model = tf.keras.Model()
for wrapper in self.wrappers:
setattr(mask_model, wrapper.layer.name, wrapper.masks)
mask_model.save_weights(mask_path)

_logger.info('Done')

def calc_masks(self, wrapper, **kwargs):
"""
Abstract method to be overridden by algorithm. End users should ignore it.
Expand Down Expand Up @@ -199,7 +280,7 @@ def _update_mask(self):
wrapper.masks = masks


class PrunerLayerWrapper(tf.keras.Model):
class PrunerLayerWrapper(LayerWrapper):
"""
Instrumented TF layer.
Expand All @@ -210,8 +291,6 @@ class PrunerLayerWrapper(tf.keras.Model):
Attributes
----------
layer_info : LayerInfo
All static information of the original layer.
layer : tf.keras.layers.Layer
The original layer.
config : JSON object
Expand All @@ -233,6 +312,10 @@ def __init__(self, layer, config, pruner):
_logger.info('Layer detected to compress: %s', self.layer.name)

def call(self, *inputs):
self._update_weights()
return self.layer(*inputs)

def _update_weights(self):
new_weights = []
for weight in self.layer.weights:
mask = self.masks.get(weight.name)
Expand All @@ -243,7 +326,6 @@ def call(self, *inputs):
if new_weights and not hasattr(new_weights[0], 'numpy'):
raise RuntimeError('NNI: Compressed model can only run in eager mode')
self.layer.set_weights([weight.numpy() for weight in new_weights])
return self.layer(*inputs)


# TODO: designed to replace `patch_optimizer`
Expand Down
33 changes: 31 additions & 2 deletions test/ut/sdk/test_compressor_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path
import tempfile
import unittest

import numpy as np
Expand All @@ -27,6 +29,9 @@
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10 = tf.constant([[1.0] * 10])

# This tensor is used as input of CNN models
image_tensor = tf.zeros([1, 10, 10, 3])


@unittest.skipIf(tf.__version__[0] != '2', 'Skip TF 1.x setup')
class TfCompressorTestCase(unittest.TestCase):
Expand All @@ -42,13 +47,37 @@ def _test_layer_detection_on_model(self, model):
layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers)
assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types

def test_level_pruner(self):
def test_level_pruner_and_export_correctness(self):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model = build_naive_model()
pruners['level'](model).compress()
pruner = pruners['level'](model)
model = pruner.compress()

x = model(tensor1x10)
assert x.numpy() == 94.5

temp_dir = Path(tempfile.gettempdir())
pruner.export_model(temp_dir / 'model', temp_dir / 'mask')

# because exporting will uninstrument and re-instrument the model,
# we must test the model again
x = model(tensor1x10)
assert x.numpy() == 94.5

# load and test exported model
exported_model = tf.keras.models.load_model(temp_dir / 'model')
x = exported_model(tensor1x10)
assert x.numpy() == 94.5

def test_export_not_crash(self):
for model in [CnnModel(), build_sequential_model()]:
pruner = pruners['level'](model)
model = pruner.compress()
# cannot use model.build(image_tensor.shape) here
# it fails even without compression
# seems TF's bug, not ours
model(image_tensor)
pruner.export_model(tempfile.TemporaryDirectory().name)

try:
from tensorflow.keras import Model, Sequential
Expand Down

0 comments on commit b7062b5

Please sign in to comment.