-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
learning rate issue in LRScheduler #750
Comments
@eric-haibin-lin do you see it as a common problem for the lr scheduler in mxnet we well? |
@ljk628 sorry for the late response. Could you provide a simple script to reproduce the issue? The learning rate looks fine if I try the following example:
Thanks! |
Hi @eric-haibin-lin , I think the problem comes with gluoncv 0.4.0, here is the commands to reproduce the bug
I am using EC2 Deep Learning AMI 22.0 and the pre-installed MXNet ( The script you provided will produce an AssertionError with the gluoncv 0.4.
As described, the issue is that the Ideally, |
I could also confirm the issue on MxNet 1.5.0 and GluonCV 0.5.0.
|
@hetong007 Sounds like a bug in the new LRScheduler? |
It is from these two lines: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/optimizer/optimizer.py#L106-L107 The optimizer overwrites @zhreshold @eric-haibin-lin any better idea? |
I think this is awful, at least we need to have a warning to it. |
@hetong007 How about add something to respect the existing learning rate from |
It has been fixed in mxnet with apache/mxnet#16487. |
There seems to be a bug in LRScheduler for the initial learning rate assignment.
If we replace the line of https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L141 with a simple step based scheduler
and add
print(trainer.learning_rate)
after https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L350, you will find it prints 0.01 rather than the true value ofopt.lr
.This is because when the the optimizer is initialized in the trainer, it was given the default
learning_rate=0.01
evenlr_scheduler
has abase_lr
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/optimizer/optimizer.py#L102.One way to fix it is to add
learning_rate
explicitly in https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L166, i.e.,However, it is not elegant as
opt.lr
is already provided in thelr_scheduler
and should not be provided again.The text was updated successfully, but these errors were encountered: