Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

learning rate issue in LRScheduler #750

Closed
ljk628 opened this issue Apr 21, 2019 · 10 comments
Closed

learning rate issue in LRScheduler #750

ljk628 opened this issue Apr 21, 2019 · 10 comments

Comments

@ljk628
Copy link

ljk628 commented Apr 21, 2019

There seems to be a bug in LRScheduler for the initial learning rate assignment.

If we replace the line of https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L141 with a simple step based scheduler

lr_scheduler = LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)

and add print(trainer.learning_rate) after https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L350, you will find it prints 0.01 rather than the true value of opt.lr.

This is because when the the optimizer is initialized in the trainer, it was given the default learning_rate=0.01 even lr_scheduler has a base_lr https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/optimizer/optimizer.py#L102.

One way to fix it is to add learning_rate explicitly in https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L166, i.e.,

optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler, 'learning_rate': opt.lr}

However, it is not elegant as opt.lr is already provided in the lr_scheduler and should not be provided again.

@zhreshold
Copy link
Member

@hetong007

@hetong007
Copy link
Member

@eric-haibin-lin do you see it as a common problem for the lr scheduler in mxnet we well?

@eric-haibin-lin
Copy link
Member

@ljk628 sorry for the late response. Could you provide a simple script to reproduce the issue?
The optimizer is supposed to query the lr_scheduler when lr_scheduler is set. See https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/optimizer/optimizer.py#L190-L195

The learning rate looks fine if I try the following example:

>>> import mxnet as mx; import gluoncv; lr_scheduler = gluoncv.utils.LRScheduler('step', baselr=10, targetlr=0, niters=10, nepochs=2); optim = mx.optimizer.create('sgd', lr_scheduler=lr_scheduler)
>>> optim.learning_rate
10
>>> mx.__version__
'1.4.0'
>>> gluoncv.__version__
'0.3.0'

Thanks!

@ljk628
Copy link
Author

ljk628 commented May 7, 2019

Hi @eric-haibin-lin , I think the problem comes with gluoncv 0.4.0, here is the commands to reproduce the bug

>>> import mxnet as mx; import gluoncv; 
>>> lr_scheduler = gluoncv.utils.LRScheduler('step', base_lr=10, target_lr=0, niters=10, nepochs=2, step_epoch=[1, 2]); optim = mx.optimizer.create('sgd', lr_scheduler=lr_scheduler)
>>> optim.learning_rate
0.01
>>> mx.__version__
'1.4.0'
>>> gluoncv.__version__
'0.4.0'

I am using EC2 Deep Learning AMI 22.0 and the pre-installed MXNet (source activate mxnet_p27 or source activate mxnet_p36).

The script you provided will produce an AssertionError with the gluoncv 0.4.

File "/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/gluoncv/utils/lr_scheduler.py", line 92, in __init__
    assert(step_iter is not None or step_epoch is not None)

As described, the issue is that the base_lr is already set in the lr_scheduler and the optimizer is initialized with lr_scheduler. However, there is a default learning_rate parameter (0.01) for optimizer. If we don't explicitly set base_lr in the optimizer, it uses the default value (0.01).

Ideally, base_lr should be set once when the lr_scheduler is created but not again when initalize optimizer. This issue also exists for the current ImageNet finetuning script. https://github.com/dmlc/gluon-cv/blob/master/scripts/classification/imagenet/train_imagenet.py#L166

@chenliu0831
Copy link

I could also confirm the issue on MxNet 1.5.0 and GluonCV 0.5.0.

import mxnet as mx 
import gluoncv 
lr_scheduler = gluoncv.utils.LRScheduler('step', base_lr=10, target_lr=0, niters=10, nepochs=2, step_epoch=[1, 2]) 
optim = mx.optimizer.create('sgd', lr_scheduler=lr_scheduler)

print(optim.learning_rate) # 0.01

@zhreshold
Copy link
Member

@hetong007 Sounds like a bug in the new LRScheduler?

@hetong007
Copy link
Member

It is from these two lines: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/optimizer/optimizer.py#L106-L107

The optimizer overwrites base_lr in the lr_scheduler without checking anything. The easiest fix would be fixing that in the script. A more systematic way is to avoid that behavior in optimizer implementation but that may be an API-breaker.

@zhreshold @eric-haibin-lin any better idea?

@zhreshold
Copy link
Member

I think this is awful, at least we need to have a warning to it.

@chenliu0831
Copy link

chenliu0831 commented Oct 15, 2019

@hetong007 How about add something to respect the existing learning rate from lr_scheduler at mx.optimizer and fallback to the provided learning_rate from optimizer interface? Current override behavior probably is quite surprising to many users who provided lr to the scheduler object. I wonder if there're actual users rely on the current overwrite to make it a breaking change.

@hetong007
Copy link
Member

It has been fixed in mxnet with apache/mxnet#16487.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants