Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add LANS optimizer #18620

Merged
merged 3 commits into from
Jun 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,81 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count,
learning_rates=lrs,
wds=wds,
**kwargs)


def multi_lans_update(weights, grads, mean, var, step_count,
lrs, wds, out=None, num_tensors=0, **kwargs):
"""Given a list of gradients, update weights, mean and variance of multiple tensors
following LANS Optimizer implementation.

Parameters
----------
weights : List of NDArrays containing the input weights of multiple tensors

grads : List of NDArrays containing input gradients

mean : List of NDArrays containing mean of multiple tensors to be updated

var : List of NDArrays containing variance of multiple tensors to be updated

step_count : List of scalars with the number of update step for each tensor

lrs : List of learning rates (one for each tensor)

wds : List of weight decays (one for each tensor)

out: List of NDArrays where the updated weights will be stored

num_tensors : Number of NDArrays/tensors in the list
"""

if not num_tensors:
num_tensors = len(weights)
temp_list = _flatten_list(zip(weights, grads, mean, var))
return ndarray._internal._multi_lans_update(*temp_list,
out=out,
num_tensors=num_tensors,
step_count=step_count,
learning_rates=lrs,
wds=wds,
**kwargs)


def multi_mp_lans_update(weights, grads, mean, var, weights32, step_count,
lrs, wds, out=None, num_tensors=0, **kwargs):
"""Given a list of gradients, update weights, mean and variance of multiple tensors
following LANS Optimizer implementation, and using Mixed-Precision.

Parameters
----------
weights : List of NDArrays containing the input weights of multiple tensors

grads : List of NDArrays containing input gradients

mean : List of NDArrays containing mean of multiple tensors to be updated

var : List of NDArrays containing variance of multiple tensors to be updated

weights32 : Master copy of weights in FP32

step_count : List of scalars with the number of update step for each tensor

lrs : List of learning rates (one for each tensor)

wds : List of weight decays (one for each tensor)

out: List of NDArrays where the updated weights will be stored

num_tensors : Number of NDArrays/tensors in the list
"""

if not num_tensors:
num_tensors = len(weights)
temp_list = _flatten_list(zip(weights, grads, mean, var, weights32))
return ndarray._internal._multi_mp_lans_update(*temp_list,
out=out,
num_tensors=num_tensors,
step_count=step_count,
learning_rates=lrs,
wds=wds,
**kwargs)
6 changes: 4 additions & 2 deletions python/mxnet/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from . import (optimizer, contrib, updater, utils, sgd,
sgld, signum, dcasgd, nag, adagrad,
adadelta, adam, adamax, nadam, ftrl,
ftml, lars, lamb, rmsprop)
ftml, lars, lamb, rmsprop, lans)
# pylint: disable=wildcard-import
from .optimizer import *

Expand Down Expand Up @@ -57,7 +57,9 @@

from .rmsprop import *

from .lans import *

__all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \
+ signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \
+ adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \
+ lars.__all__ + lamb.__all__ + rmsprop.__all__
+ lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__
220 changes: 220 additions & 0 deletions python/mxnet/optimizer/lans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""LANS optimizer."""
from __future__ import absolute_import
import numpy
from ..ndarray import (zeros, clip, sqrt, where, square, ones_like,
maximum, minimum)
from ..ndarray.contrib import (multi_lans_update, multi_mp_lans_update)
from .optimizer import Optimizer, register

__all__ = ['LANS']


@register
class LANS(Optimizer):
"""LANS Optimizer.

Referenced from 'Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes'
(http://arxiv.org/abs/2006.13484)

Parameters
----------
learning_rate : float, default 0.001
The initial learning rate. If None, the optimization will use the
learning rate from ``lr_scheduler``. If not None, it will overwrite
the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
is also None, then it will be set to 0.01 by default.
beta1 : float, default 0.9
Exponential decay rate for the first moment estimates.
beta2 : float, default 0.999
Exponential decay rate for the second moment estimates.
epsilon : float, default 1e-6
Small value to avoid division by 0.
lower_bound : float, default None
Lower limit of norm of weight
upper_bound : float, default None
Upper limit of norm of weight
aggregate_num : int, default 4
Number of weights to be aggregated in a list.
They are passed to the optimizer for a single optimization step.
In default, all the weights are aggregated.
use_fused_step : bool, default True
Whether or not to use fused kernels for optimizer.
When use_fused_step=False, step is called,
otherwise, fused_step is called.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=None, upper_bound=None, aggregate_num=4, use_fused_step=True,
**kwargs):
assert aggregate_num <= 45,\
'When use_fused_step is True, LAMB only supports aggregate_num <= 45,' \
' and receives {}'.format(aggregate_num)
super(LANS, self).__init__(learning_rate=learning_rate,
aggregate_num=aggregate_num,
use_fused_step=use_fused_step,
**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound

def create_state(self, index, weight):
stype = weight.stype
return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype), # mean
zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype)) # var

def step(self, indices, weights, grads, states):
"""Perform a fused optimization step using gradients and states.
Fused kernel is used for update.

Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
for index, weight, grad, state in zip(indices, weights, grads, states):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

# preprocess grad
grad *= self.rescale_grad
grad /= grad.norm()
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)

# update mean, var
mean, var = state
mean[:] *= self.beta1
mean[:] += (1. - self.beta1) * grad
var[:] *= self.beta2
var[:] += (1. - self.beta2) * square(grad)

r1 = weight.norm()
if self.lower_bound is not None:
r1 = maximum(r1, self.lower_bound)
if self.upper_bound is not None:
r1 = minimum(r1, self.upper_bound)

# apply bias correction
coef1 = 1. - self.beta1 ** t
coef2 = 1. - self.beta2 ** t
mean_hat = mean / coef1
var_hat = var / coef2
sqrt(var_hat, out=var_hat)
var_hat += self.epsilon
mean_hat /= var_hat
mean_hat += wd * weight

g = mean_hat
r2 = g.norm()

# calculate lans_trust_ratio for first part
ratio_m = r1 / r2
# becomes NaN if ratio == NaN or 0, otherwise 0
nan_or_zero = 1 - ratio_m / ratio_m
r_m = where(nan_or_zero, ones_like(ratio_m), ratio_m)

# update weight using first part of the estimator
g *= lr * r_m * self.beta1
weight[:] -= g

# calculate the second part of the estimator
mean_hat = grad / var_hat
mean_hat += wd * weight

g = mean_hat
r2 = g.norm()

# calculate lans_trust_ratio for second part
ratio_g = r1 / r2
# becomes NaN if ratio == NaN or 0, otherwise 0
nan_or_zero = 1 - ratio_g / ratio_g
r_g = where(nan_or_zero, ones_like(ratio_g), ratio_g)

# update weight using second part of the estimator
g *= lr * r_g * (1 - self.beta1)
weight[:] -= g

def fused_step(self, indices, weights, grads, states):
"""Perform a fused optimization step using gradients and states.
Fused kernel is used for update.

Parameters
----------
indices : list of int
List of unique indices of the parameters into the individual learning rates
and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
and `set_wd_mult()`, respectively.
weights : list of NDArray
List of parameters to be updated.
grads : list of NDArray
List of gradients of the objective with respect to this parameter.
states : List of any obj
List of state returned by `create_state()`.
"""
self._update_count(indices)
lrs = self._get_lrs(indices)
wds = self._get_wds(indices)

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound

step_counts = []
for index in indices:
step_counts.append(self._index_update_count[index])

multi_precision = self.multi_precision and weights[0].dtype == numpy.float16

if not multi_precision:
mean, var = list(zip(*states))
multi_lans_update(weights, grads, mean, var,
out=weights, step_count=step_counts,
lrs=lrs, wds=wds, **kwargs)
else:
weights32, mean_var = list(zip(*states))
mean, var = list(zip(*mean_var))
multi_mp_lans_update(weights, grads,
mean, var, weights32,
out=weights, step_count=step_counts,
lrs=lrs, wds=wds, **kwargs)

def update_multi_precision(self, indices, weights, grads, states):
"""Override update_multi_precision.
"""
if self.use_fused_step:
self.update(indices, weights, grads, states)
else:
super(LANS, self).update_multi_precision(indices, weights, grads, states)
Loading