Skip to content

Commit

Permalink
Refactor code using black
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Dec 11, 2024
1 parent 95de72a commit 01429e3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
16 changes: 12 additions & 4 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,13 +517,21 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
regularizer = getattr(self.net, 'regularizer', None)
regularizer = getattr(self.net, "regularizer", None)
if regularizer is not None:
weight_decay = self.net.regularizer_value if self.opt_name == "adamw" else self.net.regularizer
else:
weight_decay = (
self.net.regularizer_value
if self.opt_name == "adamw"
else self.net.regularizer
)
else:
weight_decay = None
self.opt = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay, weight_decay=weight_decay,
trainable_variables,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=weight_decay,
)

def train_step(inputs, targets, auxiliary_vars):
Expand Down
16 changes: 12 additions & 4 deletions deepxde/optimizers/paddle/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,25 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
learning_rate = _get_lr_scheduler(learning_rate, decay)

if optimizer == "adam":
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
return paddle.optimizer.Adam(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
elif optimizer == "sgd":
return paddle.optimizer.SGD(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
return paddle.optimizer.SGD(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
elif optimizer == "rmsprop":
return paddle.optimizer.RMSProp(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay,
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay,
)
elif optimizer == "adamw":
if weight_decay[0] == 0:
raise ValueError("AdamW optimizer requires non-zero weight decay")
return paddle.optimizer.AdamW(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay[0],
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay[0],
)
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

0 comments on commit 01429e3

Please sign in to comment.