From 5241b9a0f6c066a349c436ace3f7e310c5945395 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 10 Nov 2021 15:37:45 +0000 Subject: [PATCH] update --- .../operators/fused/cudnn_bn_add_relu_test.cc | 5 +- .../fused/cudnn_bn_stats_finalize.cu.h | 5 +- .../operators/fused/cudnn_fusion_helper.h | 5 +- .../fused/cudnn_scale_bias_add_relu.cu.h | 5 +- .../operators/optimizers/lars_momentum_op.cc | 144 +++++++++++++++--- .../operators/optimizers/lars_momentum_op.h | 76 +++++---- python/paddle/fluid/optimizer.py | 2 +- .../test_fleet_lars_meta_optimizer.py | 2 +- .../fluid/tests/unittests/test_momentum_op.py | 133 ++++++++++------ 9 files changed, 265 insertions(+), 112 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/optimizers/lars_momentum_op.h diff --git a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc index fb6079c2a55d6..36477bb09a2ad 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc +++ b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc @@ -1,8 +1,11 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h index b5e7a33b731a0..1b995e1313f47 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -1,8 +1,11 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/operators/fused/cudnn_fusion_helper.h b/paddle/fluid/operators/fused/cudnn_fusion_helper.h index 0f3d659057407..78655416ffbc2 100644 --- a/paddle/fluid/operators/fused/cudnn_fusion_helper.h +++ b/paddle/fluid/operators/fused/cudnn_fusion_helper.h @@ -1,8 +1,11 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h index 4268bd8278a12..d60a9174e1317 100644 --- a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -1,8 +1,11 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 8f30dd5b2e68a..b0a85b7d5b92f 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -13,46 +13,158 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/optimizers/lars_momentum_op.h" -#include "paddle/fluid/operators/optimizers/momentum_op.h" namespace paddle { namespace operators { +class LarsMomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut", + "LarsMomentum"); + PADDLE_ENFORCE_EQ( + ctx->GetInputsVarType("Param").front(), + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The input var's type should be LoDTensor, but the received is %s", + ctx->GetInputsVarType("Param").front())); + + auto lr_dims = ctx->GetInputsDim("LearningRate"); + auto grad_dim = ctx->GetInputsDim("Grad"); + auto param_dim = ctx->GetInputsDim("Param"); + auto velocity_dim = ctx->GetInputsDim("Velocity"); + auto lars_weight_decays = + ctx->Attrs().Get>("lars_weight_decay"); + auto multi_precision = ctx->Attrs().Get("multi_precision"); + + PADDLE_ENFORCE_EQ( + param_dim.size(), grad_dim.size(), + platform::errors::InvalidArgument( + "Input(Param) and Input(Grad) of LarsMomentumOp should have " + "same quantity. But number of Param is [%d] and Grad is [%d].", + param_dim.size(), grad_dim.size())); + PADDLE_ENFORCE_EQ( + param_dim.size(), velocity_dim.size(), + platform::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp should " + "have same quantity. But number of Param is [%d] and Velocity " + "is [%d].", + param_dim.size(), velocity_dim.size())); + PADDLE_ENFORCE_EQ( + lars_weight_decays.size(), grad_dim.size(), + platform::errors::InvalidArgument( + "Attr(Lars_weight_decay) and " + "Input(Grad) of LarsMomentumOp should have same quantity. " + "But number of Lars_weight_decay is [%d] and Grad is [%d].", + lars_weight_decays.size(), grad_dim.size())); + + if (multi_precision) { + OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam", + "LarsMomentumMultiPrecision"); + OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output", + "MasterParamOut", "LarsMomentumMultiPrecision"); + } + for (size_t i = 0; i < lr_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1, + platform::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + framework::product(lr_dims[i]))); + } + + for (size_t i = 0; i < param_dim.size(); ++i) { + PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i], + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx->Inputs("Grad")[i].front(), + ctx->GetInputsVarType("Grad")[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], grad_dim[i], + platform::errors::InvalidArgument( + "Input(Param) and Input(Grad) input of LarsMomentumOp shall " + "have same dimension. But Param`s dim is [%s] and Grad's dim " + "is [%s].", + param_dim[i], grad_dim[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], velocity_dim[i], + platform::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp shall have " + "same dimension. But Param dim [%s] differs with Velocity dim " + "[%s].", + param_dim[i], velocity_dim[i])); + } + ctx->SetOutputsDim("ParamOut", param_dim); + ctx->SetOutputsDim("VelocityOut", param_dim); + if (ctx->HasOutputs("MasterParamOut")) { + ctx->SetOutputsDim("MasterParamOut", param_dim); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Param", "(LoDTensor, default LoDTensor) " - "Input parameter that has to be updated"); + "Input parameter that has to be updated") + .AsDuplicable(); AddInput("Grad", "(LoDTensor, default LoDTensor) " - "Input gradient of the parameter"); + "Input gradient of the parameter") + .AsDuplicable(); AddInput("Velocity", "(LoDTensor, default LoDTensor) " "Input velocity (corresponding to the parameter) " - "that has to be updated"); + "that has to be updated") + .AsDuplicable(); AddInput("LearningRate", "(LoDTensor, default LoDTensor) " - "Input learning rate"); - AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); - + "Input learning rate") + .AsDuplicable(); + AddInput("MasterParam", "FP32 master weight for AMP.") + .AsDuplicable() + .AsDispensable(); AddOutput("ParamOut", "(LoDTensor) This output is updated parameter. " - "It shared memory with Input(Param)."); + "It shared memory with Input(Param).") + .AsDuplicable(); AddOutput("VelocityOut", "(LoDTensor) This output is updated velocity. " - "It shared memory with Input(Velocity)."); + "It shared memory with Input(Velocity).") + .AsDuplicable(); AddOutput("MasterParamOut", "The updated FP32 master weight for AMP. " "It shared memory with Input(MasterParam).") + .AsDuplicable() .AsDispensable(); - AddAttr("mu", "(float) Momentum coefficient"); AddAttr("lars_coeff", "(float, default 0.001) LARS coefficient.") .SetDefault(0.001); - AddAttr("lars_weight_decay", - "(float, default 0.0005) LARS weight decay") - .SetDefault(0.0005); + AddAttr>( + "lars_weight_decay", + "(std::vector, default 0.0005) LARS weight decay params") + .SetDefault({0.0005}); AddAttr("epsilon", "(float, default 0.0) epsilon to avoid Division by Zero.") .SetDefault(0.0); @@ -68,10 +180,8 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Lars Momentum Optimizer. - This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each weight using a local learning rate: - $$ local\_lr = \eta * \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ @@ -79,10 +189,8 @@ velocity = mu * velocity + local\_lr * (grad + \beta * param) \\ param = param - velocity. \\ $$ - Note that we use lars_weight_decay here to decay weights, you may need not to use L2 regularizers in case of using LARS. - )DOC"); } }; @@ -96,7 +204,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR( - lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, + lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::LarsMomentumOpVarTypeInference); diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h old mode 100755 new mode 100644 index 55775bc08fb5e..ff836fdca1cc4 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,54 +23,48 @@ template class LarsMomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto learning_rate = ctx.Input("LearningRate"); - auto* grad_var = ctx.InputVar("Grad"); - // only support dense for now. - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - auto grad = ctx.Input("Grad"); - - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); - + auto param_out = ctx.MultiOutput("ParamOut"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto param = ctx.MultiInput("Param"); + auto velocity = ctx.MultiInput("Velocity"); + auto learning_rate = ctx.MultiInput("LearningRate"); + auto grad = ctx.MultiInput("Grad"); + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); - T lars_weight_decay = ctx.Attr("lars_weight_decay"); T epsilon = ctx.Attr("epsilon"); - auto p_out = framework::EigenVector::Flatten(*param_out); - auto v_out = framework::EigenVector::Flatten(*velocity_out); + int op_num = param.size(); + for (int i = 0; i < op_num; ++i) { + auto* lr = learning_rate[i]->data(); + T lars_weight_decay = weight_decay_arr[i]; + param_out[i]->mutable_data(ctx.GetPlace()); + velocity_out[i]->mutable_data(ctx.GetPlace()); - auto p = framework::EigenVector::Flatten(*param); - auto v = framework::EigenVector::Flatten(*velocity); - auto g = framework::EigenVector::Flatten(*grad); - auto* lr = learning_rate->data(); + auto p_out = framework::EigenVector::Flatten(*(param_out[i])); + auto v_out = framework::EigenVector::Flatten(*(velocity_out[i])); + auto p = framework::EigenVector::Flatten(*(param[i])); + auto v = framework::EigenVector::Flatten(*(velocity[i])); + auto g = framework::EigenVector::Flatten(*(grad[i])); - framework::Tensor p_norm_t, g_norm_t; - p_norm_t.Resize({1}); - g_norm_t.Resize({1}); - p_norm_t.mutable_data(ctx.GetPlace()); - g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + p_norm_t.mutable_data(ctx.GetPlace()); + g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + ep_norm = p.square().sum().sqrt(); + eg_norm = g.square().sum().sqrt(); - ep_norm = p.square().sum().sqrt(); - eg_norm = g.square().sum().sqrt(); - T local_lr = lr[0]; - if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { - local_lr = lr[0] * lars_coeff * ep_norm(0) / - (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + T local_lr = lr[0]; + if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + } + v_out = v * mu + local_lr * (g + lars_weight_decay * p); + p_out = p - v_out; } - v_out = v * mu + local_lr * (g + lars_weight_decay * p); - p_out = p - v_out; } }; diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7aea94f6c523e..719bc609785e3 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2064,7 +2064,7 @@ def _append_optimize_op(self, block, param_and_grad): attrs = { "mu": self._momentum, "lars_coeff": self._lars_coeff, - "lars_weight_decay": _lars_weight_decay, + "lars_weight_decay": [_lars_weight_decay], "multi_precision": find_master, "epsilon": self._epsilon, "rescale_grad": self._rescale_grad diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py index e4cc3682d1a24..bee6acf732460 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -103,7 +103,7 @@ def test_lars_exclude_fn(self): 'op_role_var')[0] or ".b" in op.attr('op_role_var')[0]) ] for op in ops_without_wd: - self.assertEqual(op.attr('lars_weight_decay'), 0) + self.assertEqual(op.attr('lars_weight_decay')[0], 0) def test_lars_apply_with_amp(self): role = role_maker.PaddleCloudRoleMaker(is_collective=True) diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index b42de853c00d5..34e057a5a8a61 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -138,50 +138,70 @@ def test_check_output(self): "core is not compiled with CUDA") class TestLarsMomentumOpWithMP(OpTest): def setUp(self): + self.config() self.op_type = "lars_momentum" - - master_param = np.random.random((123, 321)).astype("float32") - param = master_param.astype("float16") - grad = np.random.random((123, 321)).astype("float16") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 lars_coeff = 0.001 lars_weight_decay = 0.0005 rescale_grad = 1.0 + params = [] + grads = [] + velocitys = [] + learning_rates = [] + master_params = [] + param_outs = [] + velocity_outs = [] + master_param_outs = [] + for i in range(self.params_num): + master_param = np.random.random((123, 321)).astype("float32") + param = master_param.astype("float16") + grad = np.random.random((123, 321)).astype("float16") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + + fp32_grad = grad.astype("float32") + pnorm = np.sqrt(np.square(master_param).sum()) + gnorm = np.sqrt(np.square(fp32_grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * pnorm) + fp32_grad = fp32_grad * rescale_grad + velocity_out = mu * velocity + local_lr * ( + fp32_grad + lars_weight_decay * master_param) + p_new = master_param - velocity_out + param_out = p_new.astype("float16") + master_param_out = p_new + + params.append(("SubParam_" + str(i), param)) + grads.append(("SubGrad_" + str(i), grad)) + velocitys.append(("SubVelocity_" + str(i), velocity)) + learning_rates.append(("SubLearning_rate_" + str(i), learning_rate)) + velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out)) + param_outs.append(("SubParam_out_" + str(i), param_out)) + master_params.append(("SubMasterParam_" + str(i), master_param)) + master_param_outs.append( + ("SubMasterParamOut_" + str(i), master_param_out)) + self.inputs = { - 'Param': param, - 'Grad': grad, - 'Velocity': velocity, - 'LearningRate': learning_rate, - 'MasterParam': master_param, + 'Param': params, + 'Grad': grads, + 'Velocity': velocitys, + 'LearningRate': learning_rates, + 'MasterParam': master_params, } self.attrs = { 'mu': mu, 'lars_coeff': lars_coeff, - 'lars_weight_decay': lars_weight_decay, + 'lars_weight_decay': [lars_weight_decay], 'multi_precision': True, 'rescale_grad': rescale_grad } - fp32_grad = grad.astype("float32") - pnorm = np.sqrt(np.square(master_param).sum()) - gnorm = np.sqrt(np.square(fp32_grad).sum()) - local_lr = learning_rate * lars_coeff * pnorm / ( - gnorm + lars_weight_decay * pnorm) - fp32_grad = fp32_grad * rescale_grad - velocity_out = mu * velocity + local_lr * (fp32_grad + lars_weight_decay - * master_param) - p_new = master_param - velocity_out - param_out = p_new.astype("float16") - master_param_out = p_new - self.outputs = { - 'ParamOut': param_out, - 'VelocityOut': velocity_out, - 'MasterParamOut': master_param_out + 'ParamOut': param_outs, + 'VelocityOut': velocity_outs, + 'MasterParamOut': master_param_outs } def test_check_output(self): @@ -191,46 +211,65 @@ def test_check_output(self): if core.is_float16_supported(place): self.check_output_with_place(place) + def config(self): + self.params_num = 1 + class TestLarsMomentumOp(OpTest): def setUp(self): + self.config() self.op_type = "lars_momentum" - - param = np.random.random((123, 321)).astype("float32") - grad = np.random.random((123, 321)).astype("float32") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 lars_coeff = 0.001 lars_weight_decay = 0.0005 + params = [] + grads = [] + velocitys = [] + param_outs = [] + velocity_outs = [] + learning_rates = [] + for i in range(self.params_num): + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + pnorm = np.sqrt(np.square(param).sum()) + gnorm = np.sqrt(np.square(grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * param) + velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay + * param) + param_out = param - velocity_out + + params.append(("SubParam_" + str(i), param)) + grads.append(("SubGrad_" + str(i), grad)) + velocitys.append(("SubVelocity_" + str(i), velocity)) + learning_rates.append(("SubLearning_rate_" + str(i), learning_rate)) + velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out)) + param_outs.append(("SubParam_out_" + str(i), param_out)) + self.inputs = { - 'Param': param, - 'Grad': grad, - 'Velocity': velocity, - 'LearningRate': learning_rate + 'Param': params, + 'Grad': grads, + 'Velocity': velocitys, + 'LearningRate': learning_rates } self.attrs = { 'mu': mu, 'lars_coeff': lars_coeff, - 'lars_weight_decay': lars_weight_decay + 'lars_weight_decay': [lars_weight_decay] } - - pnorm = np.sqrt(np.square(param).sum()) - gnorm = np.sqrt(np.square(grad).sum()) - local_lr = learning_rate * lars_coeff * pnorm / ( - gnorm + lars_weight_decay * param) - velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay * - param) - param_out = param - velocity_out - - self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + self.outputs = {'ParamOut': param_outs, 'VelocityOut': velocity_outs} def test_check_output(self): paddle.enable_static() self.check_output() + def config(self): + self.params_num = 1 + class TestSparseMomentumOp(unittest.TestCase): def setUp(self):