Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend paddle: add optimizers with supporting regularizer #1896

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,11 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
list(self.net.parameters()) + self.external_trainable_variables
)
self.opt = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
trainable_variables,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=self.net.regularizer,
)

def train_step(inputs, targets, auxiliary_vars):
Expand Down
1 change: 1 addition & 0 deletions deepxde/nn/paddle/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class NN(paddle.nn.Layer):

def __init__(self):
super().__init__()
self.regularizer = None
self._input_transform = None
self._output_transform = None

Expand Down
31 changes: 29 additions & 2 deletions deepxde/optimizers/paddle/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def is_external_optimizer(optimizer):
return optimizer in ["L-BFGS", "L-BFGS-B"]


def get(params, optimizer, learning_rate=None, decay=None):
def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
"""Retrieves an Optimizer instance."""
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer

if optimizer in ["L-BFGS", "L-BFGS-B"]:
if weight_decay is not None:
raise ValueError("L-BFGS optimizer doesn't support weight_decay")
if learning_rate is not None or decay is not None:
print("Warning: learning rate is ignored for {}".format(optimizer))
lijialin03 marked this conversation as resolved.
Show resolved Hide resolved
optim = paddle.optimizer.LBFGS(
Expand All @@ -46,5 +48,30 @@ def get(params, optimizer, learning_rate=None, decay=None):
learning_rate = _get_lr_scheduler(learning_rate, decay)

if optimizer == "adam":
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
return paddle.optimizer.Adam(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
if optimizer == "sgd":
return paddle.optimizer.SGD(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
if optimizer == "rmsprop":
return paddle.optimizer.RMSProp(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay,
)
if optimizer == "adamw":
if (
not isinstance(weight_decay, paddle.regularizer.L2Decay)
or weight_decay._coeff == 0
):
raise ValueError(
lijialin03 marked this conversation as resolved.
Show resolved Hide resolved
"AdamW optimizer requires L2 regularizer and non-zero weight decay"
lijialin03 marked this conversation as resolved.
Show resolved Hide resolved
)
return paddle.optimizer.AdamW(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay._coeff,
)
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")
Loading