diff --git a/deepxde/model.py b/deepxde/model.py index 1644c2529..70853009c 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -518,7 +518,11 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): list(self.net.parameters()) + self.external_trainable_variables ) self.opt = optimizers.get( - trainable_variables, self.opt_name, learning_rate=lr, decay=decay + trainable_variables, + self.opt_name, + learning_rate=lr, + decay=decay, + weight_decay=self.net.regularizer, ) def train_step(inputs, targets, auxiliary_vars): diff --git a/deepxde/nn/paddle/nn.py b/deepxde/nn/paddle/nn.py index 6609027dc..74e6e5723 100644 --- a/deepxde/nn/paddle/nn.py +++ b/deepxde/nn/paddle/nn.py @@ -6,6 +6,7 @@ class NN(paddle.nn.Layer): def __init__(self): super().__init__() + self.regularizer = None self._input_transform = None self._output_transform = None diff --git a/deepxde/optimizers/paddle/optimizers.py b/deepxde/optimizers/paddle/optimizers.py index 9e843dab3..699afcc75 100644 --- a/deepxde/optimizers/paddle/optimizers.py +++ b/deepxde/optimizers/paddle/optimizers.py @@ -19,7 +19,7 @@ def is_external_optimizer(optimizer): return optimizer in ["L-BFGS", "L-BFGS-B"] -def get(params, optimizer, learning_rate=None, decay=None): +def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None): """Retrieves an Optimizer instance.""" if isinstance(optimizer, paddle.optimizer.Optimizer): return optimizer @@ -27,6 +27,8 @@ def get(params, optimizer, learning_rate=None, decay=None): if optimizer in ["L-BFGS", "L-BFGS-B"]: if learning_rate is not None or decay is not None: print("Warning: learning rate is ignored for {}".format(optimizer)) + if weight_decay is not None: + raise ValueError("L-BFGS optimizer doesn't support weight_decay") optim = paddle.optimizer.LBFGS( learning_rate=1, max_iter=LBFGS_options["iter_per_step"], @@ -46,5 +48,28 @@ def get(params, optimizer, learning_rate=None, decay=None): learning_rate = _get_lr_scheduler(learning_rate, decay) if optimizer == "adam": - return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params) + return paddle.optimizer.Adam( + learning_rate=learning_rate, parameters=params, weight_decay=weight_decay + ) + if optimizer == "sgd": + return paddle.optimizer.SGD( + learning_rate=learning_rate, parameters=params, weight_decay=weight_decay + ) + if optimizer == "rmsprop": + return paddle.optimizer.RMSProp( + learning_rate=learning_rate, + parameters=params, + weight_decay=weight_decay, + ) + if optimizer == "adamw": + if ( + not isinstance(weight_decay, paddle.regularizer.L2Decay) + or weight_decay._coeff == 0 + ): + raise ValueError("AdamW optimizer requires non-zero L2 regularizer") + return paddle.optimizer.AdamW( + learning_rate=learning_rate, + parameters=params, + weight_decay=weight_decay._coeff, + ) raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")