Skip to content

Commit

Permalink
fix momentum ops (#36452)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Oct 15, 2021
1 parent 8566cc9 commit 4dda18a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
67 changes: 35 additions & 32 deletions paddle/fluid/operators/optimizers/momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
}
};

template <typename T, typename MT, typename UpdateMethod>
template <typename T, typename MT, RegularizationType kRegType,
typename UpdateMethod>
class DenseMomentumFunctor;

// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
private:
const T* param_;
const T* grad_;
Expand All @@ -193,15 +194,13 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;

public:
DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
Expand All @@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
Expand All @@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];

grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
if (kRegType == RegularizationType::kL2DECAY) {
grad += regularization_coeff_ * param;
}

MT velocity_out = velocity * mu_ + grad;
MT param_out = param - (grad + velocity_out * mu_) * lr;
Expand All @@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
}
};

template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
private:
const T* param_;
const T* grad_;
Expand All @@ -254,15 +252,13 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;

public:
DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
Expand All @@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
Expand All @@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];

grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
if (kRegType == RegularizationType::kL2DECAY) {
grad += regularization_coeff_ * param;
}

MT velocity_out = velocity * mu_ + grad;
MT param_out = param - lr * velocity_out;
Expand Down Expand Up @@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, MT, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
param->data<T>(), grad->data<T>(), velocity->data<MT>(), \
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
param->numel(), regularization_coeff, \
param_out->mutable_data<T>(ctx.GetPlace()), \
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
for_range(functor);

if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
}
} else {
DenseMomentumFunctor<T, MT, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run_momentum_op(params,
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
Expand All @@ -115,7 +115,7 @@ def run_momentum_op(params,
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
Expand Down Expand Up @@ -176,7 +176,10 @@ def run_op(use_merged):
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out1, out2))
else:
self.assertTrue(np.allclose(out1, out2, atol=1e-7))

def get_places(self):
places = [paddle.CPUPlace()]
Expand Down

0 comments on commit 4dda18a

Please sign in to comment.