Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the setting of global gradient clipping threshold #2216

Merged
merged 5 commits into from
May 23, 2017

Conversation

kuke
Copy link
Contributor

@kuke kuke commented May 19, 2017

resolve #1894

@kuke kuke requested review from lcy-seso and qingqing01 May 19, 2017 09:15
Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在是有global和local的 clipping阈值,两个是取min。 但学习率是global_lr * local_lr, clipping是取min呢?还是乘积呢? @pengli09 @lcy-seso

@@ -201,6 +227,13 @@ class ParameterOptimizer {
* so, if lr change in StartBatch, please assign to learningRate_
*/
real learningRate_;

/**
* global threshold for grdient clipping,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo error: grdient

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected

<< " max grad=" << maxAbsGrad << " avg grad=" << avgAbsGrad;
LOG(INFO) << "parameter=" << config.name() << " need clipping by local threshold="
<< config.gradient_clipping_threshold()
<< ", max grad=" << maxAbsGrad << ", avg grad=" << avgAbsGrad;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块global和local可以合并成一个,不用分开两次~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样设计主要是考虑global和local可以分开设置,用户可以灵活选择,逻辑上也比较清晰,最重要的是不用更改update函数的接口,避免更多的地方产生修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但是,这个函数里会计算两次,global做一次,local做一次。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在local threshold<global threshold的情况下确实会

@@ -290,6 +302,8 @@ void AdamaxParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
CHECK(sparseId == -1UL) << "Sparse update is not supported";
globalGradientClipping(vecs, config, FLAGS_log_clipping);
Copy link
Contributor

@qingqing01 qingqing01 May 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adam/Adagrad/Adadelt/RMSProp这些设置了clipping之后不会使用OptimizerWithGradientClipping这个嘛? 为啥除过momentum的要单独加到update里?

Copy link
Contributor Author

@kuke kuke May 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个并不会, OptimizerWithGradientClipping只在某层自设了gradient clipping后才会被调用, 请看文件paddle/parameter/OptimizerWithRegularizer.cpp的134~137行

if (paraConfig.gradient_clipping_threshold() > 0.0f &&
       !dynamic_cast<AddOptimizer*>(optimizer)) {
    optimizer = new OptimizerWithGradientClipping(optConfig, optimizer);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥 "设置gradient clipping", Adamax等Optimizer不会走到 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/OptimizerWithRegularizer.cpp#L136 这里?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为paraConfig这个对象是只是local的参数,global的gradient_clipping_threshold不会在这里起作用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global的学习率和grad_clipping都是ParameterOptimizer这个基类的初始化参数optConfig中获取的

Copy link
Contributor

@qingqing01 qingqing01 May 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,我的意思是:既然有OptimizerWithGradientClipping这个class,所有的Optimizer(Adam, Adagrad, Adadelt,Momentum等)按道理都应该统一用OptimizerWithGradientClipping这个class,而不是每个Optimizer的update里又独自做clipping。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,此处我觉得是可以再改进一下,稍后我再commit

@qingqing01 qingqing01 requested a review from pengli09 May 19, 2017 09:27
@kuke
Copy link
Contributor Author

kuke commented May 19, 2017

Optimized the logic. please continue to review

@kuke kuke force-pushed the enable_grad_clipping_dev branch 3 times, most recently from 8682e28 to d951d0f Compare May 22, 2017 10:25
@@ -303,18 +305,35 @@ void AdamaxParameterOptimizer::update(const VectorPtr vecs[],
void OptimizerWithGradientClipping::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
real global_thres_ = optConfig_.gradient_clipping_threshold();
real local_thres_ = config.gradient_clipping_threshold();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_thres_, local_thres_ 不符合Paddle规范~

Yibing Liu added 2 commits May 23, 2017 14:20
correct a typo

optimize code

fix a bug
@kuke kuke force-pushed the enable_grad_clipping_dev branch from d951d0f to c042413 Compare May 23, 2017 06:23
@kuke kuke force-pushed the enable_grad_clipping_dev branch from c042413 to 8c9ab5f Compare May 23, 2017 07:07
@kuke
Copy link
Contributor Author

kuke commented May 23, 2017

@qingqing01 has no more comments. @lcy-seso @pengli09 could you please have a review?

Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@lcy-seso
Copy link
Contributor

@qingqing01 我支持grdient clipping 取 min,而学习率沿用目前的方式取gloabal 和 local learning rate的连乘。

@lcy-seso lcy-seso removed the request for review from pengli09 May 23, 2017 08:16
@pengli09
Copy link
Contributor

@qingqing01 @lcy-seso 我反对取min,如果取min,那么意味着只能把个别layer的grdient clipping threshold设得比全局小,但不可能设得比全局更大了,这会导致功能缺陷

@pengli09
Copy link
Contributor

@lcy-seso @qingqing01 另外,关于global/local参数设置的处理方式,强烈建议与目前的行为方式保持一致。通常用户在拿到新版本的paddle后,不会意料到这些参数的处理方式发生了改变,很可能会出现同样的配置新老版本运行结果不一致,这将给用户debug模型带来非常大的障碍

@qingqing01
Copy link
Contributor

这个改动之前只有local的参数,没有global的。@pengli09 说的有道理,那用乘积吧?

@lcy-seso
Copy link
Contributor

那在python配置解析的时候加一个info输出,这个参数之前只有local没有global,用户的习惯是设置的很大,如果连乘,会让这个参数过大而失效。

@kuke
Copy link
Contributor Author

kuke commented May 23, 2017

@pengli09 考虑到这是两个阈值,乘积可能不太make sense。另外可能取min是不太合适,可否这样做:当local有效的时候就取local的threshlod,否则就取global的值?

@pengli09
Copy link
Contributor

@qingqing01 我觉着要不仔细理一下现在paddle各个global / local参数的处理规则,取使用最多的那种方式?比如learning_rate采用是乘,而decay_rate采用的是layer自己设了就用设置的值、否则用全局值,我觉着最好新加的设置的行为要和已有的同类情况下占多数的行为保持一致,否则一方面会给用户造成设置很混乱的印象,另一方面也需要用户去记住很多东西才能用对,这样是不好的。

Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some discussion, please do not merge this PR.

@lcy-seso
Copy link
Contributor

不能按照min取,确有道理。我支持当local设置时,时取local 指定的阈值,否则取global指定的阈值。

@qingqing01
Copy link
Contributor

"当local设置时,时取local 指定的阈值,否则取global指定的阈值", 赞同 +1

@lcy-seso
Copy link
Contributor

“当local设置时,取local 指定的阈值,否则取global指定的阈值。”
这种逻辑和之前的配置逻辑保持一致,不改变之前的使用习惯。

@kuke
Copy link
Contributor Author

kuke commented May 23, 2017

已修改 @lcy-seso @qingqing01 请看看?

Copy link
Contributor

@lcy-seso lcy-seso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuke kuke merged commit 8b6f374 into PaddlePaddle:develop May 23, 2017
@kuke kuke deleted the enable_grad_clipping_dev branch May 23, 2017 15:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable gradient clipping.
4 participants