diff --git a/docs/en_US/Compressor/Overview.md b/docs/en_US/Compressor/Overview.md index 5fc8e45c5d..b078d748a6 100644 --- a/docs/en_US/Compressor/Overview.md +++ b/docs/en_US/Compressor/Overview.md @@ -180,12 +180,54 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer): def quantize_weight(self, weight, config, **kwargs): """ - weight is the target weight tensor - config is the selected dict object in config_list for this layer - kwargs contains op, op_types, and op_name - design your quantizer and return new weight + quantize should overload this method to quantize weight tensors. + This method is effectively hooked to :meth:`forward` of the model. + + Parameters + ---------- + weight : Tensor + weight that needs to be quantized + config : dict + the configuration for weight quantization """ + + # Put your code to generate `new_weight` here + return new_weight + + def quantize_output(self, output, config, **kwargs): + """ + quantize should overload this method to quantize output. + This method is effectively hooked to `:meth:`forward` of the model. + + Parameters + ---------- + output : Tensor + output that needs to be quantized + config : dict + the configuration for output quantization + """ + + # Put your code to generate `new_output` here + + return new_output + + def quantize_input(self, *inputs, config, **kwargs): + """ + quantize should overload this method to quantize input. + This method is effectively hooked to :meth:`forward` of the model. + + Parameters + ---------- + inputs : Tensor + inputs that needs to be quantized + config : dict + the configuration for inputs quantization + """ + + # Put your code to generate `new_input` here + + return new_input # note for pytorch version, there is no sess in input arguments def update_epoch(self, epoch_num, sess): @@ -200,8 +242,6 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer): pass ``` -__[TODO]__ Will add another member function `quantize_layer_output`, as some quantization algorithms also quantize layers' output. - ### Usage of user customized compression algorithm __[TODO]__ ... diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index 6a60a29cf0..cc92938386 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -32,7 +32,23 @@ def __init__(self, model, config_list): """ self.bound_model = model self.config_list = config_list - self.modules_to_compress = [] + self.modules_to_compress = None + + def detect_modules_to_compress(self): + """ + detect all modules should be compressed, and save the result in `self.modules_to_compress`. + + The model will be instrumented and user should never edit it after calling this method. + """ + if self.modules_to_compress is None: + self.modules_to_compress = [] + for name, module in self.bound_model.named_modules(): + layer = LayerInfo(name, module) + config = self.select_config(layer) + if config is not None: + self.modules_to_compress.append((layer, config)) + return self.modules_to_compress + def compress(self): """ @@ -41,12 +57,9 @@ def compress(self): The model will be instrumented and user should never edit it after calling this method. `self.modules_to_compress` records all the to-be-compressed layers """ - for name, module in self.bound_model.named_modules(): - layer = LayerInfo(name, module) - config = self.select_config(layer) - if config is not None: - self._instrument_layer(layer, config) - self.modules_to_compress.append((layer, config)) + modules_to_compress = self.detect_modules_to_compress() + for layer, config in modules_to_compress: + self._instrument_layer(layer, config) return self.bound_model def get_modules_to_compress(self): @@ -55,7 +68,7 @@ def get_modules_to_compress(self): Returns ------- - self.modules_to_compress : list + list a list of the layers, each of which is a tuple (`layer`, `config`), `layer` is `LayerInfo`, `config` is a `dict` """ @@ -72,7 +85,7 @@ def select_config(self, layer): Returns ------- - ret : config or None + config or None the retrieved configuration for this layer, if None, this layer should not be compressed """ @@ -238,26 +251,87 @@ class Quantizer(Compressor): """ def quantize_weight(self, weight, config, op, op_type, op_name): - """user should know where dequantize goes and implement it in quantize method - we now do not provide dequantize method + """ + quantize should overload this method to quantize weight. + This method is effectively hooked to :meth:`forward` of the model. + + Parameters + ---------- + weight : Tensor + weight that needs to be quantized + config : dict + the configuration for weight quantization """ raise NotImplementedError("Quantizer must overload quantize_weight()") + def quantize_output(self, output, config, op, op_type, op_name): + """ + quantize should overload this method to quantize output. + This method is effectively hooked to :meth:`forward` of the model. + + Parameters + ---------- + output : Tensor + output that needs to be quantized + config : dict + the configuration for output quantization + """ + raise NotImplementedError("Quantizer must overload quantize_output()") + + def quantize_input(self, *inputs, config, op, op_type, op_name): + """ + quantize should overload this method to quantize input. + This method is effectively hooked to :meth:`forward` of the model. + + Parameters + ---------- + inputs : Tensor + inputs that needs to be quantized + config : dict + the configuration for inputs quantization + """ + raise NotImplementedError("Quantizer must overload quantize_input()") + + def _instrument_layer(self, layer, config): + """ + Create a wrapper forward function to replace the original one. + + Parameters + ---------- + layer : LayerInfo + the layer to instrument the mask + config : dict + the configuration for quantization + """ assert layer._forward is None, 'Each model can only be compressed once' - if not _check_weight(layer.module): - _logger.warning('Module %s does not have parameter "weight"', layer.name) - return + assert "quant_types" in config, 'must provide quant_types in config' + assert isinstance(config["quant_types"], list), 'quant_types must be list type' + + if 'weight' in config["quant_types"]: + if not _check_weight(layer.module): + _logger.warning('Module %s does not have parameter "weight"', layer.name) layer._forward = layer.module.forward def new_forward(*inputs): - weight = layer.module.weight.data - new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) - layer.module.weight.data = new_weight - return layer._forward(*inputs) + if 'input' in config["quant_types"]: + inputs = self.quantize_input(inputs, config=config, op=layer.module, op_type=layer.type, op_name=layer.name) + + if 'weight' in config["quant_types"] and _check_weight(layer.module): + weight = layer.module.weight.data + new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) + layer.module.weight.data = new_weight + result = layer._forward(*inputs) + layer.module.weight.data = weight + else: + result = layer._forward(*inputs) - layer.module.forward = new_forward + if 'output' in config["quant_types"]: + result = self.quantize_output(result, config, op=layer.module, op_type=layer.type, op_name=layer.name) + return result + + layer.module.forward = new_forward def _check_weight(module): try: diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index ca8b628640..e4eb0bbe46 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -114,7 +114,14 @@ def test_torch_pruner(self): def test_torch_quantizer(self): model = TorchMnist() - torch_compressor.NaiveQuantizer(model, [{'op_types': ['default']}]).compress() + configure_list = [{ + 'quant_types': ['weight'], + 'quant_bits': { + 'weight': 8, + }, + 'op_types':['Conv2d', 'Linear'] + }] + torch_compressor.NaiveQuantizer(model, configure_list).compress() if __name__ == '__main__':