Skip to content

Commit

Permalink
Add NNCG to optimizers submodule (#1661)
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikrathore8 authored Nov 26, 2024
1 parent bb1d3ac commit 8275aeb
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 7 deletions.
18 changes: 16 additions & 2 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,22 @@ def closure():
if self.lr_scheduler is not None:
self.lr_scheduler.step()

def train_step_nncg(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
return total_loss

self.opt.step(closure)
if self.lr_scheduler is not None:
self.lr_scheduler.step()

# Callables
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = train_step
self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg

def _compile_jax(self, lr, loss_fn, decay):
"""jax"""
Expand Down Expand Up @@ -652,7 +663,10 @@ def train(
elif backend_name == "tensorflow":
self._train_tensorflow_tfp(verbose=verbose)
elif backend_name == "pytorch":
self._train_pytorch_lbfgs(verbose=verbose)
if self.opt_name == "L-BFGS":
self._train_pytorch_lbfgs(verbose=verbose)
elif self.opt_name == "NNCG":
self._train_sgd(iterations, display_every, verbose=verbose)
elif backend_name == "paddle":
self._train_paddle_lbfgs(verbose=verbose)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepxde/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
import sys

from .config import LBFGS_options, set_LBFGS_options
from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options
from ..backend import backend_name


Expand Down
58 changes: 57 additions & 1 deletion deepxde/optimizers/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__all__ = ["set_LBFGS_options", "set_hvd_opt_options"]
__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"]

from ..backend import backend_name
from ..config import hvd

LBFGS_options = {}
NNCG_options = {}
if hvd is not None:
hvd_opt_options = {}

Expand Down Expand Up @@ -60,6 +61,60 @@ def set_LBFGS_options(
LBFGS_options["maxls"] = maxls


def set_NNCG_options(
lr=1,
rank=50,
mu=1e-1,
updatefreq=20,
chunksz=1,
cgtol=1e-16,
cgmaxiter=1000,
lsfun="armijo",
verbose=False,
):
"""Sets the hyperparameters of NysNewtonCG (NNCG).
The NNCG optimizer only supports PyTorch.
Args:
lr (float):
Learning rate (before line search).
rank (int):
Rank of preconditioner matrix used in preconditioned conjugate gradient.
mu (float):
Hessian damping parameter.
updatefreq (int):
How often the preconditioner matrix in preconditioned
conjugate gradient is updated. This parameter is not directly used in NNCG,
instead it is used in _train_pytorch_nncg in deepxde/model.py.
chunksz (int):
Number of Hessian-vector products to compute in parallel when constructing
preconditioner. If `chunk_size` is 1, the Hessian-vector products are
computed serially.
cgtol (float):
Convergence tolerance for the conjugate gradient method. The iteration stops
when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition
is based on the absolute tolerance, not the relative tolerance.
cgmaxiter (int):
Maximum number of iterations for the conjugate gradient method.
lsfun (str):
The line search function used to find the step size. The default value is
"armijo". The other option is None.
verbose (bool):
If `True`, prints the eigenvalues of the Nyström approximation
of the Hessian.
"""
NNCG_options["lr"] = lr
NNCG_options["rank"] = rank
NNCG_options["mu"] = mu
NNCG_options["updatefreq"] = updatefreq
NNCG_options["chunksz"] = chunksz
NNCG_options["cgtol"] = cgtol
NNCG_options["cgmaxiter"] = cgmaxiter
NNCG_options["lsfun"] = lsfun
NNCG_options["verbose"] = verbose


def set_hvd_opt_options(
compression=None,
op=None,
Expand Down Expand Up @@ -91,6 +146,7 @@ def set_hvd_opt_options(


set_LBFGS_options()
set_NNCG_options()
if hvd is not None:
set_hvd_opt_options()

Expand Down
Loading

0 comments on commit 8275aeb

Please sign in to comment.