This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
optimizer_op-inl.h
2939 lines (2782 loc) · 128 KB
/
optimizer_op-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file optimizer_op-inl.h
* \brief Optimizer operators
* \author Junyuan Xie
*/
#ifndef MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
#define MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <mshadow/base.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include "./operator_common.h"
#include "./mshadow_op.h"
#include "./elemwise_op_common.h"
#include "mxnet_op.h"
#include "./tensor/init_op.h"
#include "./tensor/util/tensor_util-inl.h"
namespace mxnet {
namespace op {
/*
* \brief log message for optimizers with lazy update.
*/
inline void LogLazyUpdate() {
common::LogOnce(
"Optimizer with lazy_update = True detected. "
"Be aware that lazy update with row_sparse gradient is different from "
"standard update, and may lead to different empirical results. See "
"https://mxnet.apache.org/api/python/optimization/optimization.html "
"for more details.");
}
struct SGDParam : public dmlc::Parameter<SGDParam> {
float lr;
float wd;
float rescale_grad;
float clip_gradient;
bool lazy_update;
DMLC_DECLARE_PARAMETER(SGDParam) {
DMLC_DECLARE_FIELD(lr).describe("Learning rate");
DMLC_DECLARE_FIELD(wd).set_default(0.0f).describe(
"Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe(
"Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(lazy_update)
.set_default(true)
.describe("If true, lazy updates are applied if gradient's stype is row_sparse.");
}
};
struct MultiSGDParam : public dmlc::Parameter<MultiSGDParam> {
mxnet::Tuple<float> lrs;
mxnet::Tuple<float> wds;
float rescale_grad;
float clip_gradient;
int num_weights;
DMLC_DECLARE_PARAMETER(MultiSGDParam) {
DMLC_DECLARE_FIELD(lrs).describe("Learning rates.");
DMLC_DECLARE_FIELD(wds).describe(
"Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe(
"Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(num_weights).set_default(1).describe("Number of updated weights.");
}
};
struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
mxnet::Tuple<float> lrs;
mxnet::Tuple<float> wds;
float momentum;
float rescale_grad;
float clip_gradient;
int num_weights;
DMLC_DECLARE_PARAMETER(MultiSGDMomParam) {
DMLC_DECLARE_FIELD(lrs).describe("Learning rates.");
DMLC_DECLARE_FIELD(wds).describe(
"Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(momentum).set_default(0.0f).describe(
"The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe(
"Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(num_weights).set_default(1).describe("Number of updated weights.");
}
};
template <typename ParamType, int input_stride>
inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
CHECK_EQ(out_attrs->size(), param.num_weights);
bool all_inferred = true;
auto& input_shapes = *in_attrs;
auto& output_shapes = *out_attrs;
// Learning rates
CHECK_EQ(param.lrs.ndim(), param.num_weights)
<< "Number of learning rates is inconsistent with num_weights "
<< "parameter passed. Expected number of learning rates: " << param.num_weights
<< ", and got " << param.lrs.ndim();
// Weight decays
CHECK_EQ(param.wds.ndim(), param.num_weights)
<< "Number of weight decays is inconsistent with num_weights "
<< "parameter passed. Expected number of weight decays: " << param.num_weights << ", and got "
<< param.wds.ndim();
// Weights and gradients
for (int i = 0; i < param.num_weights; ++i) {
mxnet::ShapeVector input_vec;
mxnet::ShapeVector output_vec({output_shapes[i]});
for (int j = 0; j < input_stride; ++j) {
input_vec.push_back(input_shapes[i * input_stride + j]);
}
all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs, &input_vec, &output_vec);
}
return all_inferred;
}
template <typename ParamType, int input_stride, int num_fp32_inputs>
inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
CHECK_EQ(out_attrs->size(), param.num_weights);
bool all_inferred = true;
auto& input_types = *in_attrs;
auto& output_types = *out_attrs;
// Weights and gradients
for (int i = 0; i < param.num_weights; ++i) {
std::vector<int> input_vec;
std::vector<int> output_vec({output_types[i]});
for (int j = 0; j < input_stride - num_fp32_inputs; ++j) {
input_vec.push_back(input_types[i * input_stride + j]);
}
all_inferred = all_inferred &&
ElemwiseType<input_stride - num_fp32_inputs, 1>(attrs, &input_vec, &output_vec);
}
// master copies of weights
for (int i = 0; i < param.num_weights; ++i) {
for (int j = 0; j < num_fp32_inputs; ++j) {
TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32);
}
}
return all_inferred;
}
template <typename DType, typename MPDType>
struct MultiSGDKernelParam {
static const int N = 60;
int count;
size_t max_size;
size_t sizes[N];
DType* weights[N];
DType* grads[N];
MPDType* mom[N];
MPDType* weights32[N];
DType* out_data[N];
MPDType lrs[N];
MPDType wds[N];
MPDType clip_gradient;
MPDType rescale_grad;
MPDType momentum;
};
template <typename MPDType, bool has_momentum, bool has_mixed_precision>
struct MultiSGDKernel {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i,
const MultiSGDKernelParam<DType, MPDType>& param,
const OpReqType req) {
for (int index = 0; index < param.count; ++index) {
if (i < static_cast<index_t>(param.sizes[index])) {
MPDType w =
has_mixed_precision ? param.weights32[index][i] : MPDType(param.weights[index][i]);
MPDType rescale_grad = param.rescale_grad * static_cast<MPDType>(param.grads[index][i]);
if (param.clip_gradient >= 0.0f) {
rescale_grad = mshadow_op::clip::Map(rescale_grad, param.clip_gradient);
}
rescale_grad += param.wds[index] * w;
if (has_momentum) {
param.mom[index][i] *= param.momentum;
param.mom[index][i] -= param.lrs[index] * rescale_grad;
w = w + param.mom[index][i];
} else {
w -= param.lrs[index] * rescale_grad;
}
if (has_mixed_precision) {
param.weights32[index][i] = w;
}
KERNEL_ASSIGN(param.out_data[index][i], req, w);
}
}
}
};
template <typename xpu,
typename DType,
typename MPDType,
typename ParamType = MultiSGDParam,
int input_stride = 2>
MultiSGDKernelParam<DType, MPDType> FillMultiSGDKernelParam(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MultiSGDKernelParam<DType, MPDType> param;
param.clip_gradient = p.clip_gradient;
param.rescale_grad = p.rescale_grad;
param.momentum = 0;
param.count = p.num_weights;
param.max_size = 0;
for (int i = 0; i < param.count; ++i) {
param.sizes[i] = inputs[i * input_stride].shape_.Size();
if (param.max_size < param.sizes[i]) {
param.max_size = param.sizes[i];
}
param.weights[i] = inputs[i * input_stride].FlatTo2D<xpu, DType>(s).dptr_;
param.grads[i] = inputs[i * input_stride + 1].FlatTo2D<xpu, DType>(s).dptr_;
// if mixed precision, then the last input in a set
// is 32-bit master copy of the weights
if (!std::is_same<DType, MPDType>::value) {
param.weights32[i] =
inputs[i * input_stride + input_stride - 1].FlatTo2D<xpu, MPDType>(s).dptr_;
}
param.out_data[i] = outputs[i].FlatTo2D<xpu, DType>(s).dptr_;
param.lrs[i] = p.lrs[i];
param.wds[i] = p.wds[i];
}
return param;
}
template <typename xpu, typename DType, typename MPDType, int input_stride = 3>
MultiSGDKernelParam<DType, MPDType> FillMultiSGDMomKernelParam(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
const MultiSGDMomParam& p = nnvm::get<MultiSGDMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MultiSGDKernelParam<DType, MPDType> param =
FillMultiSGDKernelParam<xpu, DType, MPDType, MultiSGDMomParam, input_stride>(
attrs, ctx, inputs, outputs);
param.momentum = p.momentum;
for (int i = 0; i < param.count; ++i) {
param.mom[i] = inputs[i * input_stride + 2].FlatTo2D<xpu, MPDType>(s).dptr_;
}
return param;
}
template <typename T>
class type_identity {
public:
using type = T;
};
template <typename T>
class single_precision {
public:
using type = float;
};
template <typename xpu, template <typename> class MPTypeChooser, int input_stride>
inline void MultiSGDUpdate(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
using MPDType = typename MPTypeChooser<DType>::type;
MultiSGDKernelParam<DType, MPDType> param =
FillMultiSGDKernelParam<xpu, DType, MPDType, MultiSGDParam, input_stride>(
attrs, ctx, inputs, outputs);
Kernel<MultiSGDKernel<MPDType, false, !std::is_same<DType, MPDType>::value>, xpu>::Launch(
s, param.max_size, param, req[0]);
});
}
template <typename xpu, template <typename> class MPTypeChooser, int input_stride>
inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
using MPDType = typename MPTypeChooser<DType>::type;
MultiSGDKernelParam<DType, MPDType> param =
FillMultiSGDMomKernelParam<xpu, DType, MPDType, input_stride>(attrs, ctx, inputs, outputs);
Kernel<MultiSGDKernel<MPDType, true, !std::is_same<DType, MPDType>::value>, xpu>::Launch(
s, param.max_size, param, req[0]);
});
}
struct SGDKernel {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
const DType* weight_data,
const DType* grad_data,
const DType param_clip_gradient,
const DType param_lr,
const DType param_wd,
const DType param_rescale_grad,
const OpReqType req) {
DType rescale_grad = param_rescale_grad * grad_data[i];
if (param_clip_gradient >= 0.0f) {
rescale_grad = mshadow_op::clip::Map(rescale_grad, param_clip_gradient);
}
rescale_grad += param_wd * weight_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - (param_lr * rescale_grad));
}
};
template <typename xpu>
inline void SGDUpdate(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<SGDKernel, xpu>::Launch(s,
weight.shape_.Size(),
out.dptr_,
weight.dptr_,
grad.dptr_,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr),
static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad),
req[0]);
});
}
/*! \brief kernel for sparse sgd
*/
template <int req, typename xpu>
struct SGDDnsRspKernel;
template <int req>
struct SGDDnsRspKernel<req, gpu> {
// DType is the output data type
// IType is row sparse idx type
// i is the ith element in row sparse gradient
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
const index_t row_length,
DType* out,
const DType* weight,
const IType* grad_idx,
const DType* grad_val,
const DType clip_gradient,
const DType lr,
const DType wd,
const DType rescale_grad) {
using nnvm::dim_t;
using namespace mshadow_op;
const dim_t row_id = i / row_length;
const dim_t col_id = i % row_length;
const dim_t row_offset = grad_idx[row_id] * row_length;
const dim_t data_i = row_offset + col_id;
DType grad_rescaled = rescale_grad * grad_val[i];
if (clip_gradient >= 0.0f) {
grad_rescaled = mshadow_op::clip::Map(grad_rescaled, clip_gradient);
}
grad_rescaled += wd * weight[data_i];
KERNEL_ASSIGN(out[data_i], req, weight[data_i] - (lr * grad_rescaled));
}
};
/*! \brief kernel for sparse sgd
*/
template <int req>
struct SGDDnsRspKernel<req, cpu> {
// DType is the output data type
// IType is row sparse idx type
// i is the ith row in row sparse gradient
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
const index_t row_length,
DType* out,
const DType* weight,
const IType* grad_idx,
const DType* grad_val,
const DType clip_gradient,
const DType lr,
const DType wd,
const DType rescale_grad) {
for (index_t j = 0; j < row_length; j++) {
index_t data_i = grad_idx[i] * row_length + j;
index_t grad_i = i * row_length + j;
DType grad_rescaled = rescale_grad * grad_val[grad_i];
if (clip_gradient >= 0.0f) {
grad_rescaled = mshadow_op::clip::Map(grad_rescaled, clip_gradient);
}
grad_rescaled += wd * weight[data_i];
KERNEL_ASSIGN(out[data_i], req, weight[data_i] - (lr * grad_rescaled));
}
}
};
/*
* \brief SGD update implementation for dense weight and row_sparse grad.
* Both standard update and lazy update are supported.
*/
template <typename xpu>
inline void SGDUpdateDnsRspImpl(const SGDParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const OpReqType& req,
TBlob* out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
// if gradients are zeros, no weights are updated
if (req == kNullOp)
return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
float wd = param.wd;
// apply standard weight decay if not lazy update
if (!param.lazy_update) {
Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(
s,
weight.Size(),
weight_data,
weight_data,
static_cast<DType>(1 - param.lr * param.wd));
wd = 0;
}
if (!grad.storage_initialized())
return;
const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
size_t num_threads = num_rows;
if (std::is_same<xpu, gpu>::value) {
num_threads = num_rows * row_length;
}
Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s,
num_threads,
row_length,
out->dptr<DType>(),
weight_data,
grad_idx,
grad_val,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr),
static_cast<DType>(wd),
static_cast<DType>(param.rescale_grad));
});
});
});
}
/*
* \brief SGD update implementation for row_sparse grad.
* Both standard update and lazy update are supported.
*/
template <typename xpu>
inline void SGDUpdateRspImpl(const SGDParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const OpReqType& req,
NDArray* out) {
CheckAllRowsPresent(weight, "SGDUpdate", "weights");
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
}
template <typename xpu>
inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
const auto w_stype = inputs[0].storage_type();
const auto g_stype = inputs[1].storage_type();
const auto o_stype = outputs[0].storage_type();
if (o_stype == w_stype && g_stype == kRowSparseStorage &&
(w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) {
NDArray out = outputs[0];
// std update and lazy update with rsp grad
SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
float lr;
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
bool lazy_update;
DMLC_DECLARE_PARAMETER(SGDMomParam) {
DMLC_DECLARE_FIELD(lr).describe("Learning rate");
DMLC_DECLARE_FIELD(momentum).set_default(0.0f).describe(
"The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(wd).set_default(0.0f).describe(
"Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe(
"Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(lazy_update)
.set_default(true)
.describe(
"If true, lazy updates are applied if gradient's stype is row_sparse "
"and both weight and momentum have the same stype");
}
};
struct SGDMomKernel {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
DType* mom_data,
const DType* weight_data,
const DType* grad_data,
const DType param_clip_gradient,
const DType param_momentum,
const DType param_lr,
const DType param_wd,
const DType param_rescale_grad,
const OpReqType req) {
DType rescale_grad = param_rescale_grad * grad_data[i];
if (param_clip_gradient >= 0.0f) {
rescale_grad = mshadow_op::clip::Map(rescale_grad, param_clip_gradient);
}
rescale_grad += param_wd * weight_data[i];
mom_data[i] *= param_momentum;
mom_data[i] -= param_lr * rescale_grad;
KERNEL_ASSIGN(out_data[i], req, weight_data[i] + mom_data[i]);
}
};
template <typename xpu>
inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<SGDMomKernel, xpu>::Launch(s,
weight.shape_.Size(),
out.dptr_,
mom.dptr_,
weight.dptr_,
grad.dptr_,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.momentum),
static_cast<DType>(param.lr),
static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad),
req[0]);
});
}
template <int n_in, int n_out, int total_in>
inline bool MP_InferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
for (int i = n_in; i < total_in; ++i) {
TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
}
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>(
attrs, in_attrs, out_attrs, -1);
}
struct MP_SGDKernel {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
const DType* weight_data,
const DType* grad_data,
float* weight32,
const float param_clip_gradient,
const float param_lr,
const float param_wd,
const float param_rescale_grad,
const OpReqType req) {
float w = weight32[i];
float rescale_grad = param_rescale_grad * static_cast<float>(grad_data[i]);
if (param_clip_gradient >= 0.0f) {
rescale_grad = mshadow_op::clip::Map(rescale_grad, param_clip_gradient);
}
rescale_grad += param_wd * w;
w -= param_lr * rescale_grad;
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, (DType)w);
}
};
template <typename xpu>
inline void MP_SGDUpdate(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> weight32 = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_SGDKernel, xpu>::Launch(s,
weight.shape_.Size(),
out.dptr_,
weight.dptr_,
grad.dptr_,
weight32.dptr_,
param.clip_gradient,
param.lr,
param.wd,
param.rescale_grad,
req[0]);
});
}
struct MP_SGDMomKernel {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
float* mom_data,
const DType* weight_data,
const DType* grad_data,
float* weight32,
const float param_clip_gradient,
const float param_momentum,
const float param_lr,
const float param_wd,
const float param_rescale_grad,
const OpReqType req) {
float w = weight32[i];
float mom = mom_data[i];
float grad_rescaled = param_rescale_grad * static_cast<float>(grad_data[i]);
if (param_clip_gradient >= 0.0f) {
grad_rescaled = mshadow_op::clip::Map(grad_rescaled, param_clip_gradient);
}
grad_rescaled += param_wd * w;
mom *= param_momentum;
mom -= param_lr * grad_rescaled;
mom_data[i] = mom;
w = w + mom;
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
};
template <typename xpu>
inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<MP_SGDMomKernel, xpu>::Launch(s,
weight.shape_.Size(),
out.dptr_,
mom.dptr_,
weight.dptr_,
grad.dptr_,
weight32.dptr_,
param.clip_gradient,
param.momentum,
param.lr,
param.wd,
param.rescale_grad,
req[0]);
});
}
template <int req, typename xpu>
struct SGDMomDnsRspDnsKernel;
template <int req>
struct SGDMomDnsRspDnsKernel<req, cpu> {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
index_t row_length,
DType* out_data,
DType* mom_data,
const DType* weight_data,
const IType* grad_idx,
const DType* grad_data,
const DType clip_gradient,
const DType momentum,
const DType lr,
const DType wd,
const DType rescale_grad) {
for (index_t j = 0; j < row_length; j++) {
index_t data_i = grad_idx[i] * row_length + j;
index_t grad_i = i * row_length + j;
DType grad_rescaled = rescale_grad * grad_data[grad_i];
if (clip_gradient >= 0.0f) {
grad_rescaled = mshadow_op::clip::Map(grad_rescaled, clip_gradient);
}
grad_rescaled += wd * weight_data[data_i];
mom_data[data_i] *= momentum;
mom_data[data_i] -= lr * grad_rescaled;
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
}
}
};
template <int req>
struct SGDMomDnsRspDnsKernel<req, gpu> {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
index_t row_length,
DType* out_data,
DType* mom_data,
const DType* weight_data,
const IType* grad_idx,
const DType* grad_data,
const DType clip_gradient,
const DType momentum,
const DType lr,
const DType wd,
const DType rescale_grad) {
using nnvm::dim_t;
const dim_t row_id = i / row_length;
const dim_t col_id = i % row_length;
const dim_t data_i = grad_idx[row_id] * row_length + col_id;
DType grad_rescaled = rescale_grad * grad_data[i];
if (clip_gradient >= 0.0f) {
grad_rescaled = mshadow_op::clip::Map(grad_rescaled, clip_gradient);
}
grad_rescaled += wd * weight_data[data_i];
mom_data[data_i] *= momentum;
mom_data[data_i] -= lr * grad_rescaled;
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
}
};
/*
* \brief sgd mom lazy update for dense weight, row_sparse grad, dense state.
*/
template <typename xpu>
inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mom,
const OpReqType& req,
TBlob* out) {
using namespace mxnet_op;
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
if (!grad.storage_initialized() || req == kNullOp)
return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mom.shape_.Size(), 0);
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
DType* grad_val = grad.data().dptr<DType>();
DType* mom_data = mom.dptr<DType>();
DType* out_data = out->dptr<DType>();
index_t num_rows = grad.aux_shape(kIdx)[0];
auto row_length = weight.shape_.ProdShape(1, weight.ndim());
size_t num_threads = num_rows;
if (std::is_same<xpu, gpu>::value) {
num_threads = num_rows * row_length;
}
Kernel<SGDMomDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(
s,
num_threads,
row_length,
out_data,
mom_data,
weight_data,
grad_idx,
grad_val,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.momentum),
static_cast<DType>(param.lr),
static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
});
}
/*
* \brief sgd momentum lazy update for row_sparse grad.
*/
template <typename xpu>
inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const NDArray& mom,
const OpReqType& req,
NDArray* out) {
using namespace mxnet_op;
using namespace rowsparse;
CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mom with zero values (if it's in rsp storage)
// in order to reuse the sgd mom dns impl
if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
NDArray mom_zeros = mom;
FillDnsZerosRspImpl(s, &mom_zeros);
}
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mom.data(), req, &out_blob);
}
/*!
* \brief Storge type inference function for optimizers which support both
* lazy update and standard update, with states (e.g. 2nd order moment)
* \param num_states The number of states that could be row_sparse or dense
*/
template <size_t num_states, typename ParamType>
inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
using namespace common;
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
// weight, grad, state 0, state 1, ... -> weight
CHECK_EQ(in_attrs->size(), 2 + num_states);
CHECK_EQ(out_attrs->size(), 1U);
const int weight_stype = in_attrs->at(0);
const int grad_stype = in_attrs->at(1);
const int state_stype = in_attrs->at(2);
// the storage type of all states should be the same
for (size_t i = 3; i < 2 + num_states; i++) {
CHECK_EQ(state_stype, in_attrs->at(i)) << "Inconsistent storage types detected in state " << i;
}
bool dispatched = false;
if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
// dns, ... -> dns
dispatched =
storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && grad_stype == kRowSparseStorage &&
(weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
state_stype == weight_stype) {
// weight and state share stype, grad's stype = rsp
dispatched = storage_type_assign(out_attrs,
static_cast<NDArrayStorageType>(weight_stype),
dispatch_mode,
DispatchMode::kFComputeEx);
// warn users if lazy_update is turned on
if (dispatched && param.lazy_update)
LogLazyUpdate();
}
if (!dispatched && grad_stype == kRowSparseStorage && weight_stype == kRowSparseStorage &&
state_stype == kDefaultStorage) {
// weight, grad, state, ... -> weight
// rsp, rsp, dns, ... -> rsp, standard update
dispatched = storage_type_assign(out_attrs,
static_cast<NDArrayStorageType>(weight_stype),
dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
/*
* \brief kernel for standard momentum update for dense weight, sparse grad and dense state.
*/
template <int req, typename xpu>
struct SGDMomStdDnsRspDnsKernel;
/*
* \brief standard momentum update for dense weight, row_sparse grad and dense states.
*/
template <typename xpu>
void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mom,
const OpReqType& req,
TBlob* out);
/*
* \brief standard momentum update for row_sparse grad.
* both row_sparse and dense weight are supported.
*/
template <typename xpu>
inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param,