Skip to content

Commit

Permalink
weight decay (L2)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 22, 2022
1 parent 69cd23c commit 9738caa
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,23 @@ def initial(self, value: Optional[Union[nn.Tensor, RawTensorTypes, nn.init.Varia
self.layer_dict.pop("init_by_layer", None)
self.layer_dict["init"] = value

@property
def weight_decay(self) -> float:
"""
Weight decay, which is equivalent to L2 loss on the parameters for SGD.
On RETURNN side, whether this is handled separately or is part of the main loss,
can be controlled via the ``decouple_constraints`` config option.
https://github.com/rwth-i6/returnn_common/issues/59#issuecomment-1073913421
"""
return self.layer_dict.get("L2", 0.0)

@weight_decay.setter
def weight_decay(self, value: Optional[float]):
if value:
self.layer_dict["L2"] = value
else:
self.layer_dict.pop("L2", None)


class LayerState(dict):
"""
Expand Down

0 comments on commit 9738caa

Please sign in to comment.