-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
@@ -148,6 +157,7 @@ def __init__(self, model, config_list, optimizer=None): | |||
types of nn.module you want to apply quantization, eg. 'Conv2d' | |||
""" | |||
super().__init__(model, config_list, optimizer) | |||
self.quant_grad = QATGrad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where we use self.quant_grad
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0: "input", | ||
1: "weight", | ||
2: "output" | ||
} | ||
|
||
class QuantGrad(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the best way to accomplish the quantization gradient calculation is to add quant_weight_grad
, quant_input_grad
, quant_output_grad
method to quantizer
class and QuantGrad
class only calls this methods instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, this will take us time to discuss a new design and quantizer
may be refactored in another pr. In this pr, we only fix dorefa and bnn unfit problem.
Because of the modification of compressor to solve the grad error, BNN and Dorefa can't operate normally on current code. This PR fixes this problem.