Skip to content

Commit

Permalink
Backend paddle: add optimizers with supportting regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Nov 28, 2024
1 parent 8275aeb commit ad52e77
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
7 changes: 6 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,13 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
regularizer = getattr(self.net, 'regularizer', None)
if regularizer is not None:
weight_decay = self.net.regularizer_value if self.opt_name == "adamw" else self.net.regularizer
else:
weight_decay = None
self.opt = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
trainable_variables, self.opt_name, learning_rate=lr, decay=decay, weight_decay=weight_decay,
)

def train_step(inputs, targets, auxiliary_vars):
Expand Down
18 changes: 16 additions & 2 deletions deepxde/optimizers/paddle/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def is_external_optimizer(optimizer):
return optimizer in ["L-BFGS", "L-BFGS-B"]


def get(params, optimizer, learning_rate=None, decay=None):
def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
"""Retrieves an Optimizer instance."""
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer

if optimizer in ["L-BFGS", "L-BFGS-B"]:
if weight_decay is not None:
raise ValueError("L-BFGS optimizer doesn't support weight_decay")
if learning_rate is not None or decay is not None:
print("Warning: learning rate is ignored for {}".format(optimizer))
optim = paddle.optimizer.LBFGS(
Expand All @@ -46,5 +48,17 @@ def get(params, optimizer, learning_rate=None, decay=None):
learning_rate = _get_lr_scheduler(learning_rate, decay)

if optimizer == "adam":
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
elif optimizer == "sgd":
return paddle.optimizer.SGD(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
elif optimizer == "rmsprop":
return paddle.optimizer.RMSProp(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay,
)
elif optimizer == "adamw":
if weight_decay[0] == 0:
raise ValueError("AdamW optimizer requires non-zero weight decay")
return paddle.optimizer.AdamW(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay[0],
)
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

0 comments on commit ad52e77

Please sign in to comment.