diff --git a/deepxde/model.py b/deepxde/model.py index 568a16d10..458dfe5cf 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -367,11 +367,22 @@ def closure(): if self.lr_scheduler is not None: self.lr_scheduler.step() + def train_step_nncg(inputs, targets, auxiliary_vars): + def closure(): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + return total_loss + + self.opt.step(closure) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + # Callables self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test - self.train_step = train_step + self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" @@ -652,7 +663,10 @@ def train( elif backend_name == "tensorflow": self._train_tensorflow_tfp(verbose=verbose) elif backend_name == "pytorch": - self._train_pytorch_lbfgs(verbose=verbose) + if self.opt_name == "L-BFGS": + self._train_pytorch_lbfgs(verbose=verbose) + elif self.opt_name == "NNCG": + self._train_sgd(iterations, display_every, verbose=verbose) elif backend_name == "paddle": self._train_paddle_lbfgs(verbose=verbose) else: diff --git a/deepxde/optimizers/__init__.py b/deepxde/optimizers/__init__.py index e1fcfced1..556761a86 100644 --- a/deepxde/optimizers/__init__.py +++ b/deepxde/optimizers/__init__.py @@ -1,7 +1,7 @@ import importlib import sys -from .config import LBFGS_options, set_LBFGS_options +from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options from ..backend import backend_name diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 01ba8bd1f..41bb19b86 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -1,9 +1,10 @@ -__all__ = ["set_LBFGS_options", "set_hvd_opt_options"] +__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"] from ..backend import backend_name from ..config import hvd LBFGS_options = {} +NNCG_options = {} if hvd is not None: hvd_opt_options = {} @@ -60,6 +61,60 @@ def set_LBFGS_options( LBFGS_options["maxls"] = maxls +def set_NNCG_options( + lr=1, + rank=50, + mu=1e-1, + updatefreq=20, + chunksz=1, + cgtol=1e-16, + cgmaxiter=1000, + lsfun="armijo", + verbose=False, +): + """Sets the hyperparameters of NysNewtonCG (NNCG). + + The NNCG optimizer only supports PyTorch. + + Args: + lr (float): + Learning rate (before line search). + rank (int): + Rank of preconditioner matrix used in preconditioned conjugate gradient. + mu (float): + Hessian damping parameter. + updatefreq (int): + How often the preconditioner matrix in preconditioned + conjugate gradient is updated. This parameter is not directly used in NNCG, + instead it is used in _train_pytorch_nncg in deepxde/model.py. + chunksz (int): + Number of Hessian-vector products to compute in parallel when constructing + preconditioner. If `chunk_size` is 1, the Hessian-vector products are + computed serially. + cgtol (float): + Convergence tolerance for the conjugate gradient method. The iteration stops + when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition + is based on the absolute tolerance, not the relative tolerance. + cgmaxiter (int): + Maximum number of iterations for the conjugate gradient method. + lsfun (str): + The line search function used to find the step size. The default value is + "armijo". The other option is None. + verbose (bool): + If `True`, prints the eigenvalues of the Nyström approximation + of the Hessian. + """ + NNCG_options["lr"] = lr + NNCG_options["rank"] = rank + NNCG_options["mu"] = mu + NNCG_options["updatefreq"] = updatefreq + NNCG_options["chunksz"] = chunksz + NNCG_options["cgtol"] = cgtol + NNCG_options["cgmaxiter"] = cgmaxiter + NNCG_options["lsfun"] = lsfun + NNCG_options["verbose"] = verbose + + def set_hvd_opt_options( compression=None, op=None, @@ -91,6 +146,7 @@ def set_hvd_opt_options( set_LBFGS_options() +set_NNCG_options() if hvd is not None: set_hvd_opt_options() diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py new file mode 100644 index 000000000..93704ec9d --- /dev/null +++ b/deepxde/optimizers/pytorch/nncg.py @@ -0,0 +1,317 @@ +from functools import reduce + +import torch +from torch.func import vmap +from torch.optim import Optimizer + + +def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): + """Line search to find a step size that satisfies the Armijo condition.""" + f0 = f(x, 0, dx) + f1 = f(x, t, dx) + while f1 > f0 + alpha * t * gx.dot(dx): + t *= beta + f1 = f(x, t, dx) + return t + + +def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x): + """Applies the inverse of the Nystrom approximation of the Hessian to a vector.""" + z = U.T @ x + z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z) + return z + + +def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): + """Solves a positive-definite linear system using NyströmPCG. + + `Frangella et al. Randomized Nyström Preconditioning. + SIAM Journal on Matrix Analysis and Applications, 2023. + ` + """ + lambd_r = S[r - 1] + S_mu_inv = (S + mu) ** (-1) + + resid = b - (hess(x) + mu * x) + with torch.no_grad(): + z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) + p = z.clone() + + i = 0 + + while torch.norm(resid) > tol and i < max_iters: + v = hess(p) + mu * p + with torch.no_grad(): + alpha = torch.dot(resid, z) / torch.dot(p, v) + x += alpha * p + + rTz = torch.dot(resid, z) + resid -= alpha * v + z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) + beta = torch.dot(resid, z) / rTz + + p = z + beta * p + + i += 1 + + if torch.norm(resid) > tol: + print( + "Warning: PCG did not converge to tolerance. " + f"Tolerance was {tol} but norm of residual is {torch.norm(resid)}" + ) + + return x + + +class NNCG(Optimizer): + """Implementation of NysNewtonCG, a damped Newton-CG method + that uses Nyström preconditioning. + + `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. + Preprint, 2024. ` + + .. warning:: + This optimizer doesn't support per-parameter options and parameter + groups (there can be only one). + + NOTE: This optimizer is currently a beta version. + + Our implementation is inspired by the PyTorch implementation of `L-BFGS + `. + + The parameters rank and mu will probably need to be tuned for your specific problem. + If the optimizer is running very slowly, you can try one of the following: + - Increase the rank (this should increase the + accuracy of the Nyström approximation in PCG) + - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) + - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1.0) + rank (int, optional): rank of the Nyström approximation (default: 10) + mu (float, optional): damping parameter (default: 1e-4) + update_freq (int, optional): frequency of updating the preconditioner + chunk_size (int, optional): number of Hessian-vector products + to be computed in parallel (default: 1) + cg_tol (float, optional): tolerance for PCG (default: 1e-16) + cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) + line_search_fn (str, optional): either 'armijo' or None (default: None) + verbose (bool, optional): verbosity (default: False) + """ + + def __init__( + self, + params, + lr=1.0, + rank=10, + mu=1e-4, + update_freq=20, + chunk_size=1, + cg_tol=1e-16, + cg_max_iters=1000, + line_search_fn=None, + verbose=False, + ): + defaults = { + "lr": lr, + "rank": rank, + "mu": mu, + "update_freq": update_freq, + "chunk_size": chunk_size, + "cg_tol": cg_tol, + "cg_max_iters": cg_max_iters, + "line_search_fn": line_search_fn, + } + self.rank = rank + self.mu = mu + self.update_freq = update_freq + self.chunk_size = chunk_size + self.cg_tol = cg_tol + self.cg_max_iters = cg_max_iters + self.line_search_fn = line_search_fn + self.verbose = verbose + self.U = None + self.S = None + self.n_iters = 0 + super().__init__(params, defaults) + + if len(self.param_groups) > 1: + raise ValueError( + "NNCG doesn't currently support " + "per-parameter options (parameter groups)" + ) + + if self.line_search_fn is not None and self.line_search_fn != "armijo": + raise ValueError("NNCG only supports Armijo line search") + + self._params = self.param_groups[0]["params"] + self._params_list = list(self._params) + self._numel_cache = None + + def step(self, closure): + """Perform a single optimization step. + + Args: + closure (callable): A closure that reevaluates the model + and returns the loss w.r.t. the parameters. + """ + if self.n_iters == 0: + # Store the previous direction for warm starting PCG + self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) + + loss = closure() + # Compute gradient via torch.autograd.grad + g_tuple = torch.autograd.grad(loss, self._params_list, create_graph=True) + g = torch.cat([gi.view(-1) for gi in g_tuple if gi is not None]) + + if self.n_iters % self.update_freq == 0: + self._update_preconditioner(g) + + # One step update + for group_idx, group in enumerate(self.param_groups): + + def hvp_temp(x): + return self._hvp(g, self._params_list, x) + + # Calculate the Newton direction + d = _nystrom_pcg( + hvp_temp, + g, + self.old_dir, + self.mu, + self.U, + self.S, + self.rank, + self.cg_tol, + self.cg_max_iters, + ) + + # Store the previous direction for warm starting PCG + self.old_dir = d + + # Check if d is a descent direction + if torch.dot(d, g) <= 0: + print("Warning: d is not a descent direction") + + if self.line_search_fn == "armijo": + x_init = self._clone_param() + + def obj_func(x, t, dx): + self._add_grad(t, dx) + loss = float(closure()) + self._set_param(x) + return loss + + # Use -d for convention + t = _armijo(obj_func, x_init, g, -d, group["lr"]) + else: + t = group["lr"] + + self.state[group_idx]["t"] = t + + # update parameters + ls = 0 + for p in group["params"]: + np = torch.numel(p) + dp = d[ls : ls + np].view(p.shape) + ls += np + p.data.add_(-dp, alpha=t) + + self.n_iters += 1 + + return loss + + def _update_preconditioner(self, grad): + """Update the Nyström approximation of the Hessian. + + Args: + grad (torch.Tensor): gradient of the loss w.r.t. the parameters. + """ + # Generate test matrix (NOTE: This is transposed test matrix) + p = grad.shape[0] + Phi = torch.randn((self.rank, p), device=grad.device) / (p**0.5) + Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t() + + Y = self._hvp_vmap(grad, self._params_list)(Phi) + + # Calculate shift + shift = torch.finfo(Y.dtype).eps + Y_shifted = Y + shift * Phi + + # Calculate Phi^T * H * Phi (w/ shift) for Cholesky + choleskytarget = torch.mm(Y_shifted, Phi.t()) + + # Perform Cholesky, if fails, do eigendecomposition + # The new shift is the abs of smallest eigenvalue (negative) + # plus the original shift + try: + C = torch.linalg.cholesky(choleskytarget) + except torch.linalg.LinAlgError: + # eigendecomposition, eigenvalues and eigenvector matrix + eigs, eigvectors = torch.linalg.eigh(choleskytarget) + shift = shift + torch.abs(torch.min(eigs)) + # add shift to eigenvalues + eigs = eigs + shift + # put back the matrix for Cholesky by eigenvector * eigenvalues + # after shift * eigenvector^T + C = torch.linalg.cholesky( + torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)) + ) + + try: + B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True) + # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 + except RuntimeError: + B = torch.linalg.solve_triangular( + C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True + ).to(C.device) + + # B = V * S * U^T b/c we have been using transposed sketch + _, S, UT = torch.linalg.svd(B, full_matrices=False) + self.U = UT.t() + self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0)) + + self.rho = self.S[-1] + + if self.verbose: + print(f"Approximate eigenvalues = {self.S}") + + def _hvp_vmap(self, grad_params, params): + return vmap( + lambda v: self._hvp(grad_params, params, v), + in_dims=0, + chunk_size=self.chunk_size, + ) + + def _hvp(self, grad_params, params, v): + Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, retain_graph=True) + Hv = tuple(Hvi.detach() for Hvi in Hv) + return torch.cat([Hvi.reshape(-1) for Hvi in Hv]) + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = reduce( + lambda total, p: total + p.numel(), self._params, 0 + ) + return self._numel_cache + + def _add_grad(self, step_size, update): + offset = 0 + for p in self._params: + numel = p.numel() + # Avoid in-place operation by creating a new tensor + p.data = p.data.add( + update[offset : offset + numel].view_as(p), alpha=step_size + ) + offset += numel + assert offset == self._numel() + + def _clone_param(self): + return [p.clone(memory_format=torch.contiguous_format) for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + # Replace the .data attribute of the tensor + p.data = pdata.data diff --git a/deepxde/optimizers/pytorch/optimizers.py b/deepxde/optimizers/pytorch/optimizers.py index 6329912dd..35ab88d24 100644 --- a/deepxde/optimizers/pytorch/optimizers.py +++ b/deepxde/optimizers/pytorch/optimizers.py @@ -2,11 +2,12 @@ import torch -from ..config import LBFGS_options +from .nncg import NNCG +from ..config import LBFGS_options, NNCG_options def is_external_optimizer(optimizer): - return optimizer in ["L-BFGS", "L-BFGS-B"] + return optimizer in ["L-BFGS", "L-BFGS-B", "NNCG"] def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): @@ -29,6 +30,23 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): history_size=LBFGS_options["maxcor"], line_search_fn=("strong_wolfe" if LBFGS_options["maxls"] > 0 else None), ) + elif optimizer == "NNCG": + if weight_decay > 0: + raise ValueError("NNCG optimizer doesn't support weight_decay > 0") + if learning_rate is not None or decay is not None: + print("Warning: learning rate is ignored for {}".format(optimizer)) + optim = NNCG( + params, + lr=NNCG_options["lr"], + rank=NNCG_options["rank"], + mu=NNCG_options["mu"], + update_freq=NNCG_options["updatefreq"], + chunk_size=NNCG_options["chunksz"], + cg_tol=NNCG_options["cgtol"], + cg_max_iters=NNCG_options["cgmaxiter"], + line_search_fn=NNCG_options["lsfun"], + verbose=NNCG_options["verbose"], + ) else: if learning_rate is None: raise ValueError("No learning rate for {}.".format(optimizer)) diff --git a/docs/demos/pinn_forward/burgers.rst b/docs/demos/pinn_forward/burgers.rst index 76ded3e31..08e56b38a 100644 --- a/docs/demos/pinn_forward/burgers.rst +++ b/docs/demos/pinn_forward/burgers.rst @@ -87,7 +87,17 @@ After we train the network using Adam, we continue to train the network using L- .. code-block:: python model.compile("L-BFGS-B") - losshistory, train_state = model.train() + losshistory, train_state = model.train() + +However, L-BFGS can stall out early in optimization if it is unable to find a step size satisfying the strong Wolfe conditions. In such cases, we can use the NNCG optimizer (compatible with PyTorch only) to continue reducing the loss: + +.. code-block:: python + + dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) + model.compile("NNCG") + losshistory, train_state = model.train(iterations=1000, display_every=100) + +By default, NNCG does not run in this demo. You will have to uncomment the NNCG code block in the demo to have it run after Adam and L-BFGS. Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer. Complete code -------------- diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index 64ee46bb6..7b6883ed1 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -1,4 +1,5 @@ """Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" + import deepxde as dde import numpy as np @@ -38,6 +39,11 @@ def pde(x, y): model.train(iterations=15000) model.compile("L-BFGS") losshistory, train_state = model.train() +# """Backend supported: pytorch""" +# # Run NNCG after Adam and L-BFGS +# dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) +# model.compile("NNCG") +# losshistory, train_state = model.train(iterations=1000, display_every=100) dde.saveplot(losshistory, train_state, issave=True, isplot=True) X, y_true = gen_testdata()