Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix quant grad function calculation error #3160

Merged
merged 6 commits into from
Dec 8, 2020
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,46 @@ class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@classmethod
def _quantize(cls, x, scale, zero_point):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
tensor to be quantized
scale : Tensor
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
quantized x without clamped
"""
return ((x / scale) + zero_point).round()
@classmethod
def get_bits_length(cls, config, quant_type):
"""
Get bit for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
n-bits for quantization configuration
"""
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)

@staticmethod
def quant_backward(tensor, grad_output, quant_type):
def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Expand All @@ -600,32 +638,48 @@ def quant_backward(tensor, grad_output, quant_type):
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
-------
tensor
gradient of the input of quantization operation
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment of this function need to be updated or delete.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, forgot to update the comment

tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to calculate tensor_q here, return a mask from wrapper.quantizer.quantize_input for example.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we need tensor_q because we need unclampped quantized data, wrapper.quantizer.quantize_input is clampped to [qmin, qmax], so the mask will always be ones everywhere if we just use wrapper.quantizer.quantize_input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can do this inside the wrapper.quantizer.quantize_input and return the mask which can be used in backward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore you don't do quantization twice and tensor saved will be smaller by 4 times.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, u are right. I did write the code like that at first.. For doing this ,we need to modify _quantize() to return mask. But i gave it up because i dont think it's a good design. Just let _quantize() to return quantized tensor is a better design i think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but keeping this design will actually make many quantization algorithms with special backward function harder and less efficient to implement. The actual problem is quant backward is inside the autograd while the forward is inside the quantizer module. I think it will be much better if forward and backward logic is moved into autograd function together and quantizer only maintain some states.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I said is actually what Pytorch implementation does.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes u r right. And this calls for a HUGE design change on NNI's QAT implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, i think i will just push the correct STE code and what u mentioned is left for next pr~

mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output

@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
if quant_type == QuantType.QUANT_INPUT:
Copy link
Contributor

@linbinskn linbinskn Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest changing defination of QuantType to 'input', 'weight' and 'output' and we don't need to modify quant_type in the if-else clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um.., I dont think it is relavant with this PR. May be you need to update all definition for quant type in another pr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help update line 583-585 to update the values from 0, 1, 2 to 'input', 'weight', 'output'? we have checked the code, it is safe to make this modification :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok~ 👌

return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
quant_type = 'input'
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
quant_type = 'weight'
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
quant_type = 'output'
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")

bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need save mask here if you want to do clip gradient, which is a bytetensor and therefore 4 times smaller than saving a floattensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason why we used to use a pass-through-estimator is that we always set rmin, rmax to be the min(tensor), max(tensor). But your propose is correct if rmin is larger than the min(tensor) and rmax is smaller than the max(tensor).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for weight, rmin and rmax is always the min and max of the tensor, but think about output. Because we use ema for updating rmin and rmax, so not all item in output is in [rmin, rmax]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your propose is necessary in quantizing ouput, thanks for submitting this PR!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eedalong @Cjkkkk , thanks for the discussion. The design of this pr looks reasonable. let's keep it in this pr. We will do some refactor later to put forward logic and backward logic together within the same class, which should be much easier for users to customize new quantization algorithms.

return output

@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
tensor, scale, zero_point, qmin, qmax = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax)
return output, None, None, None

def _check_weight(module):
Expand Down