Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix quant grad function calculation error #3160

Merged
merged 6 commits into from
Dec 8, 2020
Merged

fix quant grad function calculation error #3160

merged 6 commits into from
Dec 8, 2020

Conversation

eedalong
Copy link
Contributor

@eedalong eedalong commented Dec 6, 2020

#3159

Quant grad calculation is not correct in current nni's implementation according to google's whitepaper, it should be

@@ -608,24 +626,35 @@ def quant_backward(tensor, grad_output, quant_type):
tensor
gradient of the input of quantization operation
"""
return grad_output
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q >= qmin) * (tensor_q <= qmax)
Copy link
Contributor

@Cjkkkk Cjkkkk Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

else:
raise ValueError("unrecognized QuantType.")

bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need save mask here if you want to do clip gradient, which is a bytetensor and therefore 4 times smaller than saving a floattensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason why we used to use a pass-through-estimator is that we always set rmin, rmax to be the min(tensor), max(tensor). But your propose is correct if rmin is larger than the min(tensor) and rmax is smaller than the max(tensor).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for weight, rmin and rmax is always the min and max of the tensor, but think about output. Because we use ema for updating rmin and rmax, so not all item in output is in [rmin, rmax]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your propose is necessary in quantizing ouput, thanks for submitting this PR!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eedalong @Cjkkkk , thanks for the discussion. The design of this pr looks reasonable. let's keep it in this pr. We will do some refactor later to put forward logic and backward logic together within the same class, which should be much easier for users to customize new quantization algorithms.

@@ -608,24 +626,35 @@ def quant_backward(tensor, grad_output, quant_type):
tensor
gradient of the input of quantization operation
"""
return grad_output
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to calculate tensor_q here, return a mask from wrapper.quantizer.quantize_input for example.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we need tensor_q because we need unclampped quantized data, wrapper.quantizer.quantize_input is clampped to [qmin, qmax], so the mask will always be ones everywhere if we just use wrapper.quantizer.quantize_input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can do this inside the wrapper.quantizer.quantize_input and return the mask which can be used in backward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore you don't do quantization twice and tensor saved will be smaller by 4 times.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, u are right. I did write the code like that at first.. For doing this ,we need to modify _quantize() to return mask. But i gave it up because i dont think it's a good design. Just let _quantize() to return quantized tensor is a better design i think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but keeping this design will actually make many quantization algorithms with special backward function harder and less efficient to implement. The actual problem is quant backward is inside the autograd while the forward is inside the quantizer module. I think it will be much better if forward and backward logic is moved into autograd function together and quantizer only maintain some states.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I said is actually what Pytorch implementation does.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes u r right. And this calls for a HUGE design change on NNI's QAT implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, i think i will just push the correct STE code and what u mentioned is left for next pr~

@eedalong
Copy link
Contributor Author

eedalong commented Dec 7, 2020

check pytorch's implmentation guys:

  1. back_grad for fake_quantize 's implementation in cuda in pytorch: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
  2. back_forward for fake_quantize's implementation in python in pytorch
    https://github.com/pytorch/pytorch/blob/master/torch/quantization/_learnable_fake_quantize.py

the key point is this picture

@eedalong eedalong requested a review from Cjkkkk December 7, 2020 09:19
@@ -608,24 +626,35 @@ def quant_backward(tensor, grad_output, quant_type):
tensor
gradient of the input of quantization operation
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment of this function need to be updated or delete.

Copy link
Contributor Author

@eedalong eedalong Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, forgot to update the comment

return grad_output

@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
if quant_type == QuantType.QUANT_INPUT:
Copy link
Contributor

@linbinskn linbinskn Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest changing defination of QuantType to 'input', 'weight' and 'output' and we don't need to modify quant_type in the if-else clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um.., I dont think it is relavant with this PR. May be you need to update all definition for quant type in another pr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help update line 583-585 to update the values from 0, 1, 2 to 'input', 'weight', 'output'? we have checked the code, it is safe to make this modification :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok~ 👌

get bit for quantize config
:param config:
:param quant_type:
:return:
Copy link
Contributor

@QuanluZhang QuanluZhang Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update docstring format to be consistent with others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

@@ -589,8 +589,26 @@ class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@classmethod
def _quantize(cls, x, scale, zp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest changing zp to zero_point to keep consensus.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, changed already

@QuanluZhang QuanluZhang merged commit 2f6a74f into microsoft:master Dec 8, 2020
@liuzhe-lz liuzhe-lz mentioned this pull request Dec 25, 2020
77 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants