-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix quant grad function calculation error #3160
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -589,8 +589,46 @@ class QuantGrad(torch.autograd.Function): | |
""" | ||
Base class for overriding backward function of quantization operation. | ||
""" | ||
@classmethod | ||
def _quantize(cls, x, scale, zero_point): | ||
""" | ||
Reference function for quantizing x -- non-clamped. | ||
Parameters | ||
---------- | ||
x : Tensor | ||
tensor to be quantized | ||
scale : Tensor | ||
scale for quantizing x | ||
zero_point : Tensor | ||
zero_point for quantizing x | ||
Returns | ||
------- | ||
tensor | ||
quantized x without clamped | ||
""" | ||
return ((x / scale) + zero_point).round() | ||
@classmethod | ||
def get_bits_length(cls, config, quant_type): | ||
""" | ||
Get bit for quantize config | ||
Parameters | ||
---------- | ||
config : Dict | ||
the configuration for quantization | ||
quant_type : str | ||
quant type | ||
Returns | ||
------- | ||
int | ||
n-bits for quantization configuration | ||
""" | ||
if isinstance(config["quant_bits"], int): | ||
return config["quant_bits"] | ||
else: | ||
return config["quant_bits"].get(quant_type) | ||
|
||
@staticmethod | ||
def quant_backward(tensor, grad_output, quant_type): | ||
def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax): | ||
""" | ||
This method should be overrided by subclass to provide customized backward function, | ||
default implementation is Straight-Through Estimator | ||
|
@@ -600,32 +638,48 @@ def quant_backward(tensor, grad_output, quant_type): | |
input of quantization operation | ||
grad_output : Tensor | ||
gradient of the output of quantization operation | ||
quant_type : QuantType | ||
scale : Tensor | ||
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, | ||
you can define different behavior for different types. | ||
zero_point : Tensor | ||
zero_point for quantizing tensor | ||
qmin : Tensor | ||
quant_min for quantizing tensor | ||
qmax : Tensor | ||
quant_max for quantizng tensor | ||
Returns | ||
------- | ||
tensor | ||
gradient of the input of quantization operation | ||
""" | ||
tensor_q = QuantGrad._quantize(tensor, scale, zero_point) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to calculate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we need tensor_q because we need unclampped quantized data, wrapper.quantizer.quantize_input is clampped to [qmin, qmax], so the mask will always be ones everywhere if we just use wrapper.quantizer.quantize_input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean you can do this inside the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Therefore you don't do quantization twice and tensor saved will be smaller by 4 times. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, u are right. I did write the code like that at first.. For doing this ,we need to modify _quantize() to return mask. But i gave it up because i dont think it's a good design. Just let _quantize() to return quantized tensor is a better design i think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, but keeping this design will actually make many quantization algorithms with special backward function harder and less efficient to implement. The actual problem is quant backward is inside the autograd while the forward is inside the quantizer module. I think it will be much better if forward and backward logic is moved into autograd function together and quantizer only maintain some states. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I said is actually what Pytorch implementation does. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes u r right. And this calls for a HUGE design change on NNI's QAT implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for now, i think i will just push the correct STE code and what u mentioned is left for next pr~ |
||
mask = (tensor_q < qmin) | (tensor_q > qmax) | ||
grad_output[mask] = 0 | ||
return grad_output | ||
|
||
@staticmethod | ||
def forward(ctx, tensor, quant_type, wrapper, **kwargs): | ||
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) | ||
if quant_type == QuantType.QUANT_INPUT: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest changing defination of QuantType to 'input', 'weight' and 'output' and we don't need to modify quant_type in the if-else clause. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Um.., I dont think it is relavant with this PR. May be you need to update all definition for quant type in another pr There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you help update line 583-585 to update the values from 0, 1, 2 to 'input', 'weight', 'output'? we have checked the code, it is safe to make this modification :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok~ 👌 |
||
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) | ||
quant_type = 'input' | ||
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) | ||
elif quant_type == QuantType.QUANT_WEIGHT: | ||
return wrapper.quantizer.quantize_weight(wrapper, **kwargs) | ||
quant_type = 'weight' | ||
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs) | ||
elif quant_type == QuantType.QUANT_OUTPUT: | ||
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) | ||
quant_type = 'output' | ||
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) | ||
else: | ||
raise ValueError("unrecognized QuantType.") | ||
|
||
bits = QuantGrad.get_bits_length(wrapper.config, quant_type) | ||
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device) | ||
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you only need save mask here if you want to do clip gradient, which is a bytetensor and therefore 4 times smaller than saving a floattensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the reason why we used to use a pass-through-estimator is that we always set rmin, rmax to be the min(tensor), max(tensor). But your propose is correct if rmin is larger than the min(tensor) and rmax is smaller than the max(tensor). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, for weight, rmin and rmax is always the min and max of the tensor, but think about output. Because we use ema for updating rmin and rmax, so not all item in output is in [rmin, rmax] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, your propose is necessary in quantizing ouput, thanks for submitting this PR! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return output | ||
|
||
@classmethod | ||
def backward(cls, ctx, grad_output): | ||
tensor, quant_type = ctx.saved_variables | ||
output = cls.quant_backward(tensor, grad_output, quant_type) | ||
tensor, scale, zero_point, qmin, qmax = ctx.saved_variables | ||
output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax) | ||
return output, None, None, None | ||
|
||
def _check_weight(module): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment of this function need to be updated or delete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah, forgot to update the comment