Skip to content

Commit

Permalink
FTML optimizer implementation (apache#9262)
Browse files Browse the repository at this point in the history
* ftml implemention

* c++ version and test

* merge WD into gradients
  • Loading branch information
ZiyueHuang authored and zheng-da committed Jun 28, 2018
1 parent 7b5f26e commit f2b280b
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 2 deletions.
51 changes: 50 additions & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update)
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update)
from .ndarray import _internal
from .ndarray import op
from .ndarray import sparse
Expand Down Expand Up @@ -529,6 +529,55 @@ def update_multi_precision(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)


@register
class FTML(Optimizer):
"""The FTML optimizer.
This class implements the optimizer described in
*FTML - Follow the Moving Leader in Deep Learning*,
available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
beta1 : float, optional
0 < beta1 < 1. Generally close to 0.5.
beta2 : float, optional
0 < beta2 < 1. Generally close to 1.
epsilon : float, optional
Small value to avoid division by 0.
"""
def __init__(self, beta1=0.6, beta2=0.999, epsilon=1e-8, **kwargs):
super(FTML, self).__init__(**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon

def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # d_0
zeros(weight.shape, weight.context, dtype=weight.dtype), # v_0
zeros(weight.shape, weight.context, dtype=weight.dtype)) # z_0

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad, 't': t}
if self.clip_gradient:
kwargs['clip_grad'] = self.clip_gradient

prev_d, prev_v, prev_z = state
ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
lr=lr, wd=wd, **kwargs)

# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
Expand Down
89 changes: 89 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,95 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
}
}


struct FTMLParam : public dmlc::Parameter<FTMLParam> {
float lr;
float beta1;
float beta2;
double epsilon;
int t;
float wd;
float rescale_grad;
float clip_grad;
DMLC_DECLARE_PARAMETER(FTMLParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate.");
DMLC_DECLARE_FIELD(beta1)
.set_default(0.6f)
.set_range(0.0f, 1.0f)
.describe("Generally close to 0.5.");
DMLC_DECLARE_FIELD(beta2)
.set_default(0.999f)
.set_range(0.0f, 1.0f)
.describe("Generally close to 1.");
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-8f)
.describe("Epsilon to prevent div 0.");
DMLC_DECLARE_FIELD(t)
.describe("Number of update.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_grad)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};


struct FTMLKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, DType* weight, DType* grad,
DType* d, DType* v, DType* z, const DType lr, const DType beta1,
const DType beta2, const DType epsilon, const DType t,
const DType wd, const DType rescale_grad, const DType clip_grad,
const OpReqType req) {
using namespace mshadow_op;
const DType grad_i = clip_grad >= 0.0f
? clip::Map(rescale_grad * grad[i] + wd * weight[i], clip_grad)
: (rescale_grad * grad[i] + wd * weight[i]);
v[i] = beta2 * v[i] + (1 - beta2) * square::Map(grad_i);
const DType d_t = (1 - power::Map(beta1, t)) / lr *
(square_root::Map(v[i] / (1 - power::Map(beta2, t))) + epsilon);
z[i] = beta1 * z[i] + (1 - beta1) * grad_i - (d_t - beta1 * d[i]) * weight[i];
d[i] = d_t;
KERNEL_ASSIGN(out[i], req, - z[i] / d_t);
}
};


template<typename xpu>
inline void FTMLUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
FTMLParam param = nnvm::get<FTMLParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DType* weight_data = inputs[0].dptr<DType>();
DType* grad_data = inputs[1].dptr<DType>();
DType* d_data = inputs[2].dptr<DType>();
DType* v_data = inputs[3].dptr<DType>();
DType* z_data = inputs[4].dptr<DType>();
DType* out_data = outputs[0].dptr<DType>();
Kernel<FTMLKernel, xpu>::Launch(s, inputs[0].shape_.Size(), out_data,
weight_data, grad_data, d_data, v_data, z_data, static_cast<DType>(param.lr),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.epsilon), static_cast<DType>(param.t), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), static_cast<DType>(param.clip_grad),
req[0]);
});
}

struct AdamParam : public dmlc::Parameter<AdamParam> {
float lr;
float beta1;
Expand Down
33 changes: 33 additions & 0 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace op {

DMLC_REGISTER_PARAMETER(SGDParam);
DMLC_REGISTER_PARAMETER(SGDMomParam);
DMLC_REGISTER_PARAMETER(FTMLParam);
DMLC_REGISTER_PARAMETER(AdamParam);
DMLC_REGISTER_PARAMETER(RMSPropParam);
DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
Expand Down Expand Up @@ -143,6 +144,38 @@ NNVM_REGISTER_OP(mp_sgd_mom_update)
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
.add_arguments(SGDMomParam::__FIELDS__());

NNVM_REGISTER_OP(ftml_update)
.describe(R"code(The FTML optimizer described in
*FTML - Follow the Moving Leader in Deep Learning*,
available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
.. math::
g_t = \nabla J(W_{t-1})\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
d_t = \frac{ (1 - \beta_1^t) }{ \eta_t } (\sqrt{ \frac{ v_t }{ 1 - \beta_2^t } } + \epsilon)
\sigma_t = d_t - \beta_1 d_{t-1}
z_t = \beta_1 z_{ t-1 } + (1 - \beta_1^t) g_t - \sigma_t W_{t-1}
W_t = - \frac{ z_t }{ d_t }
)code" ADD_FILELINE)
.set_num_inputs(5)
.set_num_outputs(1)
.set_attr_parser(ParamParser<FTMLParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<5, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
})
.set_attr<FCompute>("FCompute<cpu>", FTMLUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``")
.add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``")
.add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``")
.add_arguments(AdamParam::__FIELDS__());

NNVM_REGISTER_OP(adam_update)
MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
.describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization
Expand Down
3 changes: 3 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ NNVM_REGISTER_OP(mp_sgd_update)
NNVM_REGISTER_OP(mp_sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MP_SGDMomUpdate<gpu>);

NNVM_REGISTER_OP(ftml_update)
.set_attr<FCompute>("FCompute<gpu>", FTMLUpdate<gpu>);

NNVM_REGISTER_OP(adam_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdamUpdateEx<gpu>);
Expand Down
69 changes: 68 additions & 1 deletion tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,74 @@ def test_sparse_sgd():
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
w_stype='row_sparse', g_stype='row_sparse')


# FTML

class PyFTML(mx.optimizer.Optimizer):
"""python reference implemenation of FTML"""
def __init__(self, beta1=0.6, beta2=0.999, epsilon=1e-8, **kwargs):
super(PyFTML, self).__init__(**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon

def create_state(self, index, weight):
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # d_0
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # v_0
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # z_0

def update(self, index, weight, grad, state):
assert(isinstance(weight, mx.nd. NDArray))
assert(isinstance(grad, mx.nd.NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
# get previous states
prev_d, prev_v, prev_z = state
# compute states
v_t = self.beta2 * prev_v + (1 - self.beta2) * mx.nd.square(grad)
d_t = (1 - pow(self.beta1, t)) / lr * (mx.nd.sqrt(v_t / (1 - pow(self.beta2, t))) + self.epsilon)
sigma_t = d_t - self.beta1 * prev_d
z_t = self.beta1 * prev_z + (1 - self.beta1) * grad - sigma_t * weight
# update weight
weight[:] = - z_t / d_t
# update states
prev_d[:] = d_t
prev_v[:] = v_t
prev_z[:] = z_t


def test_ftml():
mx.random.seed(0)
opt1 = PyFTML
opt2 = mx.optimizer.FTML
shape = (3, 4, 5)
beta1_options = [{}, {'beta1': 0.5}, {'beta1': 0.7}]
beta2_options = [{}, {'beta2': 0.8}, {'beta2': 0.9}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
for dtype in [np.float32]:
for beta1_option in beta1_options:
for beta2_option in beta2_options:
for cg_option in cg_options:
for rg_option in rg_options:
for wd_option in wd_options:
kwarg = {}
kwarg.update(beta1_option)
kwarg.update(beta2_option)
kwarg.update(cg_option)
kwarg.update(rg_option)
kwarg.update(wd_option)
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)



# ADAM

class PyAdam(mx.optimizer.Optimizer):
Expand Down Expand Up @@ -675,4 +743,3 @@ def get_net(num_hidden, flatten=True):
if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit f2b280b

Please sign in to comment.