-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adadelta Optimizer #26590
Adadelta Optimizer #26590
Changes from 2 commits
e45dcff
85b3f92
cbcd950
9661a54
f542d77
73baac0
5a55869
fd34fbd
f5e6881
a715c46
e67cd86
da4025d
f3699cb
6f00384
654377d
fa7ccb1
9aaf899
b727dad
9cf4c3b
2e8d253
00c38fc
b75ab16
84205ce
b6fa771
9cd1838
95310f5
ce31795
0780b9c
87a7f56
06f3c73
fd67080
27d498d
e758d2d
5252ba7
0674403
fecb57e
a140fa7
093abc9
47d0af1
55877dd
bf7d4a0
d1e4ce4
c154a07
aedaf12
1de492d
2f01874
a44ef94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ class Adadelta(Optimizer): | |
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. | ||
epsilon (float): a small float number for numeric stability. Default 1.0e-6. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float, optional |
||
rho (float): a floating point value indicating the decay rate. Default 0.95. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float, optional |
||
parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ | ||
parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ | ||
This parameter is required in dygraph mode. \ | ||
The default value is None in static mode, at this time all parameters will be updated. | ||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ | ||
|
@@ -72,18 +72,18 @@ class Adadelta(Optimizer): | |
loss = paddle.mean(out) | ||
beta1 = paddle.to_tensor([0.9], dtype="float32") | ||
beta2 = paddle.to_tensor([0.99], dtype="float32") | ||
adam = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01) | ||
adadelta = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01) | ||
back = out.backward() | ||
adam.step() | ||
adam.clear_grad() | ||
adadelta.step() | ||
adadelta.clear_grad() | ||
|
||
""" | ||
|
||
_avg_squared_grad_acc_str = "_avg_squared_grad" | ||
_avg_squared_update_acc_str = "_avg_squared_update" | ||
|
||
def __init__(self, | ||
learning_rate, | ||
learning_rate=0.001, | ||
epsilon=1.0e-6, | ||
rho=0.95, | ||
parameters=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .optimizer import Optimizer | ||
from ..fluid import core | ||
from ..fluid import framework | ||
from ..fluid.framework import Variable, name_scope | ||
|
||
__all__ = ["Momentum"] | ||
|
||
|
||
class Momentum(Optimizer): | ||
""" | ||
|
||
Simple Momentum optimizer with velocity state | ||
|
||
This optimizer has a flag for Nestrov Momentum. | ||
|
||
The update equations are as follows: | ||
|
||
.. math:: | ||
|
||
& velocity = mu * velocity + gradient | ||
|
||
& if (use\_nesterov): | ||
|
||
&\quad param = param - (gradient + mu * velocity) * learning\_rate | ||
|
||
& else: | ||
|
||
&\quad param = param - learning\_rate * velocity | ||
|
||
Parameters: | ||
|
||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. | ||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. | ||
momentum (float): Momentum factor. The default value is 0.9. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float, optional |
||
parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ | ||
This parameter is required in dygraph mode. \ | ||
The default value is None in static mode, at this time all parameters will be updated. | ||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ | ||
It canbe a float value as coeff of L2 regularization or \ | ||
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. | ||
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ | ||
the regularization setting here in optimizer will be ignored for this parameter. \ | ||
Otherwise, the regularization setting here in optimizer will take effect. \ | ||
Default None, meaning there is no regularization. | ||
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of | ||
some derived class of ``GradientClipBase`` . There are three cliping strategies | ||
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , | ||
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. | ||
name (str, optional): The default value is None. Normally there is no need for user | ||
to set this property. For more information, please refer to | ||
:ref:`api_guide_Name` . | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import paddle | ||
import paddle.fluid as fluid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to use fluid here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 谢谢指出 已作删除 |
||
import numpy as np | ||
paddle.disable_static() | ||
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") | ||
linear = paddle.nn.Linear(10, 10) | ||
inp = paddle.to_tensor(inp) | ||
out = linear(inp) | ||
loss = paddle.mean(out) | ||
beta1 = paddle.to_tensor([0.9], dtype="float32") | ||
beta2 = paddle.to_tensor([0.99], dtype="float32") | ||
momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01) | ||
back = out.backward() | ||
momentum.step() | ||
momentum.clear_grad() | ||
""" | ||
_velocity_acc_str = "velocity" | ||
|
||
def __init__(self, | ||
learning_rate=0.001, | ||
momentum=0.9, | ||
parameters=None, | ||
use_nesterov=False, | ||
weight_decay=None, | ||
grad_clip=None, | ||
name=None): | ||
if learning_rate is None: | ||
raise ValueError("learning_rate is not set") | ||
if momentum is None: | ||
raise ValueError("momentum is not set") | ||
super(Momentum, self).__init__( | ||
learning_rate=learning_rate, | ||
parameters=parameters, | ||
weight_decay=weight_decay, | ||
grad_clip=grad_clip, | ||
name=name) | ||
self.type = "momentum" | ||
self._momentum = momentum | ||
self._use_nesterov = bool(use_nesterov) | ||
|
||
def _create_accumulators(self, block, parameters): | ||
assert isinstance(block, framework.Block) | ||
|
||
for p in parameters: | ||
self._add_accumulator(self._velocity_acc_str, p) | ||
|
||
def _append_optimize_op(self, block, param_and_grad): | ||
assert isinstance(block, framework.Block) | ||
|
||
velocity_acc = self._get_accumulator(self._velocity_acc_str, | ||
param_and_grad[0]) | ||
lr = self._create_param_lr(param_and_grad) | ||
|
||
if framework.in_dygraph_mode(): | ||
_, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1], | ||
velocity_acc, lr, param_and_grad[0], | ||
velocity_acc, 'mu', self._momentum, | ||
'use_nesterov', self._use_nesterov) | ||
return None | ||
|
||
attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} | ||
inputs = { | ||
"Param": [param_and_grad[0]], | ||
"Grad": [param_and_grad[1]], | ||
"Velocity": [velocity_acc], | ||
"LearningRate": [lr] | ||
} | ||
|
||
outputs = { | ||
"ParamOut": [param_and_grad[0]], | ||
"VelocityOut": [velocity_acc] | ||
} | ||
# create the momentum optimize op | ||
momentum_op = block.append_op( | ||
type=self.type, | ||
inputs=inputs, | ||
outputs=outputs, | ||
attrs=attrs, | ||
stop_gradient=True) | ||
|
||
return momentum_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__
里没有默认值呀。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢 已作修改