-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Should this work with Mixed precision training (AMP) #31
Comments
Hi, it's likely to cause problems with low-precision. Note that we have (gt-mt)^2 in the denominator, it's likely to be 0 because of the low-precision, then it's just dividing 0 which will cause explosion. In Adam the denominator is gt^2, as long as one of gt in the history is not 0, then the denominator is not 0. It's possible to fix this by using a larger eps (or use eps according to the precision). Could you provide the code and script to reproduce? I'll consider update on this issue in the next release of package. Thanks a lot. BTW, what is AMP? |
@FilipAndersson245 @Mut1nyJD Thanks for help. A question regarding AMP, if I understand if correctly from the documentation, AMP first scales up loss by a factor, say scale by 63335, then backward to get the gradient, then divide by 65535 to get the gradient, and perform update (all operation in float16). Is this true? Or they get the scaled gradient in float16, convert to float32, update parameter in float32, then convert to float16? The first case is slightly tricky with eps, I quickly tested that if eps<1e-8, then it will be underflowed to 0 in float16 in numpy, same for pytorch, which means the eps=0 (though set as 1e-16). Also this might case the difference (gt-mt) to be 0. |
Not 100% sure but I would have thought in my naive view it is the later one. To keep precision high initially the gradient would be computed in float32 then scaled to float16 and backproped through the network. |
@Mut1nyJD |
@Mut1nyJD Hi, I just tried a by-pass to deal with the mixed precision issue, that is to cast weight and gradient to float32, update, then cast to float16. In this way the float32 burden is only applied to the weight update, but not the backward, so the computation overload would not be too much. See the code below. Please let me know if you have other suggestions.
|
Hi just a question is this optimizer compatible with Mixed precision training or AMP. I tried to use in in combination with lucidrains' lightweight-gan implementation which uses the PyTorch version of this optimizer. But after a few 100 iterations my losses go to NaN and eventually causes a Division by Zero error. Don't see the same problem with using the standard adam optimizer
The text was updated successfully, but these errors were encountered: