Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adadelta Optimizer #26590

Merged
merged 47 commits into from
Aug 29, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e45dcff
add doc; notest
MRXLT Aug 14, 2020
85b3f92
fix doc; notest
MRXLT Aug 14, 2020
cbcd950
update doc; notest
MRXLT Aug 14, 2020
9661a54
refine optimizer && adam
MRXLT Aug 14, 2020
f542d77
fix conflict
MRXLT Aug 17, 2020
73baac0
refine optimizer; notest
MRXLT Aug 18, 2020
5a55869
add adam
MRXLT Aug 18, 2020
fd34fbd
fix doc
MRXLT Aug 18, 2020
f5e6881
Merge remote-tracking branch 'upstream/develop' into 2.0-op
MRXLT Aug 18, 2020
a715c46
Merge remote-tracking branch 'upstream/develop' into 2.0-op
MRXLT Aug 19, 2020
e67cd86
fix doc && add adamw; notest
MRXLT Aug 19, 2020
da4025d
add error message
MRXLT Aug 19, 2020
f3699cb
bug fix
MRXLT Aug 19, 2020
6f00384
refine rmsprop && adamax
MRXLT Aug 19, 2020
654377d
fix ci
MRXLT Aug 19, 2020
fa7ccb1
buf fix
MRXLT Aug 19, 2020
9aaf899
update comment
MRXLT Aug 19, 2020
b727dad
unify arguments place; notest
MRXLT Aug 20, 2020
9cf4c3b
fix ut, test=develop
mapingshuo Aug 20, 2020
2e8d253
bug fix
MRXLT Aug 20, 2020
00c38fc
fix conflicts, test=develop
mapingshuo Aug 20, 2020
b75ab16
add examples code
MRXLT Aug 20, 2020
84205ce
Merge remote-tracking branch 'origin/2.0-op' into 2.0-op
MRXLT Aug 20, 2020
b6fa771
bug fix
MRXLT Aug 20, 2020
9cd1838
fix comments
MRXLT Aug 20, 2020
95310f5
fix sample code
MRXLT Aug 20, 2020
ce31795
add sample code for Optimizer
MRXLT Aug 20, 2020
0780b9c
add adamax ut, test=develop
mapingshuo Aug 21, 2020
87a7f56
fix rmsprop ut, test=develop
mapingshuo Aug 21, 2020
06f3c73
add ut for optimizer.py and adamw.py
MRXLT Aug 21, 2020
fd67080
Merge branch '2.0-op' of https://github.com/MRXLT/Paddle into 2.0-op
MRXLT Aug 21, 2020
27d498d
first commit of adadelta optimizer
bjjwwang Aug 21, 2020
e758d2d
fix learning rate
bjjwwang Aug 21, 2020
5252ba7
merge with develop
bjjwwang Aug 24, 2020
0674403
fix adadelta doc and add sgd momentum
bjjwwang Aug 24, 2020
fecb57e
Merge branch 'develop' into 2.0-adadelta
bjjwwang Aug 25, 2020
a140fa7
remove unused fluid
bjjwwang Aug 25, 2020
093abc9
Merge branch '2.0-adadelta' of https://github.com/wangjiawei04/Paddle…
bjjwwang Aug 25, 2020
47d0af1
fix codestyle
bjjwwang Aug 25, 2020
55877dd
Update test_adam_op.py
bjjwwang Aug 25, 2020
bf7d4a0
Update test_adam_op.py
bjjwwang Aug 25, 2020
d1e4ce4
fix SGD in 2 unittests
bjjwwang Aug 25, 2020
c154a07
Merge branch '2.0-adadelta' of https://github.com/wangjiawei04/Paddle…
bjjwwang Aug 25, 2020
aedaf12
fix SGD in 2 unittests
bjjwwang Aug 25, 2020
1de492d
merge with develop zhouwei new lr 0830
bjjwwang Aug 28, 2020
2f01874
fix ci
MRXLT Aug 28, 2020
a44ef94
fix ut
MRXLT Aug 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_adadelta_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_check_output(self):
self.check_output()


class TestAdadeltaPropV2(unittest.TestCase):
class TestAdadeltaV2(unittest.TestCase):
def test_adadelta_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
Expand Down
45 changes: 45 additions & 0 deletions python/paddle/fluid/tests/unittests/test_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
import paddle
import paddle.fluid as fluid


class TestMomentumOp1(OpTest):
Expand Down Expand Up @@ -234,5 +236,48 @@ def init_kernel(self):
self.use_nesterov = True


class TestMomentumV2(unittest.TestCase):
def test_momentum_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5, dtype="float32")
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Momentum(
learning_rate=0.01, momentum=0.9, parameters=linear.parameters())
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()

def test_momentum(self):
place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)

rms_optimizer = paddle.optimizer.Momentum(
learning_rate=0.1, momentum=0.9)
rms_optimizer.minimize(avg_cost)

fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

def test_raise_error(self):
self.assertRaises(
ValueError, paddle.optimizer.Momentum, learning_rate=None)
self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None)


if __name__ == "__main__":
unittest.main()
45 changes: 44 additions & 1 deletion python/paddle/fluid/tests/unittests/test_sgd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
import paddle


class TestSGDOp(OpTest):
Expand Down Expand Up @@ -197,7 +198,8 @@ def runTest(self):

cost = fluid.layers.square_error_cost(input=out, label=label)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001)
#sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)

place = fluid.CPUPlace()
Expand All @@ -208,5 +210,46 @@ def runTest(self):
result = exe.run(compiled_prog, fetch_list=[avg_cost])


class TestSGDV2(unittest.TestCase):
def test_sgd_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5, dtype="float32")
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.SGD(learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()

def test_sgd(self):
place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)

rms_optimizer = paddle.optimizer.SGD(learning_rate=0.1)
rms_optimizer.minimize(avg_cost)

fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None)


if __name__ == "__main__":
unittest.main()
12 changes: 9 additions & 3 deletions python/paddle/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
]


from ..fluid.optimizer import SGD, Momentum, Adagrad, Dpsgd, DecayedAdagrad, \
Ftrl, \
SGDOptimizer, MomentumOptimizer, AdagradOptimizer,DpsgdOptimizer,\
from ..fluid.optimizer import Momentum, Adagrad, Dpsgd, DecayedAdagrad, Ftrl,\
AdagradOptimizer,DpsgdOptimizer,\
DecayedAdagradOptimizer,FtrlOptimizer,AdadeltaOptimizer, \
ModelAverage, LarsMomentum, DGCMomentumOptimizer, LambOptimizer,\
ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, \
Expand All @@ -40,3 +39,10 @@
from .adamax import Adamax
from .rmsprop import RMSProp
from .adadelta import Adadelta
from .sgd import SGD
from .momentum import Momentum

from . import lr_scheduler
from .lr_scheduler import _LRScheduler, NoamLR, PiecewiseLR, NaturalExpLR, InverseTimeLR, PolynomialLR, \
LinearLrWarmup, ExponentialLR, MultiStepLR, StepLR, LambdaLR, ReduceLROnPlateau, CosineAnnealingLR

10 changes: 5 additions & 5 deletions python/paddle/optimizer/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Adadelta(Optimizer):
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__里没有默认值呀。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢 已作修改

epsilon (float): a small float number for numeric stability. Default 1.0e-6.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float, optional

rho (float): a floating point value indicating the decay rate. Default 0.95.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float, optional

parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \
parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
Expand Down Expand Up @@ -72,18 +72,18 @@ class Adadelta(Optimizer):
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32")
adam = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
adadelta = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
back = out.backward()
adam.step()
adam.clear_grad()
adadelta.step()
adadelta.clear_grad()

"""

_avg_squared_grad_acc_str = "_avg_squared_grad"
_avg_squared_update_acc_str = "_avg_squared_update"

def __init__(self,
learning_rate,
learning_rate=0.001,
epsilon=1.0e-6,
rho=0.95,
parameters=None,
Expand Down
150 changes: 150 additions & 0 deletions python/paddle/optimizer/momentum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable, name_scope

__all__ = ["Momentum"]


class Momentum(Optimizer):
"""

Simple Momentum optimizer with velocity state

This optimizer has a flag for Nestrov Momentum.

The update equations are as follows:

.. math::

& velocity = mu * velocity + gradient

& if (use\_nesterov):

&\quad param = param - (gradient + mu * velocity) * learning\_rate

& else:

&\quad param = param - learning\_rate * velocity

Parameters:

learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
momentum (float): Momentum factor. The default value is 0.9.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float, optional

parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
It canbe a float value as coeff of L2 regularization or \
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
the regularization setting here in optimizer will be ignored for this parameter. \
Otherwise, the regularization setting here in optimizer will take effect. \
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name` .

Examples:
.. code-block:: python

import paddle
import paddle.fluid as fluid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to use fluid here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢指出 已作删除

import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
out = linear(inp)
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32")
momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
back = out.backward()
momentum.step()
momentum.clear_grad()
"""
_velocity_acc_str = "velocity"

def __init__(self,
learning_rate=0.001,
momentum=0.9,
parameters=None,
use_nesterov=False,
weight_decay=None,
grad_clip=None,
name=None):
if learning_rate is None:
raise ValueError("learning_rate is not set")
if momentum is None:
raise ValueError("momentum is not set")
super(Momentum, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
weight_decay=weight_decay,
grad_clip=grad_clip,
name=name)
self.type = "momentum"
self._momentum = momentum
self._use_nesterov = bool(use_nesterov)

def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)

for p in parameters:
self._add_accumulator(self._velocity_acc_str, p)

def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)

velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0])
lr = self._create_param_lr(param_and_grad)

if framework.in_dygraph_mode():
_, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1],
velocity_acc, lr, param_and_grad[0],
velocity_acc, 'mu', self._momentum,
'use_nesterov', self._use_nesterov)
return None

attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
inputs = {
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"Velocity": [velocity_acc],
"LearningRate": [lr]
}

outputs = {
"ParamOut": [param_and_grad[0]],
"VelocityOut": [velocity_acc]
}
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)

return momentum_op
Loading