-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix quant grad function calculation error #3160
Conversation
@@ -608,24 +626,35 @@ def quant_backward(tensor, grad_output, quant_type): | |||
tensor | |||
gradient of the input of quantization operation | |||
""" | |||
return grad_output | |||
tensor_q = QuantGrad._quantize(tensor, scale, zero_point) | |||
mask = (tensor_q >= qmin) * (tensor_q <= qmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
else: | ||
raise ValueError("unrecognized QuantType.") | ||
|
||
bits = QuantGrad.get_bits_length(wrapper.config, quant_type) | ||
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device) | ||
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you only need save mask here if you want to do clip gradient, which is a bytetensor and therefore 4 times smaller than saving a floattensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason why we used to use a pass-through-estimator is that we always set rmin, rmax to be the min(tensor), max(tensor). But your propose is correct if rmin is larger than the min(tensor) and rmax is smaller than the max(tensor).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, for weight, rmin and rmax is always the min and max of the tensor, but think about output. Because we use ema for updating rmin and rmax, so not all item in output is in [rmin, rmax]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, your propose is necessary in quantizing ouput, thanks for submitting this PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -608,24 +626,35 @@ def quant_backward(tensor, grad_output, quant_type): | |||
tensor | |||
gradient of the input of quantization operation | |||
""" | |||
return grad_output | |||
tensor_q = QuantGrad._quantize(tensor, scale, zero_point) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to calculate tensor_q here
, return a mask from wrapper.quantizer.quantize_input
for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we need tensor_q because we need unclampped quantized data, wrapper.quantizer.quantize_input is clampped to [qmin, qmax], so the mask will always be ones everywhere if we just use wrapper.quantizer.quantize_input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean you can do this inside the wrapper.quantizer.quantize_input
and return the mask which can be used in backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Therefore you don't do quantization twice and tensor saved will be smaller by 4 times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, u are right. I did write the code like that at first.. For doing this ,we need to modify _quantize() to return mask. But i gave it up because i dont think it's a good design. Just let _quantize() to return quantized tensor is a better design i think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but keeping this design will actually make many quantization algorithms with special backward function harder and less efficient to implement. The actual problem is quant backward is inside the autograd while the forward is inside the quantizer module. I think it will be much better if forward and backward logic is moved into autograd function together and quantizer only maintain some states.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I said is actually what Pytorch implementation does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes u r right. And this calls for a HUGE design change on NNI's QAT implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now, i think i will just push the correct STE code and what u mentioned is left for next pr~
check pytorch's implmentation guys:
|
@@ -608,24 +626,35 @@ def quant_backward(tensor, grad_output, quant_type): | |||
tensor | |||
gradient of the input of quantization operation | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment of this function need to be updated or delete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah, forgot to update the comment
return grad_output | ||
|
||
@staticmethod | ||
def forward(ctx, tensor, quant_type, wrapper, **kwargs): | ||
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) | ||
if quant_type == QuantType.QUANT_INPUT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest changing defination of QuantType to 'input', 'weight' and 'output' and we don't need to modify quant_type in the if-else clause.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Um.., I dont think it is relavant with this PR. May be you need to update all definition for quant type in another pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you help update line 583-585 to update the values from 0, 1, 2 to 'input', 'weight', 'output'? we have checked the code, it is safe to make this modification :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok~ 👌
get bit for quantize config | ||
:param config: | ||
:param quant_type: | ||
:return: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update docstring format to be consistent with others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool
@@ -589,8 +589,26 @@ class QuantGrad(torch.autograd.Function): | |||
""" | |||
Base class for overriding backward function of quantization operation. | |||
""" | |||
@classmethod | |||
def _quantize(cls, x, scale, zp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest changing zp to zero_point to keep consensus.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, changed already
#3159
Quant grad calculation is not correct in current nni's implementation according to google's whitepaper, it should be