Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix learning rate scheduler being unexpectedly overwritten by optimizer's default value #16487

Merged
merged 7 commits into from
Oct 18, 2019

Conversation

hetong007
Copy link
Contributor

@hetong007 hetong007 commented Oct 15, 2019

Description

Right now the following example (similar as from dmlc/gluon-cv#750) yields surprising and unexpected output:

import mxnet as mx 
lr_scheduler = mx.lr_scheduler.FactorScheduler(1000, base_lr=10) 
optim = mx.optimizer.create('sgd', lr_scheduler=lr_scheduler)

print(optim.learning_rate) # 0.01

In the current implementation of optimizer, mxnet overwrite the learning rate scheduler's learning rate with the optimizer's learning rate by

if lr_scheduler is not None:	
    self.lr_scheduler.base_lr = learning_rate

Therefore, when users define a learning rate scheduler, and use it as the parameter of the optimizer, they are going to surprisingly find that the learning rate has been overwritten by the default value, 0.01.

My patch set default learning_rate to None, and assign the original 0.01 to it only in the absence of lr_scheduler.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

self.lr = learning_rate
if self.lr_scheduler is not None and learning_rate is not None:
if self.lr_scheduler.base_lr != learning_rate:
raise UserWarning("learning rate from ``lr_scheduler`` has been "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this terminate the execution and start the exception handling? If so, the statement below wouldn't be executed

Copy link
Member

@wkcn wkcn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix : )
The default lr should be presented.

python/mxnet/optimizer/optimizer.py Show resolved Hide resolved
Copy link
Member

@wkcn wkcn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you : )

@hetong007 hetong007 merged commit 27f7082 into apache:master Oct 18, 2019
@MrRaghav
Copy link

MrRaghav commented Jul 4, 2020

Hello,
I am still getting this error while using mxnet with sockeye. Since, this was fixed in the new release, I didn't open a new bug.
Please find the details in following points:

  1. I'm using following versions of mxnet and sockeye 2.1.7 (on CUDA 10.1)
    [username]@[server]:/username/sockeye/dir1$ pip3 list | grep mxnet
    mxnet 1.6.0
    mxnet-cu101mkl 1.6.0
    mxnet-mkl 1.6.0
    [username]@[server]:
    /username/sockeye/dir1$ pip3 list | grep sockeye
    sockeye 2.1.7

  2. When I run the sockeye.train command with arguments, I get following log:

    [username]@[server]:~/username/sockeye$ tail -30 77233.out
    File "/home/username/.local/lib/python3.7/site-packages/sockeye/train.py", line 997, in
    main()
    File "/home/username/.local/lib/python3.7/site-packages/sockeye/train.py", line 764, in main
    train(args)
    File "/home/username/.local/lib/python3.7/site-packages/sockeye/train.py", line 992, in train
    training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter, checkpoint_decoder=cp_decoder)
    File "/home/username/.local/lib/python3.7/site-packages/sockeye/training.py", line 242, in fit

    self._step(batch=train_iter.next())
    File "/home/username/.local/lib/python3.7/site-packages/sockeye/training.py", line 346, in step
    loss_func.metric.update(loss_value.asscalar(), num_samples.asscalar())
    File "/home/username/.local/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 2553, in asscalar
    return self.asnumpy()[0]
    File "/home/username/.local/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 2535, in asnumpy
    ctypes.c_size_t(data.size)))
    File "/home/username/.local/lib/python3.7/site-packages/mxnet/base.py", line 255, in check_call

    raise MXNetError(py_str(LIB.MXGetLastError()))
    mxnet.base.MXNetError: [09:58:26] src/storage/./pooled_storage_manager.h:161: cudaMalloc retry failed: out of memory
    Stack trace:
    [bt] (0) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x6d554b) [0x7f6c5b3d054b]
    [bt] (1) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x41a0c72) [0x7f6c5ee9bc72]
    [bt] (2) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x41a694f) [0x7f6c5eea194f]
    [bt] (3) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x3972e10) [0x7f6c5e66de10]
    [bt] (4) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x39730c7) [0x7f6c5e66e0c7]
    [bt] (5) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocatormxnet::TBlob > const&, std::vector<mxnet::OpReqType, std::allocatormxnet::OpReqType > const&, std::vector<mxnet::TBlob, std::allocatormxnet::TBlob > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocatormxnet::engine::Var* > const&, std::vector<mxnet::engine::Var*, std::allocatormxnet::engine::Var* > const&, std::vector<mxnet::Resource, std::allocatormxnet::Resource > const&, std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, std::vector<unsigned int, std::allocator > const&, std::vector<mxnet::OpReqType, std::allocatormxnet::OpReqType > const&)::{lambda(mxnet::RunContext)Add some ops #1}::operator()(mxnet::RunContext) const+0x281) [0x7f6c5e66e4d1]
    [bt] (6) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x3896f19) [0x7f6c5e591f19]
    [bt] (7) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x38a3c31) [0x7f6c5e59ec31]
    [bt] (8) /home/username/.local/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x38a7170) [0x7f6c5e5a2170]

learning rate from lr_scheduler has been overwritten by learning_rate in optimizer.

  1. Also, I can see another error "cudaMalloc retry failed: out of memory" in above log and I checked MXNetError:cudaMalloc failed: out of memory deepinsight/insightface#257 to find a fix. They've mentioned that reducing the batches solves the issue but I am not using any such argument in sockeye.train

  2. The arguments used with sockeye are as follows:
    python3 -m sockeye.train -d training_data
    -vs dev.BPE.de
    -vt dev.BPE.en
    --shared-vocab
    -o parallel/wmt_model

  3. I found the code https://mxnet.apache.org/api/python/docs/_modules/mxnet/optimizer/optimizer.html which says learning_rate should be assigned to self.lr_scheduler.base_lr with above warning. But, I am getting it as an error and the output comes as failed.

  4. Moreover, I checked the release notes of mxnet 1.6.0 from below link and can see that this issue has been fixed.
    https://cwiki.apache.org/confluence/display/MXNET/1.6.0+Release+notes#id-1.6.0Releasenotes-Bugfixes

I hope I didn't miss anything before coming to you. can you please suggest what should be done in such scenario?

@szha
Copy link
Member

szha commented Jul 5, 2020

@MrRaghav it looks like an out of memory issue that is different from what this PR is about. It would be best if you could create a new issue for it and I will try to help there. Thanks

@MrRaghav
Copy link

MrRaghav commented Jul 5, 2020

Thank you. I have created one: #18662

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants