Skip to content

Commit

Permalink
fix RMSProp one_dim_param_no_weight_decay
Browse files Browse the repository at this point in the history
  • Loading branch information
flytocc committed Apr 17, 2023
1 parent a283b51 commit 55e2114
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions ppcls/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,26 +228,26 @@ def __init__(self,
def __call__(self, model_list):
# model_list is None in static graph
parameters = None
if len(self.no_weight_decay_name_list) > 0:
if model_list:
params_with_decay = []
params_without_decay = []
for m in model_list:
params = [p for n, p in m.named_parameters() \
if not any(nd in n for nd in self.no_weight_decay_name_list)]
params_with_decay.extend(params)
params = [p for n, p in m.named_parameters() \
if any(nd in n for nd in self.no_weight_decay_name_list) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1)]
params_without_decay.extend(params)
parameters = [{
"params": params_with_decay,
"weight_decay": self.weight_decay
}, {
"params": params_without_decay,
"weight_decay": 0.0
}]
else:
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
for n, p in m.named_parameters():
if any(nd in n for nd in self.no_weight_decay_name_list) \
or (self.one_dim_param_no_weight_decay and len(p.shape) == 1):
params_without_decay.append(p)
else:
params_with_decay.append(p)
if params_without_decay:
parameters = [{
"params": params_with_decay,
"weight_decay": self.weight_decay
}, {
"params": params_without_decay,
"weight_decay": 0.0
}]
else:
parameters = params_with_decay
opt = optim.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
Expand Down

0 comments on commit 55e2114

Please sign in to comment.