Skip to content

Commit

Permalink
Add LANS optimizer (apache#18620)
Browse files Browse the repository at this point in the history
* add lans optimizer

* fix

* fix

Co-authored-by: Zheng <shzheng@a483e789dd93.ant.amazon.com>
  • Loading branch information
2 people authored and ys2843 committed Jun 29, 2020
1 parent 2c16502 commit a988119
Show file tree
Hide file tree
Showing 10 changed files with 1,302 additions and 20 deletions.
78 changes: 78 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,81 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count,
learning_rates=lrs,
wds=wds,
**kwargs)


def multi_lans_update(weights, grads, mean, var, step_count,
lrs, wds, out=None, num_tensors=0, **kwargs):
"""Given a list of gradients, update weights, mean and variance of multiple tensors
following LANS Optimizer implementation.
Parameters
----------
weights : List of NDArrays containing the input weights of multiple tensors
grads : List of NDArrays containing input gradients
mean : List of NDArrays containing mean of multiple tensors to be updated
var : List of NDArrays containing variance of multiple tensors to be updated
step_count : List of scalars with the number of update step for each tensor
lrs : List of learning rates (one for each tensor)
wds : List of weight decays (one for each tensor)
out: List of NDArrays where the updated weights will be stored
num_tensors : Number of NDArrays/tensors in the list
"""

if not num_tensors:
num_tensors = len(weights)
temp_list = _flatten_list(zip(weights, grads, mean, var))
return ndarray._internal._multi_lans_update(*temp_list,
out=out,
num_tensors=num_tensors,
step_count=step_count,
learning_rates=lrs,
wds=wds,
**kwargs)


def multi_mp_lans_update(weights, grads, mean, var, weights32, step_count,
lrs, wds, out=None, num_tensors=0, **kwargs):
"""Given a list of gradients, update weights, mean and variance of multiple tensors
following LANS Optimizer implementation, and using Mixed-Precision.
Parameters
----------
weights : List of NDArrays containing the input weights of multiple tensors
grads : List of NDArrays containing input gradients
mean : List of NDArrays containing mean of multiple tensors to be updated
var : List of NDArrays containing variance of multiple tensors to be updated
weights32 : Master copy of weights in FP32
step_count : List of scalars with the number of update step for each tensor
lrs : List of learning rates (one for each tensor)
wds : List of weight decays (one for each tensor)
out: List of NDArrays where the updated weights will be stored
num_tensors : Number of NDArrays/tensors in the list
"""

if not num_tensors:
num_tensors = len(weights)
temp_list = _flatten_list(zip(weights, grads, mean, var, weights32))
return ndarray._internal._multi_mp_lans_update(*temp_list,
out=out,
num_tensors=num_tensors,
step_count=step_count,
learning_rates=lrs,
wds=wds,
**kwargs)
6 changes: 4 additions & 2 deletions python/mxnet/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from . import (optimizer, contrib, updater, utils, sgd,
sgld, signum, dcasgd, nag, adagrad,
adadelta, adam, adamax, nadam, ftrl,
ftml, lars, lamb, rmsprop)
ftml, lars, lamb, rmsprop, lans)
# pylint: disable=wildcard-import
from .optimizer import *

Expand Down Expand Up @@ -57,7 +57,9 @@

from .rmsprop import *

from .lans import *

__all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \
+ signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \
+ adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \
+ lars.__all__ + lamb.__all__ + rmsprop.__all__
+ lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__
220 changes: 220 additions & 0 deletions python/mxnet/optimizer/lans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""LANS optimizer."""
from __future__ import absolute_import
import numpy
from ..ndarray import (zeros, clip, sqrt, where, square, ones_like,
maximum, minimum)
from ..ndarray.contrib import (multi_lans_update, multi_mp_lans_update)
from .optimizer import Optimizer, register

__all__ = ['LANS']


@register
class LANS(Optimizer):
"""LANS Optimizer.
Referenced from 'Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes'
(http://arxiv.org/abs/2006.13484)
Parameters
----------
learning_rate : float, default 0.001
The initial learning rate. If None, the optimization will use the
learning rate from ``lr_scheduler``. If not None, it will overwrite
the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
is also None, then it will be set to 0.01 by default.
beta1 : float, default 0.9
Exponential decay rate for the first moment estimates.
beta2 : float, default 0.999
Exponential decay rate for the second moment estimates.
epsilon : float, default 1e-6
Small value to avoid division by 0.
lower_bound : float, default None
Lower limit of norm of weight
upper_bound : float, default None
Upper limit of norm of weight
aggregate_num : int, default 4
Number of weights to be aggregated in a list.
They are passed to the optimizer for a single optimization step.
In default, all the weights are aggregated.
use_fused_step : bool, default True
Whether or not to use fused kernels for optimizer.
When use_fused_step=False, step is called,
otherwise, fused_step is called.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=None, upper_bound=None, aggregate_num=4, use_fused_step=True,
**kwargs):
assert aggregate_num <= 45,\
'When use_fused_step is True, LAMB only supports aggregate_num <= 45,' \
' and receives {}'.format(aggregate_num)
super(LANS, self).__init__(learning_rate=learning_rate,
aggregate_num=aggregate_num,
use_fused_step=use_fused_step,
**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound

def create_state(self, index, weight):
stype = weight.stype
return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype), # mean
zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype)) # var

def step(self, indices, weights, grads, states):
"""Perform a fused optimization step using gradients and states.
Fused kernel is used for update.
Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
for index, weight, grad, state in zip(indices, weights, grads, states):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

# preprocess grad
grad *= self.rescale_grad
grad /= grad.norm()
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)

# update mean, var
mean, var = state
mean[:] *= self.beta1
mean[:] += (1. - self.beta1) * grad
var[:] *= self.beta2
var[:] += (1. - self.beta2) * square(grad)

r1 = weight.norm()
if self.lower_bound is not None:
r1 = maximum(r1, self.lower_bound)
if self.upper_bound is not None:
r1 = minimum(r1, self.upper_bound)

# apply bias correction
coef1 = 1. - self.beta1 ** t
coef2 = 1. - self.beta2 ** t
mean_hat = mean / coef1
var_hat = var / coef2
sqrt(var_hat, out=var_hat)
var_hat += self.epsilon
mean_hat /= var_hat
mean_hat += wd * weight

g = mean_hat
r2 = g.norm()

# calculate lans_trust_ratio for first part
ratio_m = r1 / r2
# becomes NaN if ratio == NaN or 0, otherwise 0
nan_or_zero = 1 - ratio_m / ratio_m
r_m = where(nan_or_zero, ones_like(ratio_m), ratio_m)

# update weight using first part of the estimator
g *= lr * r_m * self.beta1
weight[:] -= g

# calculate the second part of the estimator
mean_hat = grad / var_hat
mean_hat += wd * weight

g = mean_hat
r2 = g.norm()

# calculate lans_trust_ratio for second part
ratio_g = r1 / r2
# becomes NaN if ratio == NaN or 0, otherwise 0
nan_or_zero = 1 - ratio_g / ratio_g
r_g = where(nan_or_zero, ones_like(ratio_g), ratio_g)

# update weight using second part of the estimator
g *= lr * r_g * (1 - self.beta1)
weight[:] -= g

def fused_step(self, indices, weights, grads, states):
"""Perform a fused optimization step using gradients and states.
Fused kernel is used for update.
Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
self._update_count(indices)
lrs = self._get_lrs(indices)
wds = self._get_wds(indices)

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound

step_counts = []
for index in indices:
step_counts.append(self._index_update_count[index])

multi_precision = self.multi_precision and weights[0].dtype == numpy.float16

if not multi_precision:
mean, var = list(zip(*states))
multi_lans_update(weights, grads, mean, var,
out=weights, step_count=step_counts,
lrs=lrs, wds=wds, **kwargs)
else:
weights32, mean_var = list(zip(*states))
mean, var = list(zip(*mean_var))
multi_mp_lans_update(weights, grads,
mean, var, weights32,
out=weights, step_count=step_counts,
lrs=lrs, wds=wds, **kwargs)

def update_multi_precision(self, indices, weights, grads, states):
"""Override update_multi_precision.
"""
if self.use_fused_step:
self.update(indices, weights, grads, states)
else:
super(LANS, self).update_multi_precision(indices, weights, grads, states)
Loading

0 comments on commit a988119

Please sign in to comment.