-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_lr.py
194 lines (174 loc) · 7.27 KB
/
custom_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from paddle.optimizer.lr import *
"""
PaddleVideo Learning Rate Schedule:
You can use paddle.optimizer.lr
or define your custom_lr in this file.
"""
class CustomWarmupCosineDecay(LRScheduler):
r"""
We combine warmup and stepwise-cosine which is used in slowfast model.
Args:
warmup_start_lr (float): start learning rate used in warmup stage.
warmup_epochs (int): the number epochs of warmup.
cosine_base_lr (float|int, optional): base learning rate in cosine schedule.
max_epoch (int): total training epochs.
num_iters(int): number iterations of each epoch.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CosineAnnealingDecay`` instance to schedule learning rate.
"""
def __init__(self,
warmup_start_lr,
warmup_epochs,
cosine_base_lr,
max_epoch,
num_iters,
last_epoch=-1,
verbose=False):
self.warmup_start_lr = warmup_start_lr
self.warmup_epochs = warmup_epochs
self.cosine_base_lr = cosine_base_lr
self.max_epoch = max_epoch
self.num_iters = num_iters
#call step() in base class, last_lr/last_epoch/base_lr will be update
super(CustomWarmupCosineDecay, self).__init__(last_epoch=last_epoch,
verbose=verbose)
def step(self, epoch=None):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
"""
if epoch is None:
if self.last_epoch == -1:
self.last_epoch += 1
else:
self.last_epoch += 1 / self.num_iters # update step with iters
else:
self.last_epoch = epoch
self.last_lr = self.get_lr()
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def _lr_func_cosine(self, cur_epoch, cosine_base_lr, max_epoch):
return cosine_base_lr * (math.cos(math.pi * cur_epoch / max_epoch) +
1.0) * 0.5
def get_lr(self):
"""Define lr policy"""
lr = self._lr_func_cosine(self.last_epoch, self.cosine_base_lr,
self.max_epoch)
lr_end = self._lr_func_cosine(self.warmup_epochs, self.cosine_base_lr,
self.max_epoch)
# Perform warm up.
if self.last_epoch < self.warmup_epochs:
lr_start = self.warmup_start_lr
alpha = (lr_end - lr_start) / self.warmup_epochs
lr = self.last_epoch * alpha + lr_start
return lr
class CustomWarmupPiecewiseDecay(LRScheduler):
r"""
This op combine warmup and stepwise-cosine which is used in slowfast model.
Args:
warmup_start_lr (float): start learning rate used in warmup stage.
warmup_epochs (int): the number epochs of warmup.
step_base_lr (float|int, optional): base learning rate in step schedule.
max_epoch (int): total training epochs.
num_iters(int): number iterations of each epoch.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CustomWarmupPiecewiseDecay`` instance to schedule learning rate.
"""
def __init__(self,
warmup_start_lr,
warmup_epochs,
step_base_lr,
lrs,
gamma,
steps,
max_epoch,
num_iters,
last_epoch=0,
verbose=False):
self.warmup_start_lr = warmup_start_lr
self.warmup_epochs = warmup_epochs
self.step_base_lr = step_base_lr
self.lrs = lrs
self.gamma = gamma
self.steps = steps
self.max_epoch = max_epoch
self.num_iters = num_iters
self.last_epoch = last_epoch
self.last_lr = self.warmup_start_lr # used in first iter
self.verbose = verbose
self._var_name = None
def step(self, epoch=None, rebuild=False):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
"""
if epoch is None:
if not rebuild:
self.last_epoch += 1 / self.num_iters # update step with iters
else:
self.last_epoch = epoch
self.last_lr = self.get_lr()
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def _lr_func_steps_with_relative_lrs(self, cur_epoch, lrs, base_lr, steps,
max_epoch):
# get step index
steps = steps + [max_epoch]
for ind, step in enumerate(steps):
if cur_epoch < step:
break
return lrs[ind - 1] * base_lr
def get_lr(self):
"""Define lr policy"""
lr = self._lr_func_steps_with_relative_lrs(
self.last_epoch,
self.lrs,
self.step_base_lr,
self.steps,
self.max_epoch,
)
lr_end = self._lr_func_steps_with_relative_lrs(
self.warmup_epochs,
self.lrs,
self.step_base_lr,
self.steps,
self.max_epoch,
)
# Perform warm up.
if self.last_epoch < self.warmup_epochs:
lr_start = self.warmup_start_lr
alpha = (lr_end - lr_start) / self.warmup_epochs
lr = self.last_epoch * alpha + lr_start
return lr
class CustomPiecewiseDecay(PiecewiseDecay):
def __init__(self, **kargs):
kargs.pop('num_iters')
super().__init__(**kargs)