Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

FTML optimizer implementation #9262

Merged
merged 3 commits into from
Jan 3, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,59 @@ def update_multi_precision(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)

@register
class FTML(Optimizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we already have this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have FTRL (Follow the regularized leader). This PR adds FTML (Follow the moving leader).

"""The FTML optimizer.

This class implements the optimizer described in
*FTML - Follow the Moving Leader in Deep Learning*,
available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.

This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.

Parameters
----------
beta1: float, 0 < beta < 1. Generally close to 0.5.
beta2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
"""
def __init__(self, beta1=0.6, beta2=0.999, epsilon=1e-8, **kwargs):
super(FTML, self).__init__(**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon

def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # d_0
zeros(weight.shape, weight.context, dtype=weight.dtype), # v_0
zeros(weight.shape, weight.context, dtype=weight.dtype)) # z_0

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
# get previous states
prev_d, prev_v, prev_z = state
# compute states
v_t = self.beta2 * prev_v + (1 - self.beta2) * square(grad)
d_t = (1 - pow(self.beta1, t)) / lr * (sqrt(v_t / (1 - pow(self.beta2, t))) + self.epsilon)
sigma_t = d_t - self.beta1 * prev_d
z_t = self.beta1 * prev_z + (1 - self.beta1) * grad - sigma_t * weight
# update weight
weight[:] = - z_t / d_t - lr * wd * weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should merge the wd term into the gradient. @szhengac could you help check this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest formulas look good.

# update states
prev_d[:] = d_t
prev_v[:] = v_t
prev_z[:] = z_t

# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
Expand Down