Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNCG to optimizers submodule #1661

Merged
merged 29 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d0e22bb
Add NNCG to optimizers submodule
pratikrathore8 Feb 22, 2024
3afed9f
Update nys_newton_cg.py
pratikrathore8 Feb 23, 2024
081d5f6
Moved NNCG to pytorch folder
pratikrathore8 Feb 27, 2024
03a77a1
Minor formatting changes in NNCG
pratikrathore8 Mar 1, 2024
88d2f7e
Update nys_newton_cg.py
pratikrathore8 Mar 4, 2024
fff6a91
Fix Codacy issues
pratikrathore8 Mar 4, 2024
19490ea
Fix more Codacy issues
pratikrathore8 Mar 4, 2024
ec59a99
Added NNCG to config.py and optimizers.py
pratikrathore8 Mar 11, 2024
8995aad
Clean up NNCG integration in optimizers.py
pratikrathore8 Mar 13, 2024
1b13a08
Fixed import order in optimizers.py
pratikrathore8 Mar 15, 2024
2d63ba3
Made demo with NNCG and Burgers equation
pratikrathore8 Apr 8, 2024
c356c90
Merge branch 'lululxvi:master' into master
pratikrathore8 Oct 28, 2024
8d0210a
refactor nncg integration
pratikrathore8 Oct 29, 2024
5c9beba
clean up commented code
pratikrathore8 Oct 29, 2024
ce1dcc8
format with black
pratikrathore8 Oct 29, 2024
d97ca16
remove unnecessary error checks
pratikrathore8 Oct 29, 2024
81c4452
fix some codacy issues in nncg
pratikrathore8 Oct 29, 2024
60d3ff3
further improvements to nncg integration
pratikrathore8 Oct 30, 2024
6e739e1
add train_step_nncg
pratikrathore8 Oct 31, 2024
f4e322e
improve documentation in nncg config
pratikrathore8 Nov 4, 2024
6338f59
added doc for nncg demo
pratikrathore8 Nov 5, 2024
838eaa8
added demo file to pinn_forward.rst
pratikrathore8 Nov 5, 2024
e0bb44d
Merge branch 'master' into master
pratikrathore8 Nov 7, 2024
61f08f9
Merge branch 'lululxvi:master' into master
pratikrathore8 Nov 13, 2024
7d58317
change ordering in pinn_forward.rst
pratikrathore8 Nov 13, 2024
5467524
merge burgers_nncg demo into burgers demo
pratikrathore8 Nov 18, 2024
f2d5c31
more cleanup in demo
pratikrathore8 Nov 22, 2024
74ace2c
comment out nncg code block
pratikrathore8 Nov 22, 2024
c1a6365
minor cleanup in burgers example
pratikrathore8 Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,22 @@ def closure():
if self.lr_scheduler is not None:
self.lr_scheduler.step()

def train_step_nncg(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
return total_loss

self.opt.step(closure)
if self.lr_scheduler is not None:
self.lr_scheduler.step()

# Callables
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = train_step
self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg

def _compile_jax(self, lr, loss_fn, decay):
"""jax"""
Expand Down Expand Up @@ -648,7 +659,10 @@ def train(
elif backend_name == "tensorflow":
self._train_tensorflow_tfp()
elif backend_name == "pytorch":
self._train_pytorch_lbfgs()
if self.opt_name == "L-BFGS":
self._train_pytorch_lbfgs()
elif self.opt_name == "NNCG":
self._train_sgd(iterations, display_every)
elif backend_name == "paddle":
self._train_paddle_lbfgs()
else:
Expand Down
2 changes: 1 addition & 1 deletion deepxde/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
import sys

from .config import LBFGS_options, set_LBFGS_options
from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options
from ..backend import backend_name


Expand Down
55 changes: 54 additions & 1 deletion deepxde/optimizers/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__all__ = ["set_LBFGS_options", "set_hvd_opt_options"]
__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"]

from ..backend import backend_name
from ..config import hvd

LBFGS_options = {}
NNCG_options = {}
if hvd is not None:
hvd_opt_options = {}

Expand Down Expand Up @@ -60,6 +61,57 @@ def set_LBFGS_options(
LBFGS_options["maxls"] = maxls


def set_NNCG_options(
lr=1,
rank=10,
mu=1e-4,
updatefreq=20,
chunksz=1,
cgtol=1e-16,
cgmaxiter=1000,
lsfun="armijo",
verbose=False,
):
"""Sets the hyperparameters of NysNewtonCG (NNCG).
pratikrathore8 marked this conversation as resolved.
Show resolved Hide resolved

Args:
lr (float): `lr` (torch).
pratikrathore8 marked this conversation as resolved.
Show resolved Hide resolved
Learning rate (before line search).
rank (int): `rank` (torch).
Rank of preconditioner matrix used in preconditioned conjugate gradient.
mu (float): `mu` (torch).
Hessian damping parameter.
updatefreq (int): How often the preconditioner matrix in preconditioned
conjugate gradient is updated. This parameter is not directly used in NNCG,
instead it is used in _train_pytorch_nncg in deepxde/model.py.
chunksz (int): `chunk_size` (torch).
Number of Hessian-vector products to compute in parallel when constructing
preconditioner. If `chunk_size` is 1, the Hessian-vector products are
computed serially.
cgtol (float): `cg_tol` (torch).
Convergence tolerance for the conjugate gradient method. The iteration stops
when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition
is based on the absolute tolerance, not the relative tolerance.
cgmaxiter (int): `cg_max_iters` (torch).
Maximum number of iterations for the conjugate gradient method.
lsfun (str): `line_search_fn` (torch).
The line search function used to find the step size. The default value is
"armijo". The other option is None.
verbose (bool): `verbose` (torch).
If `True`, prints the eigenvalues of the Nyström approximation
of the Hessian.
"""
NNCG_options["lr"] = lr
NNCG_options["rank"] = rank
NNCG_options["mu"] = mu
NNCG_options["updatefreq"] = updatefreq
NNCG_options["chunksz"] = chunksz
NNCG_options["cgtol"] = cgtol
NNCG_options["cgmaxiter"] = cgmaxiter
NNCG_options["lsfun"] = lsfun
NNCG_options["verbose"] = verbose


def set_hvd_opt_options(
compression=None,
op=None,
Expand Down Expand Up @@ -91,6 +143,7 @@ def set_hvd_opt_options(


set_LBFGS_options()
set_NNCG_options()
if hvd is not None:
set_hvd_opt_options()

Expand Down
Loading