-
Notifications
You must be signed in to change notification settings - Fork 31
/
optimizer_op-inl.h
247 lines (230 loc) · 9.49 KB
/
optimizer_op-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file optimizer_op-inl.h
* \brief Optimizer operators
* \author Leonard Lausen
*/
#ifndef MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_
#define MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_
#include <dmlc/parameter.h>
#include <mshadow/base.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include "../elemwise_op_common.h"
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../tensor/init_op.h"
#include "../tensor/util/tensor_util-inl.h"
namespace mxnet {
namespace op {
struct GroupAdagradParam : public dmlc::Parameter<GroupAdagradParam> {
float lr;
float epsilon;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(GroupAdagradParam) {
DMLC_DECLARE_FIELD(lr).describe("Learning rate");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe(
"Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(epsilon).set_default(1.0e-5).describe(
"Epsilon for numerical stability");
}
};
inline bool GroupAdagradStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
const int weight_stype = in_attrs->at(0);
const int grad_stype = in_attrs->at(1);
const int state_stype = in_attrs->at(2);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
// dns, ... -> dns
dispatched = storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
if (!dispatched && grad_stype == kRowSparseStorage &&
(weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
state_stype == weight_stype) {
// weight and state share stype, grad's stype = rsp
dispatched = storage_type_assign(
out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode,
DispatchMode::kFComputeEx);
}
return dispatched;
}
/*! \brief kernel for sparse adagrad update with group sparsity regularization
*/
template <typename xpu> struct GroupAdagradDnsRspKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void
Map(int i, const index_t row_length, DType *out_data, DType *state_data,
DType *weight_data, const IType *grad_idx, const DType *grad_data,
const DType clip_gradient, const DType rescale_grad, const DType lr,
const DType eps) {
using namespace mshadow_op;
// Helper to obtain index into weight / state arrays
auto get_data_j = [&i, &grad_idx, &row_length](index_t j) -> index_t {
return grad_idx[i] * row_length + j;
};
// Helper to obtain explicit rescaled and clipped grad
auto get_grad_rescaled = [&i, &row_length, &grad_data, &rescale_grad,
&clip_gradient](index_t j) -> DType {
index_t grad_j = i * row_length + j;
DType grad_rescaled = grad_data[grad_j] * rescale_grad;
if (clip_gradient >= 0.0f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}
return grad_rescaled;
};
// Update history states
DType grad_ssq = 0;
for (index_t j = 0; j < row_length; j++) {
const DType grad_rescaled = get_grad_rescaled(j);
grad_ssq += grad_rescaled * grad_rescaled;
}
state_data[grad_idx[i]] += grad_ssq / row_length;
// Standard Adagrad Update
for (index_t j = 0; j < row_length; j++) {
// clang-format off
const DType grad_rescaled = get_grad_rescaled(j);
index_t data_j = get_data_j(j);
const DType div = lr * grad_rescaled / square_root::Map(state_data[grad_idx[i]] + eps);
out_data[data_j] = weight_data[data_j] - div;
// clang-format on
}
}
};
/*
* \brief Group Adagrad update implementation for dense weight and row_sparse
* grad.
*/
template <typename xpu>
inline void GroupAdagradUpdateDnsRspDnsImpl(
const GroupAdagradParam ¶m, const OpContext &ctx, const TBlob &weight,
const NDArray &grad, const TBlob &state, const OpReqType &req, TBlob *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
// if gradients are zeros, no weights are updated
if (req == kNullOp) {
return;
}
CHECK_EQ(req, kWriteInplace)
<< "kWriteInplace is expected for sparse group_adagrad_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(state.shape_.Size(), 0);
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
DType *weight_data = weight.dptr<DType>();
DType *out_data = out->dptr<DType>();
const IType *grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
const DType *grad_val = grad.data().dptr<DType>();
DType *state_data = state.dptr<DType>();
const nnvm::dim_t num_grad = grad.aux_shape(rowsparse::kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
if (!grad.storage_initialized()) {
// Lazy update with 0 gradient
return;
}
Kernel<GroupAdagradDnsRspKernel<xpu>, xpu>::Launch(
s, num_grad, row_length, out_data, state_data, weight_data, grad_idx,
grad_val, static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.rescale_grad), static_cast<DType>(param.lr),
static_cast<DType>(param.epsilon));
});
});
}
/*
* \brief AdaGrad update implementation for row_sparse grad. Both standard
* update and lazy update are supported.
*/
template <typename xpu>
inline void
GroupAdagradUpdateRspRspRspImpl(const GroupAdagradParam ¶m,
const OpContext &ctx, const NDArray &weight,
const NDArray &grad, const NDArray &state,
const OpReqType &req, NDArray *out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace rowsparse;
CheckAllRowsPresent(weight, "GroupAdagradUpdate", "weights");
Stream<xpu> *s = ctx.get_stream<xpu>();
// fill history with zero values
if (!state.storage_initialized()) {
NDArray state_zeros = state;
FillDnsZerosRspImpl(s, &state_zeros);
} else {
CheckAllRowsPresent(state, "GroupAdagradUpdate", "states");
}
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
GroupAdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
state.data(), req, &out_blob);
}
template <typename xpu>
inline void GroupAdagradUpdateEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const GroupAdagradParam ¶m = nnvm::get<GroupAdagradParam>(attrs.parsed);
const auto weight_stype = inputs[0].storage_type();
const auto grad_stype = inputs[1].storage_type();
const auto state_stype = inputs[2].storage_type();
const auto output_stype = outputs[0].storage_type();
if (state_stype == weight_stype && output_stype == weight_stype &&
weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) {
NDArray out = outputs[0];
GroupAdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
inputs[2], req[0], &out);
} else if (state_stype == weight_stype && output_stype == weight_stype &&
weight_stype == kDefaultStorage &&
grad_stype == kRowSparseStorage) {
TBlob out_blob = outputs[0].data();
GroupAdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(),
inputs[1], inputs[2].data(), req[0],
&out_blob);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CONTRIB_OPTIMIZER_OP_INL_H_