Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix quant grad function calculation error #3160

Merged
merged 6 commits into from
Dec 8, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,26 @@ class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@classmethod
def _quantize(cls, x, scale, zp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest changing zp to zero_point to keep consensus.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, changed already

r"""Reference function for quantizing x -- non-clamped.
"""
return ((x / scale) + zp).round()
@classmethod
def get_bits_length(cls, config, quant_type):
"""
get bit for quantize config
:param config:
:param quant_type:
:return:
Copy link
Contributor

@QuanluZhang QuanluZhang Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update docstring format to be consistent with others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

"""
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)

@staticmethod
def quant_backward(tensor, grad_output, quant_type):
def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Expand All @@ -608,24 +626,34 @@ def quant_backward(tensor, grad_output, quant_type):
tensor
gradient of the input of quantization operation
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment of this function need to be updated or delete.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, forgot to update the comment

tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to calculate tensor_q here, return a mask from wrapper.quantizer.quantize_input for example.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we need tensor_q because we need unclampped quantized data, wrapper.quantizer.quantize_input is clampped to [qmin, qmax], so the mask will always be ones everywhere if we just use wrapper.quantizer.quantize_input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can do this inside the wrapper.quantizer.quantize_input and return the mask which can be used in backward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore you don't do quantization twice and tensor saved will be smaller by 4 times.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, u are right. I did write the code like that at first.. For doing this ,we need to modify _quantize() to return mask. But i gave it up because i dont think it's a good design. Just let _quantize() to return quantized tensor is a better design i think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but keeping this design will actually make many quantization algorithms with special backward function harder and less efficient to implement. The actual problem is quant backward is inside the autograd while the forward is inside the quantizer module. I think it will be much better if forward and backward logic is moved into autograd function together and quantizer only maintain some states.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I said is actually what Pytorch implementation does.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes u r right. And this calls for a HUGE design change on NNI's QAT implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, i think i will just push the correct STE code and what u mentioned is left for next pr~

mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output

@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
if quant_type == QuantType.QUANT_INPUT:
Copy link
Contributor

@linbinskn linbinskn Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest changing defination of QuantType to 'input', 'weight' and 'output' and we don't need to modify quant_type in the if-else clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um.., I dont think it is relavant with this PR. May be you need to update all definition for quant type in another pr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help update line 583-585 to update the values from 0, 1, 2 to 'input', 'weight', 'output'? we have checked the code, it is safe to make this modification :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok~ 👌

return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
quant_type = 'input'
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
quant_type = 'weight'
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
quant_type = 'output'
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")

bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need save mask here if you want to do clip gradient, which is a bytetensor and therefore 4 times smaller than saving a floattensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason why we used to use a pass-through-estimator is that we always set rmin, rmax to be the min(tensor), max(tensor). But your propose is correct if rmin is larger than the min(tensor) and rmax is smaller than the max(tensor).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for weight, rmin and rmax is always the min and max of the tensor, but think about output. Because we use ema for updating rmin and rmax, so not all item in output is in [rmin, rmax]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your propose is necessary in quantizing ouput, thanks for submitting this PR!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eedalong @Cjkkkk , thanks for the discussion. The design of this pr looks reasonable. let's keep it in this pr. We will do some refactor later to put forward logic and backward logic together within the same class, which should be much easier for users to customize new quantization algorithms.

return output

@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
tensor, scale, zero_point, qmin, qmax = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax)
return output, None, None, None

def _check_weight(module):
Expand Down