From ea8a692725ac227df6bdc393feb6f0ed4eb32ed9 Mon Sep 17 00:00:00 2001 From: Zheng Date: Wed, 24 Jun 2020 18:31:08 -0700 Subject: [PATCH 1/3] add lans optimizer --- python/mxnet/ndarray/contrib.py | 78 +++++ python/mxnet/optimizer/__init__.py | 6 +- python/mxnet/optimizer/lans.py | 220 ++++++++++++++ src/operator/contrib/multi_lans-inl.h | 385 ++++++++++++++++++++++++ src/operator/contrib/multi_lans.cc | 267 ++++++++++++++++ src/operator/contrib/multi_lans.cu | 287 ++++++++++++++++++ src/operator/contrib/multi_sum_sq-inl.h | 13 +- src/operator/contrib/multi_sum_sq.cc | 20 +- src/operator/contrib/multi_sum_sq.cu | 16 +- tests/python/unittest/test_optimizer.py | 30 ++ 10 files changed, 1302 insertions(+), 20 deletions(-) create mode 100644 python/mxnet/optimizer/lans.py create mode 100644 src/operator/contrib/multi_lans-inl.h create mode 100644 src/operator/contrib/multi_lans.cc create mode 100644 src/operator/contrib/multi_lans.cu diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 2ff422f29497..b1700bbe8c52 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -680,3 +680,81 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count, learning_rates=lrs, wds=wds, **kwargs) + + +def multi_lans_update(weights, grads, mean, var, step_count, + lrs, wds, out=None, num_tensors=0, **kwargs): + """Given a list of gradients, update weights, mean and variance of multiple tensors + following LANS Optimizer implementation. + + Parameters + ---------- + weights : List of NDArrays containing the input weights of multiple tensors + + grads : List of NDArrays containing input gradients + + mean : List of NDArrays containing mean of multiple tensors to be updated + + var : List of NDArrays containing variance of multiple tensors to be updated + + step_count : List of scalars with the number of update step for each tensor + + lrs : List of learning rates (one for each tensor) + + wds : List of weight decays (one for each tensor) + + out: List of NDArrays where the updated weights will be stored + + num_tensors : Number of NDArrays/tensors in the list + """ + + if not num_tensors: + num_tensors = len(weights) + temp_list = _flatten_list(zip(weights, grads, mean, var)) + return ndarray._internal._multi_lans_update(*temp_list, + out=out, + num_tensors=num_tensors, + step_count=step_count, + learning_rates=lrs, + wds=wds, + **kwargs) + + +def multi_mp_lans_update(weights, grads, mean, var, weights32, step_count, + lrs, wds, out=None, num_tensors=0, **kwargs): + """Given a list of gradients, update weights, mean and variance of multiple tensors + following LANS Optimizer implementation, and using Mixed-Precision. + + Parameters + ---------- + weights : List of NDArrays containing the input weights of multiple tensors + + grads : List of NDArrays containing input gradients + + mean : List of NDArrays containing mean of multiple tensors to be updated + + var : List of NDArrays containing variance of multiple tensors to be updated + + weights32 : Master copy of weights in FP32 + + step_count : List of scalars with the number of update step for each tensor + + lrs : List of learning rates (one for each tensor) + + wds : List of weight decays (one for each tensor) + + out: List of NDArrays where the updated weights will be stored + + num_tensors : Number of NDArrays/tensors in the list + """ + + if not num_tensors: + num_tensors = len(weights) + temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + return ndarray._internal._multi_mp_lans_update(*temp_list, + out=out, + num_tensors=num_tensors, + step_count=step_count, + learning_rates=lrs, + wds=wds, + **kwargs) diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py index 89b37de1c873..4f11d78a56e4 100644 --- a/python/mxnet/optimizer/__init__.py +++ b/python/mxnet/optimizer/__init__.py @@ -19,7 +19,7 @@ from . import (optimizer, contrib, updater, utils, sgd, sgld, signum, dcasgd, nag, adagrad, adadelta, adam, adamax, nadam, ftrl, - ftml, lars, lamb, rmsprop) + ftml, lars, lamb, rmsprop, lans) # pylint: disable=wildcard-import from .optimizer import * @@ -57,7 +57,9 @@ from .rmsprop import * +from .lans import * + __all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \ + signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \ + adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \ - + lars.__all__ + lamb.__all__ + rmsprop.__all__ + + lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__ diff --git a/python/mxnet/optimizer/lans.py b/python/mxnet/optimizer/lans.py new file mode 100644 index 000000000000..54ea4d33e6ca --- /dev/null +++ b/python/mxnet/optimizer/lans.py @@ -0,0 +1,220 @@ +# coding: utf-8 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""LANS optimizer.""" +from __future__ import absolute_import +import numpy +from ..ndarray import (zeros, clip, sqrt, where, square, ones_like, + maximum, minimum) +from ..ndarray.contrib import (multi_lans_update, multi_mp_lans_update) +from .optimizer import Optimizer, register + +__all__ = ['LANS'] + + +@register +class LANS(Optimizer): + """LANS Optimizer. + + Referenced from 'Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes' + (http://arxiv.org/abs/2006.13484) + + Parameters + ---------- + learning_rate : float, default 0.001 + The initial learning rate. If None, the optimization will use the + learning rate from ``lr_scheduler``. If not None, it will overwrite + the learning rate in ``lr_scheduler``. If None and ``lr_scheduler`` + is also None, then it will be set to 0.01 by default. + beta1 : float, default 0.9 + Exponential decay rate for the first moment estimates. + beta2 : float, default 0.999 + Exponential decay rate for the second moment estimates. + epsilon : float, default 1e-6 + Small value to avoid division by 0. + lower_bound : float, default None + Lower limit of norm of weight + upper_bound : float, default None + Upper limit of norm of weight + aggregate_num : int, default 4 + Number of weights to be aggregated in a list. + They are passed to the optimizer for a single optimization step. + In default, all the weights are aggregated. + use_fused_step : bool, default True + Whether or not to use fused kernels for optimizer. + When use_fused_step=False, step is called, + otherwise, fused_step is called. + """ + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, + lower_bound=None, upper_bound=None, aggregate_num=4, use_fused_step=True, + **kwargs): + assert aggregate_num <= 45,\ + 'When use_fused_step is True, LAMB only supports aggregate_num <= 45,' \ + ' and receives {}'.format(aggregate_num) + super(LANS, self).__init__(learning_rate=learning_rate, + aggregate_num=aggregate_num, + use_fused_step=use_fused_step, + **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + def create_state(self, index, weight): + stype = weight.stype + return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype), # mean + zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype)) # var + + def step(self, indices, weights, grads, states): + """Perform a fused optimization step using gradients and states. + Fused kernel is used for update. + + Parameters + ---------- + indices : list of int + List of unique indices of the parameters into the individual learning rates + and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` + and `set_wd_mult()`, respectively. + weights : list of NDArray + List of parameters to be updated. + grads : list of NDArray + List of gradients of the objective with respect to this parameter. + states : List of any obj + List of state returned by `create_state()`. + """ + for index, weight, grad, state in zip(indices, weights, grads, states): + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + t = self._index_update_count[index] + + # preprocess grad + grad *= self.rescale_grad + grad /= grad.norm() + if self.clip_gradient is not None: + grad = clip(grad, -self.clip_gradient, self.clip_gradient) + + # update mean, var + mean, var = state + mean[:] *= self.beta1 + mean[:] += (1. - self.beta1) * grad + var[:] *= self.beta2 + var[:] += (1. - self.beta2) * square(grad) + + r1 = weight.norm() + if self.lower_bound is not None: + r1 = maximum(r1, self.lower_bound) + if self.upper_bound is not None: + r1 = minimum(r1, self.upper_bound) + + # apply bias correction + coef1 = 1. - self.beta1 ** t + coef2 = 1. - self.beta2 ** t + mean_hat = mean / coef1 + var_hat = var / coef2 + sqrt(var_hat, out=var_hat) + var_hat += self.epsilon + mean_hat /= var_hat + mean_hat += wd * weight + + g = mean_hat + r2 = g.norm() + + # calculate lans_trust_ratio for first part + ratio_m = r1 / r2 + # becomes NaN if ratio == NaN or 0, otherwise 0 + nan_or_zero = 1 - ratio_m / ratio_m + r_m = where(nan_or_zero, ones_like(ratio_m), ratio_m) + + # update weight using first part of the estimator + g *= lr * r_m * self.beta1 + weight[:] -= g + + # calculate the second part of the estimator + mean_hat = grad / var_hat + mean_hat += wd * weight + + g = mean_hat + r2 = g.norm() + + # calculate lans_trust_ratio for second part + ratio_g = r1 / r2 + # becomes NaN if ratio == NaN or 0, otherwise 0 + nan_or_zero = 1 - ratio_g / ratio_g + r_g = where(nan_or_zero, ones_like(ratio_g), ratio_g) + + # update weight using second part of the estimator + g *= lr * r_g * (1 - self.beta1) + weight[:] -= g + + def fused_step(self, indices, weights, grads, states): + """Perform a fused optimization step using gradients and states. + Fused kernel is used for update. + + Parameters + ---------- + indices : list of int + List of unique indices of the parameters into the individual learning rates + and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` + and `set_wd_mult()`, respectively. + weights : list of NDArray + List of parameters to be updated. + grads : list of NDArray + List of gradients of the objective with respect to this parameter. + states : List of any obj + List of state returned by `create_state()`. + """ + self._update_count(indices) + lrs = self._get_lrs(indices) + wds = self._get_wds(indices) + + kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, + 'rescale_grad': self.rescale_grad} + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + if self.lower_bound: + kwargs['lower_bound'] = self.lower_bound + if self.upper_bound: + kwargs['upper_bound'] = self.upper_bound + + step_counts = [] + for index in indices: + step_counts.append(self._index_update_count[index]) + + multi_precision = self.multi_precision and weights[0].dtype == numpy.float16 + + if not multi_precision: + mean, var = list(zip(*states)) + multi_lans_update(weights, grads, mean, var, + out=weights, step_count=step_counts, + lrs=lrs, wds=wds, **kwargs) + else: + weights32, mean_var = list(zip(*states)) + mean, var = list(zip(*mean_var)) + multi_mp_lans_update(weights, grads, + mean, var, weights32, + out=weights, step_count=step_counts, + lrs=lrs, wds=wds, **kwargs) + + def update_multi_precision(self, indices, weights, grads, states): + """Override update_multi_precision. + """ + if self.use_fused_step: + self.update(indices, weights, grads, states) + else: + super(LANS, self).update_multi_precision(indices, weights, grads, states) diff --git a/src/operator/contrib/multi_lans-inl.h b/src/operator/contrib/multi_lans-inl.h new file mode 100644 index 000000000000..958b6bd9f25e --- /dev/null +++ b/src/operator/contrib/multi_lans-inl.h @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file multi_lans-inl.h + * \brief multi-tensor LANS optimizer + * \author Shuai Zheng + */ +#ifndef MXNET_OPERATOR_CONTRIB_MULTI_LANS_INL_H_ +#define MXNET_OPERATOR_CONTRIB_MULTI_LANS_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../tensor/init_op.h" +#include "../tensor/util/tensor_util-inl.h" +#include "multi_sum_sq-inl.h" + +namespace mxnet { +namespace op { + +namespace multilans { +enum MultiLANSUpdateResource {kTempSpace}; +} // namespace multilans + +struct MultiLANSParam : public dmlc::Parameter { + mxnet::Tuple learning_rates; + mxnet::Tuple wds; + float beta1; + float beta2; + float epsilon; + float rescale_grad; + float lower_bound; + float upper_bound; + float clip_gradient; + int num_tensors; + mxnet::Tuple step_count; + + DMLC_DECLARE_PARAMETER(MultiLANSParam) { + DMLC_DECLARE_FIELD(learning_rates) + .describe("List of learning rates"); + DMLC_DECLARE_FIELD(beta1) + .set_default(0.9f) + .describe("Exponential decay rate for the first moment estimates."); + DMLC_DECLARE_FIELD(beta2) + .set_default(0.999f) + .describe("Exponential decay rate for the second moment estimates."); + DMLC_DECLARE_FIELD(epsilon) + .set_default(1e-6f) + .describe("Small value to avoid division by 0."); + DMLC_DECLARE_FIELD(wds) + .describe("List of Weight decays." + "Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Gradient rescaling factor"); + DMLC_DECLARE_FIELD(lower_bound) + .set_default(-1.0f) + .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set"); + DMLC_DECLARE_FIELD(upper_bound) + .set_default(-1.0f) + .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set"); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(step_count) + .describe("Step count for each tensor"); + DMLC_DECLARE_FIELD(num_tensors) + .set_default(1) + .describe("Number of tensors"); + } +}; + +template +inline bool MultiLANSInferShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_tensors); + CHECK_EQ(out_attrs->size(), param.num_tensors); + + bool all_inferred = true; + auto& input_shapes = *in_attrs; + auto& output_shapes = *out_attrs; + + CHECK_LE(param.num_tensors, 45) + << "Invalid number of tensors, the maximum value is 45, and got " + << param.num_tensors; + CHECK_EQ(param.learning_rates.ndim(), param.num_tensors) + << "Number of learning rates is inconsistent with num_tensors " + << "parameter passed. Expected number of learning rates: " + << param.num_tensors << ", and got " << param.learning_rates.ndim(); + CHECK_EQ(param.wds.ndim(), param.num_tensors) + << "Number of weight decays is inconsistent with num_tensors " + << "parameter passed. Expected number of weight decays: " + << param.num_tensors << ", and got " << param.wds.ndim(); + CHECK_EQ(param.step_count.ndim(), param.num_tensors) + << "Number of step counts is inconsistent with num_tensors." + << "Expected number of step counts: " + << param.num_tensors << ", and got " << param.step_count.ndim(); + + // Weights, gradients, mean and variance + for (int i = 0; i < param.num_tensors; ++i) { + mxnet::ShapeVector input_vec; + mxnet::ShapeVector output_vec({output_shapes[i]}); + for (int j = 0; j < input_stride; ++j) { + input_vec.push_back(input_shapes[i * input_stride + j]); + } + all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); + } + return all_inferred; +} + +template +inline bool MPMultiLANSInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_tensors); + CHECK_EQ(out_attrs->size(), param.num_tensors); + + bool all_inferred = true; + auto& input_types = *in_attrs; + auto& output_types = *out_attrs; + + // weights, gradients + for (int i = 0; i < param.num_tensors; ++i) { + std::vector input_vec; + std::vector output_vec({output_types[i]}); + for (int j = 0; j < 2; ++j) { + input_vec.push_back(input_types[i * input_stride + j]); + } + all_inferred = all_inferred && + ElemwiseType<2, 1>(attrs, &input_vec, &output_vec); + } + + // mean, var, weights32 (master copies of weights) + for (int i = 0; i < param.num_tensors; ++i) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i + 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i + 3, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1, mshadow::kFloat32); + } + return all_inferred; +} + +template +class LANSTypeIdentity { + public: + using type = T; +}; + +template +class LANSSinglePrecision { + public: + using type = float; +}; + +template +struct MultiLANSKernelParam { + static const int N = 45; + size_t ntensors; + size_t max_size; + size_t total_size; + size_t sizes[N]; + size_t tensor2temp_g[N]; + DType* weights[N]; + DType* grads[N]; + MPDType* mean[N]; + MPDType* var[N]; + MPDType* weights32[N]; + DType* out_data[N]; + int step_count[N]; + MPDType learning_rates[N]; + MPDType wds[N]; + + // gpu + int chunk_size = 65536; + int nchunks; +}; + +template +void FillMultiLANSKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + MultiLANSKernelParam *multi_param) { + const ParamType& p = nnvm::get(attrs.parsed); + mxnet_op::Stream* s = ctx.get_stream(); + + multi_param->ntensors = p.num_tensors; + multi_param->total_size = 0; + multi_param->max_size = 0; + multi_param->nchunks = 0; + + constexpr bool is_same = std::is_same::value; + for (size_t i = 0; i < multi_param->ntensors; ++i) { + const auto idx = i * input_stride; + multi_param->sizes[i] = inputs[idx].shape_.Size(); + multi_param->tensor2temp_g[i] = multi_param->total_size; + multi_param->total_size += multi_param->sizes[i]; + if (multi_param->max_size < multi_param->sizes[i]) + multi_param->max_size = multi_param->sizes[i]; + + multi_param->weights[i] = inputs[idx].FlatTo2D(s).dptr_; + multi_param->grads[i] = inputs[idx + 1].FlatTo2D(s).dptr_; + multi_param->mean[i] = inputs[idx + 2].FlatTo2D(s).dptr_; + multi_param->var[i] = inputs[idx + 3].FlatTo2D(s).dptr_; + + // if mixed precision, then the last input in a set + // is 32-bit master copy of the weights + if (!is_same) + multi_param->weights32[i] = inputs[idx + input_stride - 1].FlatTo2D(s).dptr_; + multi_param->out_data[i] = outputs[i].FlatTo2D(s).dptr_; + multi_param->nchunks += (multi_param->sizes[i] + multi_param->chunk_size - 1) + / multi_param->chunk_size; + multi_param->learning_rates[i] = static_cast(p.learning_rates[i]); + multi_param->wds[i] = static_cast(p.wds[i]); + } + memcpy(multi_param->step_count, p.step_count.begin(), multi_param->ntensors * sizeof(int)); +} + +using namespace mxnet_op; +template +void CallKernel1(Stream* s); +template +void CallKernel1(Stream* s); + +template +void CallKernel2(Stream* s); +template +void CallKernel2(Stream* s); + +template class MPTypeChooser, int input_stride> +inline void MultiLANS(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiLANSKernelParam kernel_params; + FillMultiLANSKernelParam + (attrs, ctx, inputs, outputs, &kernel_params); + + // create vector of TBlob with all the weights contiguous + std::vector weights; + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + weights.emplace_back(inputs[index*input_stride]); + } + + // create vector of TBlob with all the weights contiguous + std::vector grads; + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + grads.emplace_back(inputs[index*input_stride+1]); + } + + // Calculate amount of temporary storage (temp_m, temp_g, r1, r2_m, r2_g, g_sq_norm, + // block_to_tensor, block_to_chunk) + size_t workspace_size = 2 * kernel_params.total_size * sizeof(float) + + 4 * kernel_params.ntensors * sizeof(float) + + 2 * kernel_params.nchunks * sizeof(int); + + // take into account the required storage required within MultiSumSqRun + size_t required_storage_multi_sum_sq = 0; + required_storage_multi_sum_sq = GetRequiredStorageMultiSumSq(inputs); + workspace_size += required_storage_multi_sum_sq; + + // Request temporary storage + Tensor workspace = + ctx.requested[multilans::kTempSpace].get_space_typed( + Shape1(workspace_size), s); + + // Create tensors + size_t pos_wspace = required_storage_multi_sum_sq; + Tensor temp_m(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.total_size), s); + // create vector of TBlob with all the temp_m and temp_g contiguous + std::vector temp_m_tblobs; + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + Tensor aux(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.sizes[index]), s); + TBlob newtblob(aux); + temp_m_tblobs.emplace_back(newtblob); + pos_wspace += kernel_params.sizes[index] * sizeof(float); + } + Tensor temp_g(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.total_size), s); + std::vector temp_g_tblobs; + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + Tensor aux(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.sizes[index]), s); + TBlob newtblob(aux); + temp_g_tblobs.emplace_back(newtblob); + pos_wspace += kernel_params.sizes[index] * sizeof(float); + } + Tensor r1(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.ntensors), s); + pos_wspace += kernel_params.ntensors * sizeof(float); + Tensor r2_m(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.ntensors), s); + pos_wspace += kernel_params.ntensors * sizeof(float); + Tensor r2_g(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.ntensors), s); + pos_wspace += kernel_params.ntensors * sizeof(float); + Tensor g_sq_norm(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.ntensors), s); + pos_wspace += kernel_params.ntensors * sizeof(float); + Tensor block_to_tensor(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.nchunks), s); + pos_wspace += kernel_params.nchunks * sizeof(int); + Tensor block_to_chunk(reinterpret_cast(&workspace[pos_wspace]), + Shape1(kernel_params.nchunks), s); + + MultiSumSqRun(weights, kernel_params.ntensors, r1.dptr_, ctx); + MultiSumSqRun(grads, kernel_params.ntensors, g_sq_norm.dptr_, ctx, param.rescale_grad); + CallKernel1(s, kernel_params, param, + g_sq_norm.dptr_, + temp_m.dptr_, + temp_g.dptr_, + block_to_tensor.dptr_, + block_to_chunk.dptr_); + MultiSumSqRun(temp_m_tblobs, kernel_params.ntensors, r2_m.dptr_, ctx); + MultiSumSqRun(temp_g_tblobs, kernel_params.ntensors, r2_g.dptr_, ctx); + CallKernel2(s, kernel_params, param, r1.dptr_, + r2_m.dptr_, r2_g.dptr_, + temp_m.dptr_, temp_g.dptr_, + block_to_tensor.dptr_, block_to_chunk.dptr_, + req[0]); + }); +} + +template +inline void MultiLANSUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + if (!MP) { + MultiLANS + (attrs, ctx, inputs, req, outputs); + } else { + MultiLANS + (attrs, ctx, inputs, req, outputs); + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CONTRIB_MULTI_LANS_INL_H_ diff --git a/src/operator/contrib/multi_lans.cc b/src/operator/contrib/multi_lans.cc new file mode 100644 index 000000000000..37d6001312c4 --- /dev/null +++ b/src/operator/contrib/multi_lans.cc @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file multi_lans.cc + * \brief multi-tensor LANS optimizer + * \author Shuai Zheng + */ + +#include "./multi_lans-inl.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +template +struct MultiLANSKernelStep1 { + template + MSHADOW_XINLINE static void Map(int i, + const MultiLANSKernelParam& kernel_params, + const float beta1, const float beta2, + const float epsilon, + const float clip_gradient, + const float rescale_grad, + float* g_sq_norm, + float* temp_m, float* temp_g) { + using namespace mshadow_op; + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + if ((size_t)i < kernel_params.sizes[index]) { + MPDType w = has_mixed_precision ? kernel_params.weights32[index][i]: + MPDType(kernel_params.weights[index][i]); + float g_norm = sqrt(g_sq_norm[index]); + MPDType scaled_grad = static_cast(kernel_params.grads[index][i]) * rescale_grad; + scaled_grad /= g_norm; + if (clip_gradient >= 0.0f) + scaled_grad = mshadow_op::clip::Map(scaled_grad, static_cast(clip_gradient)); + MPDType mean = static_cast(beta1) * kernel_params.mean[index][i] + + (static_cast(1.0f) - static_cast(beta1)) * scaled_grad; + MPDType var = static_cast(beta2) * kernel_params.var[index][i] + + (static_cast(1.0f) - static_cast(beta2)) * scaled_grad * scaled_grad; + kernel_params.mean[index][i] = mean; + kernel_params.var[index][i] = var; + + MPDType m, g; + MPDType mean_hat = mean / (static_cast(1.0f) - + power::Map(static_cast(beta1), + static_cast(kernel_params.step_count[index]))); + MPDType var_hat = var / (static_cast(1.0f) - + power::Map(static_cast(beta2), + static_cast(kernel_params.step_count[index]))); + var_hat = sqrt(var_hat) + static_cast(epsilon); + MPDType scaled_w = kernel_params.wds[index] * w; + m = mean_hat / var_hat + scaled_w; + g = scaled_grad / var_hat + scaled_w; + temp_m[kernel_params.tensor2temp_g[index]+i] = m; + temp_g[kernel_params.tensor2temp_g[index]+i] = g; + } + } + } +}; + +template +struct MultiLANSKernelStep2 { + template + MSHADOW_XINLINE static void Map(int i, + const MultiLANSKernelParam& kernel_params, + const float beta1, + const float* sum_sq_weigths, + const float* sum_sq_temp_m, + const float* sum_sq_temp_g, + const float* temp_m, + const float* temp_g, + const float lower_bound, + const float upper_bound, + const OpReqType req) { + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + if ((size_t)i < kernel_params.sizes[index]) { + MPDType w = has_mixed_precision ? kernel_params.weights32[index][i]: + MPDType(kernel_params.weights[index][i]); + float r1 = sqrt(sum_sq_weigths[index]); + float r2_m = sqrt(sum_sq_temp_m[index]); + float r2_g = sqrt(sum_sq_temp_g[index]); + if (lower_bound >= 0) + r1 = std::max(r1, lower_bound); + if (upper_bound >= 0) + r1 = std::min(r1, upper_bound); + + // calculate nesterov lamb_trust_ratio + MPDType r_m, r_g; + if (r1 == 0.0f || r2_m == 0.0f) + r_m = 1.0f; + else + r_m = r1/r2_m; + if (r1 == 0.0f || r2_g == 0.0f) + r_g = 1.0f; + else + r_g = r1/r2_g; + r_m *= static_cast(beta1); + r_g *= (1. - static_cast(beta1)); + + MPDType lr_adjusted_m = kernel_params.learning_rates[index] * r_m; + MPDType lr_adjusted_g = kernel_params.learning_rates[index] * r_g; + w -= lr_adjusted_m * temp_m[kernel_params.tensor2temp_g[index]+i] + + lr_adjusted_g * temp_g[kernel_params.tensor2temp_g[index]+i]; + + // update weights + if (has_mixed_precision) + kernel_params.weights32[index][i] = w; + KERNEL_ASSIGN(kernel_params.out_data[index][i], req, w); + } + } + } +}; + +template +void CallKernel1(Stream* s, + const MultiLANSKernelParam& kernel_params, + const MultiLANSParam ¶m, + float* g_sq_norm, + float* temp_m, + float* temp_g, + int* block_to_tensor, + int* block_to_chunk) { + Kernel::value>, cpu>:: + Launch(s, kernel_params.max_size, + kernel_params, + param.beta1, param.beta2, + param.epsilon, + param.clip_gradient, + param.rescale_grad, + g_sq_norm, + temp_m, + temp_g); +} + +template +void CallKernel2(Stream* s, + const MultiLANSKernelParam& kernel_params, + const MultiLANSParam ¶m, + float* r1, float* r2_m, float* r2_g, + float* temp_m, float* temp_g, + int* block_to_tensor, + int* block_to_chunk, + const OpReqType req) { + Kernel::value>, cpu>:: + Launch(s, kernel_params.max_size, + kernel_params, + param.beta1, + r1, r2_m, r2_g, + temp_m, temp_g, + param.lower_bound, param.upper_bound, + req); +} + +DMLC_REGISTER_PARAMETER(MultiLANSParam); + +std::vector LANSParamToVector(uint32_t num_tensors, + const char *p_names[], + size_t n_params) { + std::vector ret; + for (uint32_t i = 0; i < num_tensors; ++i) { + const auto idx = std::to_string(i); + for (size_t j = 0; j < n_params; ++j) + ret.push_back(std::string(p_names[i]) + idx); + } + return ret; +} + +inline uint32_t NumTensors(const nnvm::NodeAttrs& attrs) { + return static_cast(dmlc::get(attrs.parsed).num_tensors); +} + +NNVM_REGISTER_OP(_multi_lans_update) +.describe(R"code(Compute the LANS coefficients of multiple weights and grads" +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + return NumTensors(attrs) * 4; + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + return NumTensors(attrs); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiLANSInferShape) +.set_attr("FInferType", ElemwiseType<-1, -1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const char *param_names[] = {"weight_", "grad_", "mean_", "var_"}; + return LANSParamToVector(NumTensors(attrs), param_names, + sizeof(param_names)/sizeof(param_names[0])); + }) +// mutable: mean, var +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto i_max = NumTensors(attrs); + for (size_t i = 0; i < i_max; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", MultiLANSUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "data") +.add_arguments(MultiLANSParam::__FIELDS__()); + + +NNVM_REGISTER_OP(_multi_mp_lans_update) +.describe(R"code(Compute the Nesterov LAMB coefficients of multiple weights and grads with Mix Precision" +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + return NumTensors(attrs) * 5; + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + return NumTensors(attrs); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiLANSInferShape) +.set_attr("FInferType", MPMultiLANSInferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const char *param_names[] = {"weight_", "grad_", "mean_", "var_", "weight32_"}; + return LANSParamToVector(NumTensors(attrs), param_names, + sizeof(param_names)/sizeof(param_names[0])); + }) +// mutable: mean, var, weights32 +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto i_max = NumTensors(attrs); + for (size_t i = 0; i < i_max; ++i) { + ret.push_back(i * 5 + 2); + ret.push_back(i * 5 + 3); + ret.push_back(i * 5 + 4); + } + return ret; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", MultiLANSUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "data") +.add_arguments(MultiLANSParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/multi_lans.cu b/src/operator/contrib/multi_lans.cu new file mode 100644 index 000000000000..64de72116514 --- /dev/null +++ b/src/operator/contrib/multi_lans.cu @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file multi_lans.cu + * \brief multi-tensor LANS optimizer + * \author Shuai Zheng + */ + +#include "./multi_lans-inl.h" + +namespace mxnet { +namespace op { + +#define BLOCK_SIZE_LAMB 512 +#define ILP_LAMB 4 + +template +__global__ void KernelStep1(const MultiLANSKernelParam kernel_params, + const float beta1, const float beta2, + const MPDType beta3, const MPDType beta4, + const float epsilon, + const float clip_gradient, + const float rescale_grad, + float* g_sq_norm, + float* temp_m, + float* temp_g, + int* block_to_tensor, + int* block_to_chunk) { + const int tensor_id = block_to_tensor[blockIdx.x]; + const int chunck_id = block_to_chunk[blockIdx.x]; + const int start_pos = chunck_id * kernel_params.chunk_size + threadIdx.x; + const int stop_pos = chunck_id * kernel_params.chunk_size + kernel_params.chunk_size; + + MPDType g_norm = sqrtf(g_sq_norm[tensor_id]); + + MPDType biascorrection1, biascorrection2; + + biascorrection1 = 1.0 - + static_cast(std::pow(beta1, kernel_params.step_count[tensor_id])); + biascorrection2 = 1.0 - + static_cast(std::pow(beta2, kernel_params.step_count[tensor_id])); + + MPDType r_weight[ILP_LAMB]; + MPDType r_grad[ILP_LAMB]; + MPDType r_mean[ILP_LAMB]; + MPDType r_var[ILP_LAMB]; + MPDType r_m[ILP_LAMB]; + MPDType r_g[ILP_LAMB]; + + for (size_t i = start_pos; i < stop_pos && i < kernel_params.sizes[tensor_id]; + i+= blockDim.x*ILP_LAMB) { +#pragma unroll + for (int ii = 0; ii < ILP_LAMB; ii++) { + int load_pos = i + ii*blockDim.x; + if (load_pos < stop_pos && load_pos < kernel_params.sizes[tensor_id]) { + r_weight[ii] = has_mixed_precision ? kernel_params.weights32[tensor_id][load_pos]: + static_cast(kernel_params.weights[tensor_id][load_pos]); + r_grad[ii] = static_cast(kernel_params.grads[tensor_id][load_pos]); + r_mean[ii] = kernel_params.mean[tensor_id][load_pos]; + r_var[ii] = kernel_params.var[tensor_id][load_pos]; + } else { + r_weight[ii] = static_cast(0); + r_grad[ii] = static_cast(0); + r_mean[ii] = static_cast(0); + r_var[ii] = static_cast(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP_LAMB; ii++) { + r_grad[ii] = (r_grad[ii] * rescale_grad) / g_norm; + if (clip_gradient >= 0.0f) + r_grad[ii] = max(min(r_grad[ii], clip_gradient), -clip_gradient); + r_mean[ii] = static_cast(beta1) * r_mean[ii] + beta3 * r_grad[ii]; + r_var[ii] = static_cast(beta2) * r_var[ii] + beta4 * r_grad[ii] * r_grad[ii]; + MPDType r_var_hat = sqrt(r_var[ii] / biascorrection2) + static_cast(epsilon); + r_m[ii] = (r_mean[ii] / biascorrection1) / r_var_hat; + r_g[ii] = r_grad[ii] / r_var_hat; + r_m[ii] = __fmaf_rn(kernel_params.wds[tensor_id], r_weight[ii], r_m[ii]); + r_g[ii] = __fmaf_rn(kernel_params.wds[tensor_id], r_weight[ii], r_g[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP_LAMB; ii++) { + int store_pos = i + ii*blockDim.x; + if (store_pos < stop_pos && store_pos < kernel_params.sizes[tensor_id]) { + kernel_params.mean[tensor_id][store_pos] = r_mean[ii]; + kernel_params.var[tensor_id][store_pos] = r_var[ii]; + temp_m[kernel_params.tensor2temp_g[tensor_id]+store_pos] = r_m[ii]; + temp_g[kernel_params.tensor2temp_g[tensor_id]+store_pos] = r_g[ii]; + } + } + } +} + +template +__global__ void KernelStep2(const MultiLANSKernelParam kernel_params, + const float beta1, + const MPDType beta3, + const float* sum_sq_weigths, + const float* sum_sq_temp_m, + const float* sum_sq_temp_g, + const float* temp_m, + const float* temp_g, + const float lower_bound, + const float upper_bound, + int* block_to_tensor, + int* block_to_chunk, + const OpReqType req) { + const int tensor_id = block_to_tensor[blockIdx.x]; + const int chunck_id = block_to_chunk[blockIdx.x]; + const int start_pos = chunck_id * kernel_params.chunk_size + threadIdx.x; + const int stop_pos = chunck_id * kernel_params.chunk_size + kernel_params.chunk_size; + + MPDType r1 = sqrtf(sum_sq_weigths[tensor_id]); + MPDType r2_m = sqrtf(sum_sq_temp_m[tensor_id]); + MPDType r2_g = sqrtf(sum_sq_temp_g[tensor_id]); + if (lower_bound >= 0) + r1 = max(r1, lower_bound); + if (upper_bound >= 0) + r1 = min(r1, upper_bound); + + MPDType lr_adjusted_m, lr_adjusted_g; + if (r1 == 0.0f || r2_m == 0.0f) + lr_adjusted_m = kernel_params.learning_rates[tensor_id]; + else + lr_adjusted_m = kernel_params.learning_rates[tensor_id] * r1/r2_m; + if (r1 == 0.0f || r2_g == 0.0f) + lr_adjusted_g = kernel_params.learning_rates[tensor_id]; + else + lr_adjusted_g = kernel_params.learning_rates[tensor_id] * r1/r2_g; + lr_adjusted_m *= static_cast(beta1); + lr_adjusted_g *= beta3; + + MPDType r_weight[ILP_LAMB]; + MPDType r_m[ILP_LAMB]; + MPDType r_g[ILP_LAMB]; + + for (size_t i=start_pos; i < stop_pos && i < kernel_params.sizes[tensor_id]; + i+= blockDim.x*ILP_LAMB) { +#pragma unroll + for (int ii = 0; ii < ILP_LAMB; ii++) { + int load_pos = i + ii*blockDim.x; + if (load_pos < stop_pos && load_pos < kernel_params.sizes[tensor_id]) { + r_weight[ii] = has_mixed_precision ? kernel_params.weights32[tensor_id][load_pos]: + static_cast(kernel_params.weights[tensor_id][load_pos]); + r_m[ii] = temp_m[kernel_params.tensor2temp_g[tensor_id]+load_pos]; + r_g[ii] = temp_g[kernel_params.tensor2temp_g[tensor_id]+load_pos]; + } + } +#pragma unroll + for (int ii = 0; ii < ILP_LAMB; ii++) { + r_weight[ii] -= lr_adjusted_m * r_m[ii] + lr_adjusted_g * r_g[ii]; + } +#pragma unroll + for (int ii = 0; ii < ILP_LAMB; ii++) { + int store_pos = i + ii*blockDim.x; + if (store_pos < stop_pos && store_pos < kernel_params.sizes[tensor_id]) { + if (has_mixed_precision) + kernel_params.weights32[tensor_id][store_pos] = r_weight[ii]; + KERNEL_ASSIGN(kernel_params.out_data[tensor_id][store_pos], req, r_weight[ii]); + } + } + } +} + +template +void CallKernel1(Stream* s, + const MultiLANSKernelParam& kernel_params, + const MultiLANSParam ¶m, + float* g_sq_norm, + float* temp_m, + float* temp_g, + int* block_to_tensor, + int* block_to_chunk) { + int nblocks = kernel_params.nchunks; + int* host_block2tensor = reinterpret_cast(malloc(kernel_params.nchunks*sizeof(int))); + int* host_block2chunk = reinterpret_cast(malloc(kernel_params.nchunks*sizeof(int))); + int chunk_id = 0; + for (size_t index = 0; index < kernel_params.ntensors; ++index) { + int current_chunk = 0; + for (size_t j = 0; j < kernel_params.sizes[index]; j+=kernel_params.chunk_size) { + host_block2tensor[chunk_id] = index; + host_block2chunk[chunk_id] = current_chunk; + current_chunk++; + chunk_id++; + } + } + cudaMemcpyAsync(block_to_tensor, host_block2tensor, kernel_params.nchunks*sizeof(int), + cudaMemcpyHostToDevice, Stream::GetStream(s)); + cudaMemcpyAsync(block_to_chunk, host_block2chunk, kernel_params.nchunks*sizeof(int), + cudaMemcpyHostToDevice, Stream::GetStream(s)); + + bool has_mixed_precision = !std::is_same::value; + MPDType beta3 = 1.0 - param.beta1; + MPDType beta4 = 1.0 - param.beta2; + + if (has_mixed_precision) + KernelStep1<<::GetStream(s)>>>( + kernel_params, + param.beta1, param.beta2, + beta3, beta4, + param.epsilon, + param.clip_gradient, + param.rescale_grad, + g_sq_norm, + temp_m, + temp_g, + block_to_tensor, + block_to_chunk); + else + KernelStep1<<::GetStream(s)>>>( + kernel_params, + param.beta1, param.beta2, + beta3, beta4, + param.epsilon, + param.clip_gradient, + param.rescale_grad, + g_sq_norm, + temp_m, + temp_g, + block_to_tensor, + block_to_chunk); + } + +template +void CallKernel2(Stream* s, + const MultiLANSKernelParam& kernel_params, + const MultiLANSParam ¶m, + float* r1, float* r2_m, float* r2_g, + float* temp_m, float* temp_g, + int* block_to_tensor, + int* block_to_chunk, + const OpReqType req) { + size_t nblocks = kernel_params.nchunks; + bool has_mixed_precision = !std::is_same::value; + MPDType beta3 = 1.0 - param.beta1; + + if (has_mixed_precision) + KernelStep2<<::GetStream(s)>>>( + kernel_params, + param.beta1, + beta3, + r1, r2_m, r2_g, + temp_m, temp_g, + param.lower_bound, param.upper_bound, + block_to_tensor, + block_to_chunk, + req); + else + KernelStep2<<::GetStream(s)>>>( + kernel_params, + param.beta1, + beta3, + r1, r2_m, r2_g, + temp_m, temp_g, + param.lower_bound, param.upper_bound, + block_to_tensor, + block_to_chunk, + req); +} + + +NNVM_REGISTER_OP(_multi_lans_update) +.set_attr("FCompute", MultiLANSUpdate); + +NNVM_REGISTER_OP(_multi_mp_lans_update) +.set_attr("FCompute", MultiLANSUpdate); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/multi_sum_sq-inl.h b/src/operator/contrib/multi_sum_sq-inl.h index f4aabc97ba3e..caf3ea239ad1 100644 --- a/src/operator/contrib/multi_sum_sq-inl.h +++ b/src/operator/contrib/multi_sum_sq-inl.h @@ -18,10 +18,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2020 by Contributors * \file multi_l2_norm-inl.h * \brief vectorized L2 norm over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez, Shuai Zheng */ @@ -41,9 +41,14 @@ namespace op { struct MultiSumSqParam : public dmlc::Parameter { int num_arrays; + float scale; + DMLC_DECLARE_PARAMETER(MultiSumSqParam) { DMLC_DECLARE_FIELD(num_arrays) .describe("number of input arrays."); + DMLC_DECLARE_FIELD(scale) + .set_default(1.0f) + .describe("Scaling factor for l2 norm"); } }; @@ -88,7 +93,7 @@ size_t GetRequiredStorageMultiSumSq(const std::vector &inputs, template void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, const OpContext &ctx); + float *out_ptr, const OpContext &ctx, float scale=1.0f); template void MultiSumSq(const nnvm::NodeAttrs& attrs, @@ -99,7 +104,7 @@ void MultiSumSq(const nnvm::NodeAttrs& attrs, auto s = ctx.get_stream(); const auto& p = dmlc::get(attrs.parsed); float* out_ptr = outputs[0].FlatTo2D(s).dptr_; - MultiSumSqRun(inputs, p.num_arrays, out_ptr, ctx); + MultiSumSqRun(inputs, p.num_arrays, out_ptr, ctx, p.scale); } } // namespace op diff --git a/src/operator/contrib/multi_sum_sq.cc b/src/operator/contrib/multi_sum_sq.cc index 9d15bf65a9e1..71c228572e82 100644 --- a/src/operator/contrib/multi_sum_sq.cc +++ b/src/operator/contrib/multi_sum_sq.cc @@ -18,10 +18,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2020 by Contributors * \file multi_sum_sq.cc * \brief vectorized sum or squared over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez, Shuai Zheng */ #include "./multi_sum_sq-inl.h" @@ -67,7 +67,7 @@ size_t GetRequiredStorageMultiSumSq(const std::vector &inputs, template inline void CalcSumSq(const std::vector &inputs, int n_inputs, - float *out_ptr, mshadow::Stream *s) { + float *out_ptr, mshadow::Stream *s, float scale) { int i; size_t j; #pragma omp parallel for private(i, j) @@ -75,18 +75,22 @@ inline void CalcSumSq(const std::vector &inputs, int n_inputs, float sum = 0; const auto address = inputs[i].FlatTo2D(s).dptr_; const auto j_max = inputs[i].shape_.Size(); - for (j = 0; j < j_max; ++j) - sum += address[j] * address[j]; - + for (j = 0; j < j_max; ++j) { + auto val = static_cast(address[j]); + if (scale != 1.0f) { + val *= scale; + } + sum += val * val; + } out_ptr[i] = sum; } } template<> void MultiSumSqRun(const std::vector &inputs, int n_inputs, - float *out_ptr, const OpContext &ctx) { + float *out_ptr, const OpContext &ctx, float scale) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, - CalcSumSq(inputs, n_inputs, out_ptr, ctx.get_stream()); + CalcSumSq(inputs, n_inputs, out_ptr, ctx.get_stream(), scale); ) } diff --git a/src/operator/contrib/multi_sum_sq.cu b/src/operator/contrib/multi_sum_sq.cu index 8d9a26676ea2..6d14e1926696 100644 --- a/src/operator/contrib/multi_sum_sq.cu +++ b/src/operator/contrib/multi_sum_sq.cu @@ -18,10 +18,10 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2020 by Contributors * \file multi_sum_sq.cu * \brief vectorized sums of squares norm over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez, Shuai Zheng */ #include "./multi_sum_sq-inl.h" #include @@ -85,7 +85,8 @@ template __global__ void MultiSumSqKernel(int chunk_size, MultiSumSqKernelParam param, float* block_reductions, - int start_tensor_id) { + int start_tensor_id, + float scale) { const int tensor_loc = param.block_to_tensor[blockIdx.x]; const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size; const int n = param.sizes[tensor_loc] - chunk_len; @@ -101,7 +102,10 @@ __global__ void MultiSumSqKernel(int chunk_size, int i = i_start + threadIdx.x; #pragma unroll for (int ii = 0; ii < ILP && i < i_max; ++ii, i += blockDim.x) { - const auto incoming_val = static_cast(x[i]); + auto incoming_val = static_cast(x[i]); + if (scale != 1.0f) { + incoming_val *= scale; + } val += incoming_val * incoming_val; } } @@ -146,7 +150,7 @@ size_t GetRequiredStorageMultiSumSq(const std::vector &inputs, template<> void MultiSumSqRun(const std::vector &inputs, int n_inputs, - float *out_ptr, const OpContext &ctx) { + float *out_ptr, const OpContext &ctx, float scale) { const int block_size = 512; using namespace mxnet_op; auto s = ctx.get_stream(); @@ -184,7 +188,7 @@ void MultiSumSqRun(const std::vector &inputs, int n_inputs, if (!(tensors_full || blocks_full || last_chunk)) continue; MultiSumSqKernel<<>> - (chunk_size, param, block_reductions.dptr_, start_tensor_id); + (chunk_size, param, block_reductions.dptr_, start_tensor_id, scale); MSHADOW_CUDA_POST_KERNEL_CHECK(MultiSumSqKernel); loc_block_info = 0; diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index e560f13647b7..294b80a65adb 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -292,6 +292,36 @@ def test_lamb(): shapes, dtype, rtol=1e-3, atol=1e-3) +@xfail_when_nonstandard_decimal_separator +@with_seed() +def test_lans(): + opt1 = mx.optimizer.LANS + opt2 = mx.optimizer.LANS + + shapes = [(3, 4, 5), (10, 4), (7,)] + beta1_options = [{}, {'beta1': 0.5}] + beta2_options = [{}, {'beta2': 0.8}] + cg_options = [{}, {'clip_gradient': 0.4}] + rg_options = [{}, {'rescale_grad': 0.14}] + wd_options = [{}, {'wd': 0.03}] + lb_options = [{'lower_bound': None}, {'lower_bound': 1e-3}] + ub_options = [{'upper_bound': None}, {'upper_bound': 10}] + mp_options = [{'multi_precision': False}, {'multi_precision': True}] + agg_options = [{'aggregate_num': 0}, {'aggregate_num': 1}, + {'aggregate_num': 4}] + for dtype in [np.float16, np.float32]: + for params in itertools.product(beta1_options, beta2_options, cg_options, rg_options, + wd_options, lb_options, ub_options, + mp_options, agg_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and ('multi_precision' not in kwarg or + not kwarg['multi_precision'])): + continue + compare_optimizer(opt1(use_fused_step=False, **kwarg), + opt2(use_fused_step=True, **kwarg), + shapes, dtype, rtol=1e-3, atol=1e-3) + + @with_seed() def test_sgld(): opt1 = mx.optimizer.SGLD From 98b1d6d56947ddaec719dcb77af88eac4e5a25d9 Mon Sep 17 00:00:00 2001 From: Zheng Date: Thu, 25 Jun 2020 12:30:17 -0700 Subject: [PATCH 2/3] fix --- src/operator/contrib/multi_sum_sq-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/contrib/multi_sum_sq-inl.h b/src/operator/contrib/multi_sum_sq-inl.h index caf3ea239ad1..46031d282d29 100644 --- a/src/operator/contrib/multi_sum_sq-inl.h +++ b/src/operator/contrib/multi_sum_sq-inl.h @@ -93,7 +93,7 @@ size_t GetRequiredStorageMultiSumSq(const std::vector &inputs, template void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, const OpContext &ctx, float scale=1.0f); + float *out_ptr, const OpContext &ctx, float scale = 1.0f); template void MultiSumSq(const nnvm::NodeAttrs& attrs, From e9e5ae4eb7737f2834280aef7069d0fda9889d70 Mon Sep 17 00:00:00 2001 From: Zheng Date: Thu, 25 Jun 2020 13:35:17 -0700 Subject: [PATCH 3/3] fix --- python/mxnet/ndarray/contrib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b1700bbe8c52..0975013cce63 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -683,7 +683,7 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count, def multi_lans_update(weights, grads, mean, var, step_count, - lrs, wds, out=None, num_tensors=0, **kwargs): + lrs, wds, out=None, num_tensors=0, **kwargs): """Given a list of gradients, update weights, mean and variance of multiple tensors following LANS Optimizer implementation. @@ -721,7 +721,7 @@ def multi_lans_update(weights, grads, mean, var, step_count, def multi_mp_lans_update(weights, grads, mean, var, weights32, step_count, - lrs, wds, out=None, num_tensors=0, **kwargs): + lrs, wds, out=None, num_tensors=0, **kwargs): """Given a list of gradients, update weights, mean and variance of multiple tensors following LANS Optimizer implementation, and using Mixed-Precision.