From d0e22bbca72bd4f9243c4fd65d0e5f25a5cba55b Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Thu, 22 Feb 2024 11:52:54 -0800 Subject: [PATCH 01/26] Add NNCG to optimizers submodule --- deepxde/optimizers/nys_newton_cg.py | 272 ++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 deepxde/optimizers/nys_newton_cg.py diff --git a/deepxde/optimizers/nys_newton_cg.py b/deepxde/optimizers/nys_newton_cg.py new file mode 100644 index 000000000..feeacd4bd --- /dev/null +++ b/deepxde/optimizers/nys_newton_cg.py @@ -0,0 +1,272 @@ +import torch +from torch.optim import Optimizer +from torch.func import vmap +from functools import reduce + +def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): + """Line search to find a step size that satisfies the Armijo condition.""" + f0 = f(x, 0, dx) + f1 = f(x, t, dx) + while f1 > f0 + alpha * t * gx.dot(dx): + t *= beta + f1 = f(x, t, dx) + return t + +def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x): + """Applies the inverse of the Nystrom approximation of the Hessian to a vector.""" + z = U.T @ x + z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z) + return z + +def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): + """Solves a positive-definite linear system using NyströmPCG. + + `Frangella et al. Randomized Nyström Preconditioning. + SIAM Journal on Matrix Analysis and Applications, 2023. + `""" + lambd_r = S[r - 1] + S_mu_inv = (S + mu) ** (-1) + + resid = b - (hess(x) + mu * x) + with torch.no_grad(): + z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) + p = z.clone() + + i = 0 + + while torch.norm(resid) > tol and i < max_iters: + v = hess(p) + mu * p + with torch.no_grad(): + alpha = torch.dot(resid, z) / torch.dot(p, v) + x += alpha * p + + rTz = torch.dot(resid, z) + resid -= alpha * v + z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) + beta = torch.dot(resid, z) / rTz + + p = z + beta * p + + i += 1 + + if torch.norm(resid) > tol: + print(f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}") + + return x + +class NysNewtonCG(Optimizer): + """Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning. + + `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. + Preprint, 2024. ` + + .. warning:: + This optimizer doesn't support per-parameter options and parameter + groups (there can be only one). + + NOTE: This optimizer is currently a beta version. + + Our implementation is inspired by the PyTorch implementation of `L-BFGS + `. + + The parameters rank and mu will probably need to be tuned for your specific problem. + If the optimizer is running very slowly, you can try one of the following: + - Increase the rank (this should increase the accuracy of the Nyström approximation in PCG) + - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) + - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1.0) + rank (int, optional): rank of the Nyström approximation (default: 10) + mu (float, optional): damping parameter (default: 1e-4) + chunk_size (int, optional): number of Hessian-vector products to be computed in parallel (default: 1) + cg_tol (float, optional): tolerance for PCG (default: 1e-16) + cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) + line_search_fn (str, optional): either 'armijo' or None (default: None) + verbose (bool, optional): verbosity (default: False) + + """ + def __init__(self, params, lr=1.0, rank=10, mu=1e-4, chunk_size=1, + cg_tol=1e-16, cg_max_iters=1000, line_search_fn=None, verbose=False): + defaults = dict(lr=lr, rank=rank, chunk_size=chunk_size, mu=mu, cg_tol=cg_tol, + cg_max_iters=cg_max_iters, line_search_fn=line_search_fn) + self.rank = rank + self.mu = mu + self.chunk_size = chunk_size + self.cg_tol = cg_tol + self.cg_max_iters = cg_max_iters + self.line_search_fn = line_search_fn + self.verbose = verbose + self.U = None + self.S = None + self.n_iters = 0 + super(NysNewtonCG, self).__init__(params, defaults) + + if len(self.param_groups) > 1: + raise ValueError( + "NysNewtonCG doesn't currently support per-parameter options (parameter groups)") + + if self.line_search_fn is not None and self.line_search_fn != 'armijo': + raise ValueError("NysNewtonCG only supports Armijo line search") + + self._params = self.param_groups[0]['params'] + self._params_list = list(self._params) + self._numel_cache = None + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns (i) the loss and (ii) gradient w.r.t. the parameters. + The closure can compute the gradient w.r.t. the parameters by calling torch.autograd.grad on the loss with create_graph=True. + """ + if self.n_iters == 0: + # Store the previous direction for warm starting PCG + self.old_dir = torch.zeros( + self._numel(), device=self._params[0].device) + + # NOTE: The closure must return both the loss and the gradient + loss = None + if closure is not None: + with torch.enable_grad(): + loss, grad_tuple = closure() + + g = torch.cat([grad.view(-1) for grad in grad_tuple if grad is not None]) + + # One step update + for group_idx, group in enumerate(self.param_groups): + def hvp_temp(x): + return self._hvp(g, self._params_list, x) + + # Calculate the Newton direction + d = _nystrom_pcg(hvp_temp, g, self.old_dir, + self.mu, self.U, self.S, self.rank, self.cg_tol, self.cg_max_iters) + + # Store the previous direction for warm starting PCG + self.old_dir = d + + # Check if d is a descent direction + if torch.dot(d, g) <= 0: + print("Warning: d is not a descent direction") + + if self.line_search_fn == 'armijo': + x_init = self._clone_param() + + def obj_func(x, t, dx): + self._add_grad(t, dx) + loss = float(closure()[0]) + self._set_param(x) + return loss + + # Use -d for convention + t = _armijo(obj_func, x_init, g, -d, group['lr']) + else: + t = group['lr'] + + self.state[group_idx]['t'] = t + + # update parameters + ls = 0 + for p in group['params']: + np = torch.numel(p) + dp = d[ls:ls+np].view(p.shape) + ls += np + p.data.add_(-dp, alpha=t) + + self.n_iters += 1 + + return loss, g + + def update_preconditioner(self, grad_tuple): + """Update the Nystrom approximation of the Hessian. + + Args: + grad_tuple (tuple): tuple of Tensors containing the gradients of the loss w.r.t. the parameters. + This tuple can be obtained by calling torch.autograd.grad on the loss with create_graph=True. + """ + + # Flatten and concatenate the gradients + gradsH = torch.cat([gradient.view(-1) + for gradient in grad_tuple if gradient is not None]) + + # Generate test matrix (NOTE: This is transposed test matrix) + p = gradsH.shape[0] + Phi = torch.randn( + (self.rank, p), device=gradsH.device) / (p ** 0.5) + Phi = torch.linalg.qr(Phi.t(), mode='reduced')[0].t() + + Y = self._hvp_vmap(gradsH, self._params_list)(Phi) + + # Calculate shift + shift = torch.finfo(Y.dtype).eps + Y_shifted = Y + shift * Phi + + # Calculate Phi^T * H * Phi (w/ shift) for Cholesky + choleskytarget = torch.mm(Y_shifted, Phi.t()) + + # Perform Cholesky, if fails, do eigendecomposition + # The new shift is the abs of smallest eigenvalue (negative) plus the original shift + try: + C = torch.linalg.cholesky(choleskytarget) + except: + # eigendecomposition, eigenvalues and eigenvector matrix + eigs, eigvectors = torch.linalg.eigh(choleskytarget) + shift = shift + torch.abs(torch.min(eigs)) + # add shift to eigenvalues + eigs = eigs + shift + # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T + C = torch.linalg.cholesky( + torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T))) + + try: + B = torch.linalg.solve_triangular( + C, Y_shifted, upper=False, left=True) + # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 + except: + B = torch.linalg.solve_triangular(C.to('cpu'), Y_shifted.to( + 'cpu'), upper=False, left=True).to(C.device) + + # B = V * S * U^T b/c we have been using transposed sketch + _, S, UT = torch.linalg.svd(B, full_matrices=False) + self.U = UT.t() + self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0)) + + self.rho = self.S[-1] + + if self.verbose: + print(f'Approximate eigenvalues = {self.S}') + + def _hvp_vmap(self, grad_params, params): + return vmap(lambda v: self._hvp(grad_params, params, v), in_dims=0, chunk_size=self.chunk_size) + + def _hvp(self, grad_params, params, v): + Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, + retain_graph=True) + Hv = tuple(Hvi.detach() for Hvi in Hv) + return torch.cat([Hvi.reshape(-1) for Hvi in Hv]) + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = reduce( + lambda total, p: total + p.numel(), self._params, 0) + return self._numel_cache + + def _add_grad(self, step_size, update): + offset = 0 + for p in self._params: + numel = p.numel() + # Avoid in-place operation by creating a new tensor + p.data = p.data.add( + update[offset:offset + numel].view_as(p), alpha=step_size) + offset += numel + assert offset == self._numel() + + def _clone_param(self): + return [p.clone(memory_format=torch.contiguous_format) for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + # Replace the .data attribute of the tensor + p.data = pdata.data \ No newline at end of file From 3afed9f377dd3dbcebfc91a07139ab9466d25e9b Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:39:22 -0800 Subject: [PATCH 02/26] Update nys_newton_cg.py --- deepxde/optimizers/nys_newton_cg.py | 126 ++++++++++++++++++---------- 1 file changed, 83 insertions(+), 43 deletions(-) diff --git a/deepxde/optimizers/nys_newton_cg.py b/deepxde/optimizers/nys_newton_cg.py index feeacd4bd..885f1789a 100644 --- a/deepxde/optimizers/nys_newton_cg.py +++ b/deepxde/optimizers/nys_newton_cg.py @@ -3,6 +3,7 @@ from torch.func import vmap from functools import reduce + def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): """Line search to find a step size that satisfies the Armijo condition.""" f0 = f(x, 0, dx) @@ -12,16 +13,18 @@ def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): f1 = f(x, t, dx) return t + def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x): """Applies the inverse of the Nystrom approximation of the Hessian to a vector.""" z = U.T @ x z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z) return z + def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): """Solves a positive-definite linear system using NyströmPCG. - `Frangella et al. Randomized Nyström Preconditioning. + `Frangella et al. Randomized Nyström Preconditioning. SIAM Journal on Matrix Analysis and Applications, 2023. `""" lambd_r = S[r - 1] @@ -50,13 +53,16 @@ def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): i += 1 if torch.norm(resid) > tol: - print(f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}") + print( + f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}" + ) return x + class NysNewtonCG(Optimizer): """Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning. - + `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. Preprint, 2024. ` @@ -64,11 +70,11 @@ class NysNewtonCG(Optimizer): This optimizer doesn't support per-parameter options and parameter groups (there can be only one). - NOTE: This optimizer is currently a beta version. + NOTE: This optimizer is currently a beta version. - Our implementation is inspired by the PyTorch implementation of `L-BFGS + Our implementation is inspired by the PyTorch implementation of `L-BFGS `. - + The parameters rank and mu will probably need to be tuned for your specific problem. If the optimizer is running very slowly, you can try one of the following: - Increase the rank (this should increase the accuracy of the Nyström approximation in PCG) @@ -86,12 +92,30 @@ class NysNewtonCG(Optimizer): cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) line_search_fn (str, optional): either 'armijo' or None (default: None) verbose (bool, optional): verbosity (default: False) - + """ - def __init__(self, params, lr=1.0, rank=10, mu=1e-4, chunk_size=1, - cg_tol=1e-16, cg_max_iters=1000, line_search_fn=None, verbose=False): - defaults = dict(lr=lr, rank=rank, chunk_size=chunk_size, mu=mu, cg_tol=cg_tol, - cg_max_iters=cg_max_iters, line_search_fn=line_search_fn) + + def __init__( + self, + params, + lr=1.0, + rank=10, + mu=1e-4, + chunk_size=1, + cg_tol=1e-16, + cg_max_iters=1000, + line_search_fn=None, + verbose=False, + ): + defaults = dict( + lr=lr, + rank=rank, + chunk_size=chunk_size, + mu=mu, + cg_tol=cg_tol, + cg_max_iters=cg_max_iters, + line_search_fn=line_search_fn, + ) self.rank = rank self.mu = mu self.chunk_size = chunk_size @@ -106,12 +130,13 @@ def __init__(self, params, lr=1.0, rank=10, mu=1e-4, chunk_size=1, if len(self.param_groups) > 1: raise ValueError( - "NysNewtonCG doesn't currently support per-parameter options (parameter groups)") + "NysNewtonCG doesn't currently support per-parameter options (parameter groups)" + ) - if self.line_search_fn is not None and self.line_search_fn != 'armijo': + if self.line_search_fn is not None and self.line_search_fn != "armijo": raise ValueError("NysNewtonCG only supports Armijo line search") - self._params = self.param_groups[0]['params'] + self._params = self.param_groups[0]["params"] self._params_list = list(self._params) self._numel_cache = None @@ -124,8 +149,7 @@ def step(self, closure=None): """ if self.n_iters == 0: # Store the previous direction for warm starting PCG - self.old_dir = torch.zeros( - self._numel(), device=self._params[0].device) + self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) # NOTE: The closure must return both the loss and the gradient loss = None @@ -137,12 +161,22 @@ def step(self, closure=None): # One step update for group_idx, group in enumerate(self.param_groups): + def hvp_temp(x): return self._hvp(g, self._params_list, x) # Calculate the Newton direction - d = _nystrom_pcg(hvp_temp, g, self.old_dir, - self.mu, self.U, self.S, self.rank, self.cg_tol, self.cg_max_iters) + d = _nystrom_pcg( + hvp_temp, + g, + self.old_dir, + self.mu, + self.U, + self.S, + self.rank, + self.cg_tol, + self.cg_max_iters, + ) # Store the previous direction for warm starting PCG self.old_dir = d @@ -151,7 +185,7 @@ def hvp_temp(x): if torch.dot(d, g) <= 0: print("Warning: d is not a descent direction") - if self.line_search_fn == 'armijo': + if self.line_search_fn == "armijo": x_init = self._clone_param() def obj_func(x, t, dx): @@ -161,17 +195,17 @@ def obj_func(x, t, dx): return loss # Use -d for convention - t = _armijo(obj_func, x_init, g, -d, group['lr']) + t = _armijo(obj_func, x_init, g, -d, group["lr"]) else: - t = group['lr'] + t = group["lr"] - self.state[group_idx]['t'] = t + self.state[group_idx]["t"] = t # update parameters ls = 0 - for p in group['params']: + for p in group["params"]: np = torch.numel(p) - dp = d[ls:ls+np].view(p.shape) + dp = d[ls : ls + np].view(p.shape) ls += np p.data.add_(-dp, alpha=t) @@ -183,19 +217,19 @@ def update_preconditioner(self, grad_tuple): """Update the Nystrom approximation of the Hessian. Args: - grad_tuple (tuple): tuple of Tensors containing the gradients of the loss w.r.t. the parameters. + grad_tuple (tuple): tuple of Tensors containing the gradients of the loss w.r.t. the parameters. This tuple can be obtained by calling torch.autograd.grad on the loss with create_graph=True. """ # Flatten and concatenate the gradients - gradsH = torch.cat([gradient.view(-1) - for gradient in grad_tuple if gradient is not None]) + gradsH = torch.cat( + [gradient.view(-1) for gradient in grad_tuple if gradient is not None] + ) # Generate test matrix (NOTE: This is transposed test matrix) p = gradsH.shape[0] - Phi = torch.randn( - (self.rank, p), device=gradsH.device) / (p ** 0.5) - Phi = torch.linalg.qr(Phi.t(), mode='reduced')[0].t() + Phi = torch.randn((self.rank, p), device=gradsH.device) / (p**0.5) + Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t() Y = self._hvp_vmap(gradsH, self._params_list)(Phi) @@ -218,16 +252,17 @@ def update_preconditioner(self, grad_tuple): eigs = eigs + shift # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T C = torch.linalg.cholesky( - torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T))) + torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)) + ) try: - B = torch.linalg.solve_triangular( - C, Y_shifted, upper=False, left=True) + B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True) # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 except: - B = torch.linalg.solve_triangular(C.to('cpu'), Y_shifted.to( - 'cpu'), upper=False, left=True).to(C.device) - + B = torch.linalg.solve_triangular( + C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True + ).to(C.device) + # B = V * S * U^T b/c we have been using transposed sketch _, S, UT = torch.linalg.svd(B, full_matrices=False) self.U = UT.t() @@ -236,21 +271,25 @@ def update_preconditioner(self, grad_tuple): self.rho = self.S[-1] if self.verbose: - print(f'Approximate eigenvalues = {self.S}') + print(f"Approximate eigenvalues = {self.S}") def _hvp_vmap(self, grad_params, params): - return vmap(lambda v: self._hvp(grad_params, params, v), in_dims=0, chunk_size=self.chunk_size) + return vmap( + lambda v: self._hvp(grad_params, params, v), + in_dims=0, + chunk_size=self.chunk_size, + ) def _hvp(self, grad_params, params, v): - Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, - retain_graph=True) + Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, retain_graph=True) Hv = tuple(Hvi.detach() for Hvi in Hv) return torch.cat([Hvi.reshape(-1) for Hvi in Hv]) def _numel(self): if self._numel_cache is None: self._numel_cache = reduce( - lambda total, p: total + p.numel(), self._params, 0) + lambda total, p: total + p.numel(), self._params, 0 + ) return self._numel_cache def _add_grad(self, step_size, update): @@ -259,7 +298,8 @@ def _add_grad(self, step_size, update): numel = p.numel() # Avoid in-place operation by creating a new tensor p.data = p.data.add( - update[offset:offset + numel].view_as(p), alpha=step_size) + update[offset : offset + numel].view_as(p), alpha=step_size + ) offset += numel assert offset == self._numel() @@ -269,4 +309,4 @@ def _clone_param(self): def _set_param(self, params_data): for p, pdata in zip(self._params, params_data): # Replace the .data attribute of the tensor - p.data = pdata.data \ No newline at end of file + p.data = pdata.data From 081d5f6946e85e2423f0da01c5fa0096de2251d2 Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:39:54 -0800 Subject: [PATCH 03/26] Moved NNCG to pytorch folder --- deepxde/optimizers/{ => pytorch}/nys_newton_cg.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename deepxde/optimizers/{ => pytorch}/nys_newton_cg.py (100%) diff --git a/deepxde/optimizers/nys_newton_cg.py b/deepxde/optimizers/pytorch/nys_newton_cg.py similarity index 100% rename from deepxde/optimizers/nys_newton_cg.py rename to deepxde/optimizers/pytorch/nys_newton_cg.py From 03a77a14699a1e29936859db09dfb0be58ac6751 Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Fri, 1 Mar 2024 12:54:58 -0800 Subject: [PATCH 04/26] Minor formatting changes in NNCG --- deepxde/optimizers/pytorch/nys_newton_cg.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/deepxde/optimizers/pytorch/nys_newton_cg.py b/deepxde/optimizers/pytorch/nys_newton_cg.py index 885f1789a..20e5becd9 100644 --- a/deepxde/optimizers/pytorch/nys_newton_cg.py +++ b/deepxde/optimizers/pytorch/nys_newton_cg.py @@ -1,7 +1,8 @@ +from functools import reduce + import torch -from torch.optim import Optimizer from torch.func import vmap -from functools import reduce +from torch.optim import Optimizer def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): @@ -26,7 +27,8 @@ def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): `Frangella et al. Randomized Nyström Preconditioning. SIAM Journal on Matrix Analysis and Applications, 2023. - `""" + ` + """ lambd_r = S[r - 1] S_mu_inv = (S + mu) ** (-1) @@ -92,7 +94,6 @@ class NysNewtonCG(Optimizer): cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) line_search_fn (str, optional): either 'armijo' or None (default: None) verbose (bool, optional): verbosity (default: False) - """ def __init__( @@ -144,8 +145,10 @@ def step(self, closure=None): """Perform a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model and returns (i) the loss and (ii) gradient w.r.t. the parameters. - The closure can compute the gradient w.r.t. the parameters by calling torch.autograd.grad on the loss with create_graph=True. + closure (callable, optional): A closure that reevaluates the model + and returns (i) the loss and (ii) gradient w.r.t. the parameters. + The closure can compute the gradient w.r.t. the parameters by + calling torch.autograd.grad on the loss with create_graph=True. """ if self.n_iters == 0: # Store the previous direction for warm starting PCG @@ -217,8 +220,10 @@ def update_preconditioner(self, grad_tuple): """Update the Nystrom approximation of the Hessian. Args: - grad_tuple (tuple): tuple of Tensors containing the gradients of the loss w.r.t. the parameters. - This tuple can be obtained by calling torch.autograd.grad on the loss with create_graph=True. + grad_tuple (tuple): tuple of Tensors containing the gradients + of the loss w.r.t. the parameters. + This tuple can be obtained by calling torch.autograd.grad + on the loss with create_graph=True. """ # Flatten and concatenate the gradients From 88d2f7eb5e26026e9c3fcc76b7aabbd50f7cd66e Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Mon, 4 Mar 2024 12:06:03 -0800 Subject: [PATCH 05/26] Update nys_newton_cg.py --- deepxde/optimizers/pytorch/nys_newton_cg.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/deepxde/optimizers/pytorch/nys_newton_cg.py b/deepxde/optimizers/pytorch/nys_newton_cg.py index 20e5becd9..13d82cf7e 100644 --- a/deepxde/optimizers/pytorch/nys_newton_cg.py +++ b/deepxde/optimizers/pytorch/nys_newton_cg.py @@ -56,14 +56,16 @@ def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): if torch.norm(resid) > tol: print( - f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}" + f"Warning: PCG did not converge to tolerance. + Tolerance was {tol} but norm of residual is {torch.norm(resid)}" ) return x class NysNewtonCG(Optimizer): - """Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning. + """Implementation of NysNewtonCG, a damped Newton-CG method + that uses Nyström preconditioning. `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. Preprint, 2024. ` @@ -79,7 +81,8 @@ class NysNewtonCG(Optimizer): The parameters rank and mu will probably need to be tuned for your specific problem. If the optimizer is running very slowly, you can try one of the following: - - Increase the rank (this should increase the accuracy of the Nyström approximation in PCG) + - Increase the rank (this should increase the + accuracy of the Nyström approximation in PCG) - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) @@ -89,7 +92,8 @@ class NysNewtonCG(Optimizer): lr (float, optional): learning rate (default: 1.0) rank (int, optional): rank of the Nyström approximation (default: 10) mu (float, optional): damping parameter (default: 1e-4) - chunk_size (int, optional): number of Hessian-vector products to be computed in parallel (default: 1) + chunk_size (int, optional): number of Hessian-vector products + to be computed in parallel (default: 1) cg_tol (float, optional): tolerance for PCG (default: 1e-16) cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) line_search_fn (str, optional): either 'armijo' or None (default: None) @@ -131,7 +135,8 @@ def __init__( if len(self.param_groups) > 1: raise ValueError( - "NysNewtonCG doesn't currently support per-parameter options (parameter groups)" + f"NysNewtonCG doesn't currently support + per-parameter options (parameter groups)" ) if self.line_search_fn is not None and self.line_search_fn != "armijo": @@ -246,7 +251,8 @@ def update_preconditioner(self, grad_tuple): choleskytarget = torch.mm(Y_shifted, Phi.t()) # Perform Cholesky, if fails, do eigendecomposition - # The new shift is the abs of smallest eigenvalue (negative) plus the original shift + # The new shift is the abs of smallest eigenvalue (negative) + # plus the original shift try: C = torch.linalg.cholesky(choleskytarget) except: @@ -255,7 +261,8 @@ def update_preconditioner(self, grad_tuple): shift = shift + torch.abs(torch.min(eigs)) # add shift to eigenvalues eigs = eigs + shift - # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T + # put back the matrix for Cholesky by eigenvector * eigenvalues + # after shift * eigenvector^T C = torch.linalg.cholesky( torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)) ) From fff6a91726ae4709f57d4a6180c5f35e76b9ea89 Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Mon, 4 Mar 2024 12:30:24 -0800 Subject: [PATCH 06/26] Fix Codacy issues --- deepxde/optimizers/pytorch/nys_newton_cg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepxde/optimizers/pytorch/nys_newton_cg.py b/deepxde/optimizers/pytorch/nys_newton_cg.py index 13d82cf7e..9b2d2d903 100644 --- a/deepxde/optimizers/pytorch/nys_newton_cg.py +++ b/deepxde/optimizers/pytorch/nys_newton_cg.py @@ -56,8 +56,8 @@ def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): if torch.norm(resid) > tol: print( - f"Warning: PCG did not converge to tolerance. - Tolerance was {tol} but norm of residual is {torch.norm(resid)}" + "Warning: PCG did not converge to tolerance. " + "Tolerance was {tol} but norm of residual is {torch.norm(resid)}" ) return x @@ -255,7 +255,7 @@ def update_preconditioner(self, grad_tuple): # plus the original shift try: C = torch.linalg.cholesky(choleskytarget) - except: + except torch.linalg.LinAlgError: # eigendecomposition, eigenvalues and eigenvector matrix eigs, eigvectors = torch.linalg.eigh(choleskytarget) shift = shift + torch.abs(torch.min(eigs)) From 19490eab2a3c435e2da57c3042ff405946bae5a2 Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Mon, 4 Mar 2024 12:33:39 -0800 Subject: [PATCH 07/26] Fix more Codacy issues --- deepxde/optimizers/pytorch/nys_newton_cg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepxde/optimizers/pytorch/nys_newton_cg.py b/deepxde/optimizers/pytorch/nys_newton_cg.py index 9b2d2d903..1f2e2623e 100644 --- a/deepxde/optimizers/pytorch/nys_newton_cg.py +++ b/deepxde/optimizers/pytorch/nys_newton_cg.py @@ -135,8 +135,8 @@ def __init__( if len(self.param_groups) > 1: raise ValueError( - f"NysNewtonCG doesn't currently support - per-parameter options (parameter groups)" + "NysNewtonCG doesn't currently support " + "per-parameter options (parameter groups)" ) if self.line_search_fn is not None and self.line_search_fn != "armijo": @@ -270,7 +270,7 @@ def update_preconditioner(self, grad_tuple): try: B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True) # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 - except: + except RuntimeError: B = torch.linalg.solve_triangular( C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True ).to(C.device) From ec59a99749a93fbcc7580235df63dd02360008e1 Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:44:29 -0700 Subject: [PATCH 08/26] Added NNCG to config.py and optimizers.py --- deepxde/optimizers/config.py | 48 ++++++++++++++++++- .../pytorch/{nys_newton_cg.py => nncg.py} | 8 ++-- deepxde/optimizers/pytorch/optimizers.py | 44 +++++++++++------ 3 files changed, 81 insertions(+), 19 deletions(-) rename deepxde/optimizers/pytorch/{nys_newton_cg.py => nncg.py} (98%) diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 01ba8bd1f..2a1504d7f 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -1,9 +1,10 @@ -__all__ = ["set_LBFGS_options", "set_hvd_opt_options"] +__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"] from ..backend import backend_name from ..config import hvd LBFGS_options = {} +NNCG_options = {} if hvd is not None: hvd_opt_options = {} @@ -59,6 +60,50 @@ def set_LBFGS_options( LBFGS_options["maxfun"] = maxfun if maxfun is not None else int(maxiter * 1.25) LBFGS_options["maxls"] = maxls +def set_NNCG_options( + lr=1, + rank=10, + mu=1e-4, + chunksz=1, + cgtol=1e-16, + cgmaxiter=1000, + lsfun="armijo", + verbose=False +): + """Sets the hyperparameters of NysNewtonCG (NNCG). + + Args: + lr (float): `lr` (torch). + Learning rate (before line search). + rank (int): `rank` (torch). + Rank of preconditioner matrix used in preconditioned conjugate gradient. + mu (float): `mu` (torch). + Hessian damping parameter. + chunksz (int): `chunk_size` (torch). + Number of Hessian-vector products to compute in parallel when constructing + preconditioner. If `chunk_size` is 1, the Hessian-vector products are + computed serially. + cgtol (float): `cg_tol` (torch). + Convergence tolerance for the conjugate gradient method. The iteration stops + when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition + is based on the absolute tolerance, not the relative tolerance. + cgmaxiter (int): `cg_max_iters` (torch). + Maximum number of iterations for the conjugate gradient method. + lsfun (str): `line_search_fn` (torch). + The line search function used to find the step size. The default value is + "armijo". The other option is None. + verbose (bool): `verbose` (torch). + If `True`, prints the eigenvalues of the Nyström approximation + of the Hessian. + """ + NNCG_options["lr"] = lr + NNCG_options["rank"] = rank + NNCG_options["mu"] = mu + NNCG_options["chunksz"] = chunksz + NNCG_options["cgtol"] = cgtol + NNCG_options["cgmaxiter"] = cgmaxiter + NNCG_options["lsfun"] = lsfun + NNCG_options["verbose"] = verbose def set_hvd_opt_options( compression=None, @@ -91,6 +136,7 @@ def set_hvd_opt_options( set_LBFGS_options() +set_NNCG_options() if hvd is not None: set_hvd_opt_options() diff --git a/deepxde/optimizers/pytorch/nys_newton_cg.py b/deepxde/optimizers/pytorch/nncg.py similarity index 98% rename from deepxde/optimizers/pytorch/nys_newton_cg.py rename to deepxde/optimizers/pytorch/nncg.py index 1f2e2623e..a71e7da31 100644 --- a/deepxde/optimizers/pytorch/nys_newton_cg.py +++ b/deepxde/optimizers/pytorch/nncg.py @@ -63,7 +63,7 @@ def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): return x -class NysNewtonCG(Optimizer): +class NNCG(Optimizer): """Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning. @@ -131,16 +131,16 @@ def __init__( self.U = None self.S = None self.n_iters = 0 - super(NysNewtonCG, self).__init__(params, defaults) + super(NNCG, self).__init__(params, defaults) if len(self.param_groups) > 1: raise ValueError( - "NysNewtonCG doesn't currently support " + "NNCG doesn't currently support " "per-parameter options (parameter groups)" ) if self.line_search_fn is not None and self.line_search_fn != "armijo": - raise ValueError("NysNewtonCG only supports Armijo line search") + raise ValueError("NNCG only supports Armijo line search") self._params = self.param_groups[0]["params"] self._params_list = list(self._params) diff --git a/deepxde/optimizers/pytorch/optimizers.py b/deepxde/optimizers/pytorch/optimizers.py index 6329912dd..ded1411be 100644 --- a/deepxde/optimizers/pytorch/optimizers.py +++ b/deepxde/optimizers/pytorch/optimizers.py @@ -2,11 +2,12 @@ import torch -from ..config import LBFGS_options +from ..config import LBFGS_options, NNCG_options +from .nncg import NNCG def is_external_optimizer(optimizer): - return optimizer in ["L-BFGS", "L-BFGS-B"] + return optimizer in ["L-BFGS", "L-BFGS-B", "NNCG"] def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): @@ -14,21 +15,36 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): # Custom Optimizer if isinstance(optimizer, torch.optim.Optimizer): optim = optimizer - elif optimizer in ["L-BFGS", "L-BFGS-B"]: + elif optimizer in ["L-BFGS", "L-BFGS-B", "NNCG"]: if weight_decay > 0: - raise ValueError("L-BFGS optimizer doesn't support weight_decay > 0") + error_optim = "L-BFGS" if optimizer in ["L-BFGS", "L-BFGS-B"] else "NNCG" + raise ValueError(f"{error_optim} optimizer doesn't support + weight_decay > 0") if learning_rate is not None or decay is not None: print("Warning: learning rate is ignored for {}".format(optimizer)) - optim = torch.optim.LBFGS( - params, - lr=1, - max_iter=LBFGS_options["iter_per_step"], - max_eval=LBFGS_options["fun_per_step"], - tolerance_grad=LBFGS_options["gtol"], - tolerance_change=LBFGS_options["ftol"], - history_size=LBFGS_options["maxcor"], - line_search_fn=("strong_wolfe" if LBFGS_options["maxls"] > 0 else None), - ) + if optimizer in ["L-BFGS", "L-BFGS-B"]: + optim = torch.optim.LBFGS( + params, + lr=1, + max_iter=LBFGS_options["iter_per_step"], + max_eval=LBFGS_options["fun_per_step"], + tolerance_grad=LBFGS_options["gtol"], + tolerance_change=LBFGS_options["ftol"], + history_size=LBFGS_options["maxcor"], + line_search_fn=("strong_wolfe" if LBFGS_options["maxls"] > 0 else None), + ) + else: + optim = NNCG( + params, + lr=NNCG_options["lr"], + rank=NNCG_options["rank"], + mu=NNCG_options["mu"], + chunk_size=NNCG_options["chunksz"], + cg_tol=NNCG_options["cgtol"], + cg_max_iters=NNCG_options["cgmaxiter"], + line_search_fn=NNCG_options["lsfun"], + verbose=NNCG_options["verbose"], + ) else: if learning_rate is None: raise ValueError("No learning rate for {}.".format(optimizer)) From 8995aad87f66cd0358f513680f0c5d222527c7ce Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:40:22 -0700 Subject: [PATCH 09/26] Clean up NNCG integration in optimizers.py --- deepxde/optimizers/pytorch/optimizers.py | 55 ++++++++++++------------ 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/deepxde/optimizers/pytorch/optimizers.py b/deepxde/optimizers/pytorch/optimizers.py index ded1411be..777662578 100644 --- a/deepxde/optimizers/pytorch/optimizers.py +++ b/deepxde/optimizers/pytorch/optimizers.py @@ -15,36 +15,37 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): # Custom Optimizer if isinstance(optimizer, torch.optim.Optimizer): optim = optimizer - elif optimizer in ["L-BFGS", "L-BFGS-B", "NNCG"]: + elif optimizer in ["L-BFGS", "L-BFGS-B"]: if weight_decay > 0: - error_optim = "L-BFGS" if optimizer in ["L-BFGS", "L-BFGS-B"] else "NNCG" - raise ValueError(f"{error_optim} optimizer doesn't support - weight_decay > 0") + raise ValueError("L-BFGS optimizer doesn't support weight_decay > 0") if learning_rate is not None or decay is not None: print("Warning: learning rate is ignored for {}".format(optimizer)) - if optimizer in ["L-BFGS", "L-BFGS-B"]: - optim = torch.optim.LBFGS( - params, - lr=1, - max_iter=LBFGS_options["iter_per_step"], - max_eval=LBFGS_options["fun_per_step"], - tolerance_grad=LBFGS_options["gtol"], - tolerance_change=LBFGS_options["ftol"], - history_size=LBFGS_options["maxcor"], - line_search_fn=("strong_wolfe" if LBFGS_options["maxls"] > 0 else None), - ) - else: - optim = NNCG( - params, - lr=NNCG_options["lr"], - rank=NNCG_options["rank"], - mu=NNCG_options["mu"], - chunk_size=NNCG_options["chunksz"], - cg_tol=NNCG_options["cgtol"], - cg_max_iters=NNCG_options["cgmaxiter"], - line_search_fn=NNCG_options["lsfun"], - verbose=NNCG_options["verbose"], - ) + optim = torch.optim.LBFGS( + params, + lr=1, + max_iter=LBFGS_options["iter_per_step"], + max_eval=LBFGS_options["fun_per_step"], + tolerance_grad=LBFGS_options["gtol"], + tolerance_change=LBFGS_options["ftol"], + history_size=LBFGS_options["maxcor"], + line_search_fn=("strong_wolfe" if LBFGS_options["maxls"] > 0 else None), + ) + elif optimizer == "NNCG": + if weight_decay > 0: + raise ValueError("NNCG optimizer doesn't support weight_decay > 0") + if learning_rate is not None or decay is not None: + print("Warning: learning rate is ignored for {}".format(optimizer)) + optim = NNCG( + params, + lr=NNCG_options["lr"], + rank=NNCG_options["rank"], + mu=NNCG_options["mu"], + chunk_size=NNCG_options["chunksz"], + cg_tol=NNCG_options["cgtol"], + cg_max_iters=NNCG_options["cgmaxiter"], + line_search_fn=NNCG_options["lsfun"], + verbose=NNCG_options["verbose"], + ) else: if learning_rate is None: raise ValueError("No learning rate for {}.".format(optimizer)) From 1b13a0889bfb8d0e4b55fb5081a57c94de32f9f7 Mon Sep 17 00:00:00 2001 From: pratikrathore8 <76628577+pratikrathore8@users.noreply.github.com> Date: Fri, 15 Mar 2024 12:52:03 -0700 Subject: [PATCH 10/26] Fixed import order in optimizers.py --- deepxde/optimizers/pytorch/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepxde/optimizers/pytorch/optimizers.py b/deepxde/optimizers/pytorch/optimizers.py index 777662578..8bf795559 100644 --- a/deepxde/optimizers/pytorch/optimizers.py +++ b/deepxde/optimizers/pytorch/optimizers.py @@ -2,8 +2,8 @@ import torch -from ..config import LBFGS_options, NNCG_options from .nncg import NNCG +from ..config import LBFGS_options, NNCG_options def is_external_optimizer(optimizer): From 2d63ba30326d18d3e23e453f62510e58a3891205 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 8 Apr 2024 16:47:36 -0700 Subject: [PATCH 11/26] Made demo with NNCG and Burgers equation --- deepxde/model.py | 74 ++++++++++++++++++++++++++- deepxde/optimizers/__init__.py | 2 +- deepxde/optimizers/config.py | 5 ++ examples/pinn_forward/Burgers_NNCG.py | 67 ++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 examples/pinn_forward/Burgers_NNCG.py diff --git a/deepxde/model.py b/deepxde/model.py index 4ebdf6859..dfc0f68ae 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -363,11 +363,27 @@ def closure(): if self.lr_scheduler is not None: self.lr_scheduler.step() + def train_step_nncg(inputs, targets, auxiliary_vars): + def closure(): + return get_loss_grad_nncg(inputs, targets, auxiliary_vars) + + self.opt.step(closure) + + def get_loss_grad_nncg(inputs, targets, auxiliary_vars): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + grad_tuple = torch.autograd.grad(total_loss, trainable_variables, + create_graph=True) + return total_loss, grad_tuple + # Callables self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test self.train_step = train_step + self.train_step_nncg = train_step_nncg + self.get_loss_grad_nncg = get_loss_grad_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" @@ -636,12 +652,22 @@ def train( self._test() self.callbacks.on_train_begin() if optimizers.is_external_optimizer(self.opt_name): + if self.opt_name == "NNCG" and backend_name != "pytorch": + raise ValueError( + "The optimizer 'NNCG' is only supported for the backend PyTorch." + ) if backend_name == "tensorflow.compat.v1": self._train_tensorflow_compat_v1_scipy(display_every) elif backend_name == "tensorflow": self._train_tensorflow_tfp() elif backend_name == "pytorch": - self._train_pytorch_lbfgs() + if self.opt_name == "L-BFGS": + self._train_pytorch_lbfgs() + elif self.opt_name == "NNCG": + self._train_pytorch_nncg(iterations, display_every) + else: + raise ValueError("Only 'L-BFGS' and 'NNCG' are supported as \ + external optimizers for PyTorch.") elif backend_name == "paddle": self._train_paddle_lbfgs() else: @@ -785,6 +811,52 @@ def _train_pytorch_lbfgs(self): if self.stop_training: break + def _train_pytorch_nncg(self, iterations, display_every): + # Loop over the iterations -- take inspiration from _train_pytorch_lbfgs and _train_sgd + for i in range(iterations): + # 1. Perform appropriate begin callbacks + self.callbacks.on_epoch_begin() + self.callbacks.on_batch_begin() + + # 2. Update the preconditioner (if applicable) + # 2.1. We can check if the preconditioner is updated by making an + # option in NNCG_options called update_freq. Do the usual modular arithmetic + # from there + if i % optimizers.NNCG_options["updatefreq"] == 0: + self.opt.zero_grad() + # 2.2. How do we actually do this? Get the sum of the losses as in + # train_step(), and use torch.autograd.grad to get a gradient + _, grad_tuple = self.get_loss_grad_nncg( + self.train_state.X_train, + self.train_state.y_train, + self.train_state.train_aux_vars, + ) + # 2.3. Plug the gradient into the NNCG update_preconditioner function + # to perform the update + self.opt.update_preconditioner(grad_tuple) + + # 3. Call the train step + self.train_step_nncg( + self.train_state.X_train, + self.train_state.y_train, + self.train_state.train_aux_vars, + ) + + # 4. Use self._test() if needed + self.train_state.epoch += 1 + self.train_state.step += 1 + if self.train_state.step % display_every == 0 or i + 1 == iterations: + self._test() + + # 5. Perform appropriate end callbacks + self.callbacks.on_batch_end() + self.callbacks.on_epoch_end() + + # 6. Allow for training to stop (if self.stop_training) + if self.stop_training: + break + + def _train_paddle_lbfgs(self): prev_n_iter = 0 diff --git a/deepxde/optimizers/__init__.py b/deepxde/optimizers/__init__.py index e1fcfced1..556761a86 100644 --- a/deepxde/optimizers/__init__.py +++ b/deepxde/optimizers/__init__.py @@ -1,7 +1,7 @@ import importlib import sys -from .config import LBFGS_options, set_LBFGS_options +from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options from ..backend import backend_name diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 2a1504d7f..285c0fdc1 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -64,6 +64,7 @@ def set_NNCG_options( lr=1, rank=10, mu=1e-4, + updatefreq=20, chunksz=1, cgtol=1e-16, cgmaxiter=1000, @@ -79,6 +80,9 @@ def set_NNCG_options( Rank of preconditioner matrix used in preconditioned conjugate gradient. mu (float): `mu` (torch). Hessian damping parameter. + updatefreq (int): How often the preconditioner matrix in preconditioned + conjugate gradient is updated. This parameter is not directly used in NNCG, + instead it is used in _train_pytorch_nncg in deepxde/model.py. chunksz (int): `chunk_size` (torch). Number of Hessian-vector products to compute in parallel when constructing preconditioner. If `chunk_size` is 1, the Hessian-vector products are @@ -99,6 +103,7 @@ def set_NNCG_options( NNCG_options["lr"] = lr NNCG_options["rank"] = rank NNCG_options["mu"] = mu + NNCG_options["updatefreq"] = updatefreq NNCG_options["chunksz"] = chunksz NNCG_options["cgtol"] = cgtol NNCG_options["cgmaxiter"] = cgmaxiter diff --git a/examples/pinn_forward/Burgers_NNCG.py b/examples/pinn_forward/Burgers_NNCG.py new file mode 100644 index 000000000..d95eb1d22 --- /dev/null +++ b/examples/pinn_forward/Burgers_NNCG.py @@ -0,0 +1,67 @@ +"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" +import deepxde as dde +import numpy as np + + +def gen_testdata(): + data = np.load("../dataset/Burgers.npz") + t, x, exact = data["t"], data["x"], data["usol"].T + xx, tt = np.meshgrid(x, t) + X = np.vstack((np.ravel(xx), np.ravel(tt))).T + y = exact.flatten()[:, None] + return X, y + + +def pde(x, y): + dy_x = dde.grad.jacobian(y, x, i=0, j=0) + dy_t = dde.grad.jacobian(y, x, i=0, j=1) + dy_xx = dde.grad.hessian(y, x, i=0, j=0) + return dy_t + y * dy_x - 0.01 / np.pi * dy_xx + + +geom = dde.geometry.Interval(-1, 1) +timedomain = dde.geometry.TimeDomain(0, 0.99) +geomtime = dde.geometry.GeometryXTime(geom, timedomain) + +bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary) +ic = dde.icbc.IC( + geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial +) + +data = dde.data.TimePDE( + geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160 +) +net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal") +model = dde.Model(data, net) + +# Run Adam+L-BFGS +model.compile("adam", lr=1e-3) +model.train(iterations=15000) + +model.compile("L-BFGS") +losshistory, train_state = model.train() +dde.saveplot(losshistory, train_state, issave=True, isplot=True) + +# Get test data +X, y_true = gen_testdata() + +# Get the results after running Adam+L-BFGS +y_pred = model.predict(X) +f = model.predict(X, operator=pde) +print("Mean residual after Adam+L-BFGS:", np.mean(np.absolute(f))) +print("L2 relative error after Adam+L-BFGS:", dde.metrics.l2_relative_error(y_true, y_pred)) +np.savetxt("test_adam_lbfgs.dat", np.hstack((X, y_true, y_pred))) + +# Run NNCG after Adam+L-BFGS +dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) +model.compile("NNCG") +losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) +dde.saveplot(losshistory_nncg, train_state_nncg, issave=True, isplot=True) + +# Get the final results after running Adam+L-BFGS+NNCG +y_pred = model.predict(X) +f = model.predict(X, operator=pde) +print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f))) +print("L2 relative error after Adam+L-BFGS+NNCG:", + dde.metrics.l2_relative_error(y_true, y_pred)) +np.savetxt("test_adam_lbfgs_nncg.dat", np.hstack((X, y_true, y_pred))) From 8d0210ade7bf5ec79b2a2165105f49ee0e9d57b1 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 28 Oct 2024 22:53:27 -0700 Subject: [PATCH 12/26] refactor nncg integration --- deepxde/model.py | 58 +++++++------------- deepxde/optimizers/pytorch/nncg.py | 67 ++++++++++++++---------- deepxde/optimizers/pytorch/optimizers.py | 1 + 3 files changed, 59 insertions(+), 67 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index 8f7c9e7f7..33ddcac2f 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -353,39 +353,40 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): "backend pytorch." ) - def train_step(inputs, targets, auxiliary_vars): + def train_step(inputs, targets, auxiliary_vars, perform_backward=True): def closure(): losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] total_loss = torch.sum(losses) self.opt.zero_grad() - total_loss.backward() + if perform_backward: + total_loss.backward() return total_loss self.opt.step(closure) if self.lr_scheduler is not None: self.lr_scheduler.step() - def train_step_nncg(inputs, targets, auxiliary_vars): - def closure(): - return get_loss_grad_nncg(inputs, targets, auxiliary_vars) + # def train_step_nncg(inputs, targets, auxiliary_vars): + # def closure(): + # return get_loss_grad_nncg(inputs, targets, auxiliary_vars) - self.opt.step(closure) + # self.opt.step(closure) - def get_loss_grad_nncg(inputs, targets, auxiliary_vars): - losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] - total_loss = torch.sum(losses) - self.opt.zero_grad() - grad_tuple = torch.autograd.grad(total_loss, trainable_variables, - create_graph=True) - return total_loss, grad_tuple + # def get_loss_grad_nncg(inputs, targets, auxiliary_vars): + # losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + # total_loss = torch.sum(losses) + # self.opt.zero_grad() + # grad_tuple = torch.autograd.grad(total_loss, trainable_variables, + # create_graph=True) + # return total_loss, grad_tuple # Callables self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test self.train_step = train_step - self.train_step_nncg = train_step_nncg - self.get_loss_grad_nncg = get_loss_grad_nncg + # self.train_step_nncg = train_step_nncg + # self.get_loss_grad_nncg = get_loss_grad_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" @@ -819,47 +820,26 @@ def _train_pytorch_lbfgs(self): break def _train_pytorch_nncg(self, iterations, display_every): - # Loop over the iterations -- take inspiration from _train_pytorch_lbfgs and _train_sgd for i in range(iterations): - # 1. Perform appropriate begin callbacks self.callbacks.on_epoch_begin() self.callbacks.on_batch_begin() - # 2. Update the preconditioner (if applicable) - # 2.1. We can check if the preconditioner is updated by making an - # option in NNCG_options called update_freq. Do the usual modular arithmetic - # from there - if i % optimizers.NNCG_options["updatefreq"] == 0: - self.opt.zero_grad() - # 2.2. How do we actually do this? Get the sum of the losses as in - # train_step(), and use torch.autograd.grad to get a gradient - _, grad_tuple = self.get_loss_grad_nncg( - self.train_state.X_train, - self.train_state.y_train, - self.train_state.train_aux_vars, - ) - # 2.3. Plug the gradient into the NNCG update_preconditioner function - # to perform the update - self.opt.update_preconditioner(grad_tuple) - - # 3. Call the train step - self.train_step_nncg( + # The train step should only use full gradients, so we do not use self.train_state.set_data_train() + self.train_step( self.train_state.X_train, self.train_state.y_train, self.train_state.train_aux_vars, + perform_backward=False, ) - # 4. Use self._test() if needed self.train_state.epoch += 1 self.train_state.step += 1 if self.train_state.step % display_every == 0 or i + 1 == iterations: self._test() - # 5. Perform appropriate end callbacks self.callbacks.on_batch_end() self.callbacks.on_epoch_end() - # 6. Allow for training to stop (if self.stop_training) if self.stop_training: break diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py index a71e7da31..ba195bb51 100644 --- a/deepxde/optimizers/pytorch/nncg.py +++ b/deepxde/optimizers/pytorch/nncg.py @@ -92,6 +92,7 @@ class NNCG(Optimizer): lr (float, optional): learning rate (default: 1.0) rank (int, optional): rank of the Nyström approximation (default: 10) mu (float, optional): damping parameter (default: 1e-4) + update_freq (int, optional): frequency of updating the preconditioner chunk_size (int, optional): number of Hessian-vector products to be computed in parallel (default: 1) cg_tol (float, optional): tolerance for PCG (default: 1e-16) @@ -106,6 +107,7 @@ def __init__( lr=1.0, rank=10, mu=1e-4, + update_freq=20, chunk_size=1, cg_tol=1e-16, cg_max_iters=1000, @@ -115,14 +117,16 @@ def __init__( defaults = dict( lr=lr, rank=rank, - chunk_size=chunk_size, mu=mu, + update_freq=update_freq, + chunk_size=chunk_size, cg_tol=cg_tol, cg_max_iters=cg_max_iters, line_search_fn=line_search_fn, ) self.rank = rank self.mu = mu + self.update_freq = update_freq self.chunk_size = chunk_size self.cg_tol = cg_tol self.cg_max_iters = cg_max_iters @@ -146,26 +150,25 @@ def __init__( self._params_list = list(self._params) self._numel_cache = None - def step(self, closure=None): + def step(self, closure): """Perform a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model - and returns (i) the loss and (ii) gradient w.r.t. the parameters. - The closure can compute the gradient w.r.t. the parameters by - calling torch.autograd.grad on the loss with create_graph=True. + closure (callable): A closure that reevaluates the model + and returns the loss w.r.t. the parameters. """ if self.n_iters == 0: # Store the previous direction for warm starting PCG self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) - # NOTE: The closure must return both the loss and the gradient - loss = None - if closure is not None: - with torch.enable_grad(): - loss, grad_tuple = closure() + loss = closure() + # g = self._gather_flat_grad() + # Compute gradient via torch.autograd.grad + g_tuple = torch.autograd.grad(loss, self._params_list, create_graph=True) + g = torch.cat([gi.view(-1) for gi in g_tuple if gi is not None]) - g = torch.cat([grad.view(-1) for grad in grad_tuple if grad is not None]) + if self.n_iters % self.update_freq == 0: + self._update_preconditioner(g) # One step update for group_idx, group in enumerate(self.param_groups): @@ -198,7 +201,7 @@ def hvp_temp(x): def obj_func(x, t, dx): self._add_grad(t, dx) - loss = float(closure()[0]) + loss = float(closure()) self._set_param(x) return loss @@ -219,29 +222,20 @@ def obj_func(x, t, dx): self.n_iters += 1 - return loss, g + return loss - def update_preconditioner(self, grad_tuple): + def _update_preconditioner(self, grad): """Update the Nystrom approximation of the Hessian. Args: - grad_tuple (tuple): tuple of Tensors containing the gradients - of the loss w.r.t. the parameters. - This tuple can be obtained by calling torch.autograd.grad - on the loss with create_graph=True. + grad (torch.Tensor): gradient of the loss w.r.t. the parameters. """ - - # Flatten and concatenate the gradients - gradsH = torch.cat( - [gradient.view(-1) for gradient in grad_tuple if gradient is not None] - ) - # Generate test matrix (NOTE: This is transposed test matrix) - p = gradsH.shape[0] - Phi = torch.randn((self.rank, p), device=gradsH.device) / (p**0.5) + p = grad.shape[0] + Phi = torch.randn((self.rank, p), device=grad.device) / (p**0.5) Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t() - Y = self._hvp_vmap(gradsH, self._params_list)(Phi) + Y = self._hvp_vmap(grad, self._params_list)(Phi) # Calculate shift shift = torch.finfo(Y.dtype).eps @@ -304,6 +298,23 @@ def _numel(self): ) return self._numel_cache + # def _gather_flat_grad(self): + # """Gathers the gradients of the parameters in a single vector. + # Copied from torch.optim.lbfgs (https://pytorch.org/docs/stable/_modules/torch/optim/lbfgs.html#LBFGS). + # """ + # views = [] + # for p in self._params: + # if p.grad is None: + # view = p.new(p.numel()).zero_() + # elif p.grad.is_sparse: + # view = p.grad.to_dense().view(-1) + # else: + # view = p.grad.view(-1) + # if torch.is_complex(view): + # view = torch.view_as_real(view).view(-1) + # views.append(view) + # return torch.cat(views, 0) + def _add_grad(self, step_size, update): offset = 0 for p in self._params: diff --git a/deepxde/optimizers/pytorch/optimizers.py b/deepxde/optimizers/pytorch/optimizers.py index 8bf795559..35ab88d24 100644 --- a/deepxde/optimizers/pytorch/optimizers.py +++ b/deepxde/optimizers/pytorch/optimizers.py @@ -40,6 +40,7 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): lr=NNCG_options["lr"], rank=NNCG_options["rank"], mu=NNCG_options["mu"], + update_freq=NNCG_options["updatefreq"], chunk_size=NNCG_options["chunksz"], cg_tol=NNCG_options["cgtol"], cg_max_iters=NNCG_options["cgmaxiter"], From 5c9beba85a472b0a97d3e46afd45e083528d1bcc Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 28 Oct 2024 22:57:02 -0700 Subject: [PATCH 13/26] clean up commented code --- deepxde/model.py | 16 ---------------- deepxde/optimizers/pytorch/nncg.py | 20 +------------------- 2 files changed, 1 insertion(+), 35 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index 33ddcac2f..901a8a994 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -366,27 +366,11 @@ def closure(): if self.lr_scheduler is not None: self.lr_scheduler.step() - # def train_step_nncg(inputs, targets, auxiliary_vars): - # def closure(): - # return get_loss_grad_nncg(inputs, targets, auxiliary_vars) - - # self.opt.step(closure) - - # def get_loss_grad_nncg(inputs, targets, auxiliary_vars): - # losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] - # total_loss = torch.sum(losses) - # self.opt.zero_grad() - # grad_tuple = torch.autograd.grad(total_loss, trainable_variables, - # create_graph=True) - # return total_loss, grad_tuple - # Callables self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test self.train_step = train_step - # self.train_step_nncg = train_step_nncg - # self.get_loss_grad_nncg = get_loss_grad_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py index ba195bb51..b3eeccbbd 100644 --- a/deepxde/optimizers/pytorch/nncg.py +++ b/deepxde/optimizers/pytorch/nncg.py @@ -162,7 +162,6 @@ def step(self, closure): self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) loss = closure() - # g = self._gather_flat_grad() # Compute gradient via torch.autograd.grad g_tuple = torch.autograd.grad(loss, self._params_list, create_graph=True) g = torch.cat([gi.view(-1) for gi in g_tuple if gi is not None]) @@ -225,7 +224,7 @@ def obj_func(x, t, dx): return loss def _update_preconditioner(self, grad): - """Update the Nystrom approximation of the Hessian. + """Update the Nyström approximation of the Hessian. Args: grad (torch.Tensor): gradient of the loss w.r.t. the parameters. @@ -298,23 +297,6 @@ def _numel(self): ) return self._numel_cache - # def _gather_flat_grad(self): - # """Gathers the gradients of the parameters in a single vector. - # Copied from torch.optim.lbfgs (https://pytorch.org/docs/stable/_modules/torch/optim/lbfgs.html#LBFGS). - # """ - # views = [] - # for p in self._params: - # if p.grad is None: - # view = p.new(p.numel()).zero_() - # elif p.grad.is_sparse: - # view = p.grad.to_dense().view(-1) - # else: - # view = p.grad.view(-1) - # if torch.is_complex(view): - # view = torch.view_as_real(view).view(-1) - # views.append(view) - # return torch.cat(views, 0) - def _add_grad(self, step_size, update): offset = 0 for p in self._params: From ce1dcc89668f65e9a72973c4ad3df4f0a0d852dc Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 28 Oct 2024 23:09:20 -0700 Subject: [PATCH 14/26] format with black --- deepxde/model.py | 7 ++++--- deepxde/optimizers/config.py | 8 +++++--- deepxde/optimizers/pytorch/nncg.py | 4 ++-- examples/pinn_forward/Burgers_NNCG.py | 12 +++++++++--- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index 901a8a994..f23639f16 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -658,8 +658,10 @@ def train( elif self.opt_name == "NNCG": self._train_pytorch_nncg(iterations, display_every) else: - raise ValueError("Only 'L-BFGS' and 'NNCG' are supported as \ - external optimizers for PyTorch.") + raise ValueError( + "Only 'L-BFGS' and 'NNCG' are supported as \ + external optimizers for PyTorch." + ) elif backend_name == "paddle": self._train_paddle_lbfgs() else: @@ -827,7 +829,6 @@ def _train_pytorch_nncg(self, iterations, display_every): if self.stop_training: break - def _train_paddle_lbfgs(self): prev_n_iter = 0 diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 285c0fdc1..537d2f4c4 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -60,6 +60,7 @@ def set_LBFGS_options( LBFGS_options["maxfun"] = maxfun if maxfun is not None else int(maxiter * 1.25) LBFGS_options["maxls"] = maxls + def set_NNCG_options( lr=1, rank=10, @@ -69,7 +70,7 @@ def set_NNCG_options( cgtol=1e-16, cgmaxiter=1000, lsfun="armijo", - verbose=False + verbose=False, ): """Sets the hyperparameters of NysNewtonCG (NNCG). @@ -80,11 +81,11 @@ def set_NNCG_options( Rank of preconditioner matrix used in preconditioned conjugate gradient. mu (float): `mu` (torch). Hessian damping parameter. - updatefreq (int): How often the preconditioner matrix in preconditioned + updatefreq (int): How often the preconditioner matrix in preconditioned conjugate gradient is updated. This parameter is not directly used in NNCG, instead it is used in _train_pytorch_nncg in deepxde/model.py. chunksz (int): `chunk_size` (torch). - Number of Hessian-vector products to compute in parallel when constructing + Number of Hessian-vector products to compute in parallel when constructing preconditioner. If `chunk_size` is 1, the Hessian-vector products are computed serially. cgtol (float): `cg_tol` (torch). @@ -110,6 +111,7 @@ def set_NNCG_options( NNCG_options["lsfun"] = lsfun NNCG_options["verbose"] = verbose + def set_hvd_opt_options( compression=None, op=None, diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py index b3eeccbbd..76c56941c 100644 --- a/deepxde/optimizers/pytorch/nncg.py +++ b/deepxde/optimizers/pytorch/nncg.py @@ -81,7 +81,7 @@ class NNCG(Optimizer): The parameters rank and mu will probably need to be tuned for your specific problem. If the optimizer is running very slowly, you can try one of the following: - - Increase the rank (this should increase the + - Increase the rank (this should increase the accuracy of the Nyström approximation in PCG) - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) @@ -155,7 +155,7 @@ def step(self, closure): Args: closure (callable): A closure that reevaluates the model - and returns the loss w.r.t. the parameters. + and returns the loss w.r.t. the parameters. """ if self.n_iters == 0: # Store the previous direction for warm starting PCG diff --git a/examples/pinn_forward/Burgers_NNCG.py b/examples/pinn_forward/Burgers_NNCG.py index d95eb1d22..ba9486b14 100644 --- a/examples/pinn_forward/Burgers_NNCG.py +++ b/examples/pinn_forward/Burgers_NNCG.py @@ -1,4 +1,5 @@ """Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" + import deepxde as dde import numpy as np @@ -49,7 +50,10 @@ def pde(x, y): y_pred = model.predict(X) f = model.predict(X, operator=pde) print("Mean residual after Adam+L-BFGS:", np.mean(np.absolute(f))) -print("L2 relative error after Adam+L-BFGS:", dde.metrics.l2_relative_error(y_true, y_pred)) +print( + "L2 relative error after Adam+L-BFGS:", + dde.metrics.l2_relative_error(y_true, y_pred), +) np.savetxt("test_adam_lbfgs.dat", np.hstack((X, y_true, y_pred))) # Run NNCG after Adam+L-BFGS @@ -62,6 +66,8 @@ def pde(x, y): y_pred = model.predict(X) f = model.predict(X, operator=pde) print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f))) -print("L2 relative error after Adam+L-BFGS+NNCG:", - dde.metrics.l2_relative_error(y_true, y_pred)) +print( + "L2 relative error after Adam+L-BFGS+NNCG:", + dde.metrics.l2_relative_error(y_true, y_pred), +) np.savetxt("test_adam_lbfgs_nncg.dat", np.hstack((X, y_true, y_pred))) From d97ca16a40df5dc19d5211b388b346fcdc3d0e85 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Tue, 29 Oct 2024 15:30:45 -0700 Subject: [PATCH 15/26] remove unnecessary error checks --- deepxde/model.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index f23639f16..8f50b92f5 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -644,10 +644,6 @@ def train( self._test() self.callbacks.on_train_begin() if optimizers.is_external_optimizer(self.opt_name): - if self.opt_name == "NNCG" and backend_name != "pytorch": - raise ValueError( - "The optimizer 'NNCG' is only supported for the backend PyTorch." - ) if backend_name == "tensorflow.compat.v1": self._train_tensorflow_compat_v1_scipy(display_every) elif backend_name == "tensorflow": @@ -657,11 +653,6 @@ def train( self._train_pytorch_lbfgs() elif self.opt_name == "NNCG": self._train_pytorch_nncg(iterations, display_every) - else: - raise ValueError( - "Only 'L-BFGS' and 'NNCG' are supported as \ - external optimizers for PyTorch." - ) elif backend_name == "paddle": self._train_paddle_lbfgs() else: From 81c4452ad8d060d7effe81b5e96da8459a382bb5 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Tue, 29 Oct 2024 15:50:48 -0700 Subject: [PATCH 16/26] fix some codacy issues in nncg --- deepxde/optimizers/pytorch/nncg.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py index 76c56941c..961260986 100644 --- a/deepxde/optimizers/pytorch/nncg.py +++ b/deepxde/optimizers/pytorch/nncg.py @@ -114,16 +114,16 @@ def __init__( line_search_fn=None, verbose=False, ): - defaults = dict( - lr=lr, - rank=rank, - mu=mu, - update_freq=update_freq, - chunk_size=chunk_size, - cg_tol=cg_tol, - cg_max_iters=cg_max_iters, - line_search_fn=line_search_fn, - ) + defaults = { + "lr": lr, + "rank": rank, + "mu": mu, + "update_freq": update_freq, + "chunk_size": chunk_size, + "cg_tol": cg_tol, + "cg_max_iters": cg_max_iters, + "line_search_fn": line_search_fn, + } self.rank = rank self.mu = mu self.update_freq = update_freq @@ -135,7 +135,7 @@ def __init__( self.U = None self.S = None self.n_iters = 0 - super(NNCG, self).__init__(params, defaults) + super().__init__(params, defaults) if len(self.param_groups) > 1: raise ValueError( From 60d3ff33f6bc0e2fa3383424780653b0d2b6cd8d Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Tue, 29 Oct 2024 21:08:48 -0700 Subject: [PATCH 17/26] further improvements to nncg integration --- deepxde/model.py | 4 +++- examples/pinn_forward/Burgers_NNCG.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index 8f50b92f5..feefe7a9f 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -801,7 +801,9 @@ def _train_pytorch_nncg(self, iterations, display_every): self.callbacks.on_epoch_begin() self.callbacks.on_batch_begin() - # The train step should only use full gradients, so we do not use self.train_state.set_data_train() + self.train_state.set_data_train( + *self.data.train_next_batch(self.batch_size) + ) self.train_step( self.train_state.X_train, self.train_state.y_train, diff --git a/examples/pinn_forward/Burgers_NNCG.py b/examples/pinn_forward/Burgers_NNCG.py index ba9486b14..cb0741e9b 100644 --- a/examples/pinn_forward/Burgers_NNCG.py +++ b/examples/pinn_forward/Burgers_NNCG.py @@ -1,4 +1,4 @@ -"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" +"""Backend supported: pytorch""" import deepxde as dde import numpy as np From 6e739e10f8509b03011f18d177d9cad574601f7d Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Thu, 31 Oct 2024 13:17:20 -0700 Subject: [PATCH 18/26] add train_step_nncg --- deepxde/model.py | 46 +++++++++++++++------------------------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index feefe7a9f..34a58c027 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -353,13 +353,23 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): "backend pytorch." ) - def train_step(inputs, targets, auxiliary_vars, perform_backward=True): + def train_step(inputs, targets, auxiliary_vars): + def closure(): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + total_loss.backward() + return total_loss + + self.opt.step(closure) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + def train_step_nncg(inputs, targets, auxiliary_vars): def closure(): losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] total_loss = torch.sum(losses) self.opt.zero_grad() - if perform_backward: - total_loss.backward() return total_loss self.opt.step(closure) @@ -370,7 +380,7 @@ def closure(): self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test - self.train_step = train_step + self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" @@ -652,7 +662,7 @@ def train( if self.opt_name == "L-BFGS": self._train_pytorch_lbfgs() elif self.opt_name == "NNCG": - self._train_pytorch_nncg(iterations, display_every) + self._train_sgd(iterations, display_every) elif backend_name == "paddle": self._train_paddle_lbfgs() else: @@ -796,32 +806,6 @@ def _train_pytorch_lbfgs(self): if self.stop_training: break - def _train_pytorch_nncg(self, iterations, display_every): - for i in range(iterations): - self.callbacks.on_epoch_begin() - self.callbacks.on_batch_begin() - - self.train_state.set_data_train( - *self.data.train_next_batch(self.batch_size) - ) - self.train_step( - self.train_state.X_train, - self.train_state.y_train, - self.train_state.train_aux_vars, - perform_backward=False, - ) - - self.train_state.epoch += 1 - self.train_state.step += 1 - if self.train_state.step % display_every == 0 or i + 1 == iterations: - self._test() - - self.callbacks.on_batch_end() - self.callbacks.on_epoch_end() - - if self.stop_training: - break - def _train_paddle_lbfgs(self): prev_n_iter = 0 From f4e322e22fdf1444db0f3c8c2511d547b2d314d0 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 4 Nov 2024 11:27:01 -0800 Subject: [PATCH 19/26] improve documentation in nncg config --- deepxde/optimizers/config.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 537d2f4c4..7ffb6a8d1 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -74,30 +74,33 @@ def set_NNCG_options( ): """Sets the hyperparameters of NysNewtonCG (NNCG). + The NNCG optimizer only supports PyTorch. + Args: - lr (float): `lr` (torch). + lr (float): Learning rate (before line search). - rank (int): `rank` (torch). + rank (int): Rank of preconditioner matrix used in preconditioned conjugate gradient. - mu (float): `mu` (torch). + mu (float): Hessian damping parameter. - updatefreq (int): How often the preconditioner matrix in preconditioned + updatefreq (int): + How often the preconditioner matrix in preconditioned conjugate gradient is updated. This parameter is not directly used in NNCG, instead it is used in _train_pytorch_nncg in deepxde/model.py. - chunksz (int): `chunk_size` (torch). + chunksz (int): Number of Hessian-vector products to compute in parallel when constructing preconditioner. If `chunk_size` is 1, the Hessian-vector products are computed serially. - cgtol (float): `cg_tol` (torch). + cgtol (float): Convergence tolerance for the conjugate gradient method. The iteration stops when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition is based on the absolute tolerance, not the relative tolerance. - cgmaxiter (int): `cg_max_iters` (torch). + cgmaxiter (int): Maximum number of iterations for the conjugate gradient method. - lsfun (str): `line_search_fn` (torch). + lsfun (str): The line search function used to find the step size. The default value is "armijo". The other option is None. - verbose (bool): `verbose` (torch). + verbose (bool): If `True`, prints the eigenvalues of the Nyström approximation of the Hessian. """ From 6338f596cc15c1712aabc34e9dbdca11687908ba Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 4 Nov 2024 18:36:23 -0800 Subject: [PATCH 20/26] added doc for nncg demo --- deepxde/optimizers/config.py | 2 +- docs/demos/pinn_forward/burgers.nncg.rst | 105 +++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 docs/demos/pinn_forward/burgers.nncg.rst diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 7ffb6a8d1..85830f72c 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -83,7 +83,7 @@ def set_NNCG_options( Rank of preconditioner matrix used in preconditioned conjugate gradient. mu (float): Hessian damping parameter. - updatefreq (int): + updatefreq (int): How often the preconditioner matrix in preconditioned conjugate gradient is updated. This parameter is not directly used in NNCG, instead it is used in _train_pytorch_nncg in deepxde/model.py. diff --git a/docs/demos/pinn_forward/burgers.nncg.rst b/docs/demos/pinn_forward/burgers.nncg.rst new file mode 100644 index 000000000..a31499791 --- /dev/null +++ b/docs/demos/pinn_forward/burgers.nncg.rst @@ -0,0 +1,105 @@ +Burgers equation with NNCG optimizer (PyTorch only) +================ + +Problem setup +-------------- + +We will solve a Burgers equation: + +.. math:: \frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} = \nu\frac{\partial^2u}{\partial x^2}, \qquad x \in [-1, 1], \quad t \in [0, 1] + +with the Dirichlet boundary conditions and initial conditions + +.. math:: u(-1,t)=u(1,t)=0, \quad u(x,0) = - \sin(\pi x). + +The reference solution is `here `_. + +Implementation +-------------- + +This description goes through the implementation of a solver for the above described Burgers equation step-by-step. + +First, the DeepXDE modules is imported: + +.. code-block:: python + + import deepxde as dde + +We begin by defining a computational geometry and time domain. We can use a built-in class ``Interval`` and ``TimeDomain`` and we combine both the domains using ``GeometryXTime`` as follows + +.. code-block:: python + + geom = dde.geometry.Interval(-1, 1) + timedomain = dde.geometry.TimeDomain(0, 0.99) + geomtime = dde.geometry.GeometryXTime(geom, timedomain) + +Next, we express the PDE residual of the Burgers equation: + +.. code-block:: python + + def pde(x, y): + dy_x = dde.grad.jacobian(y, x, i=0, j=0) + dy_t = dde.grad.jacobian(y, x, i=0, j=1) + dy_xx = dde.grad.hessian(y, x, i=0, j=0) + return dy_t + y * dy_x - 0.01 / np.pi * dy_xx + +The first argument to ``pde`` is 2-dimensional vector where the first component(``x[:,0]``) is :math:`x`-coordinate and the second componenet (``x[:,1]``) is the :math:`t`-coordinate. The second argument is the network output, i.e., the solution :math:`u(x,t)`, but here we use ``y`` as the name of the variable. + +Next, we consider the boundary/initial condition. ``on_boundary`` is chosen here to use the whole boundary of the computational domain in considered as the boundary condition. We include the ``geomtime`` space, time geometry created above and ``on_boundary`` as the BCs in the ``DirichletBC`` function of DeepXDE. We also define ``IC`` which is the inital condition for the burgers equation and we use the computational domain, initial function, and ``on_initial`` to specify the IC. + +.. code-block:: python + + bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary) + ic = dde.icbc.IC(geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial) + +Now, we have specified the geometry, PDE residual, and boundary/initial condition. We then define the ``TimePDE`` problem as + +.. code-block:: python + + data = dde.data.TimePDE(geomtime, pde, [bc, ic], + num_domain=2540, num_boundary=80, num_initial=160) + +The number 2540 is the number of training residual points sampled inside the domain, and the number 80 is the number of training points sampled on the boundary. We also include 160 initial residual points for the initial conditions. + +Next, we choose the network. Here, we use a fully connected neural network of depth 4 (i.e., 3 hidden layers) and width 20: + +.. code-block:: python + + net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal") + +Now, we have the PDE problem and the network. We build a ``Model`` and choose the optimizer and learning rate: + +.. code-block:: python + + model = dde.Model(data, net) + model.compile("adam", lr=1e-3) + + +We then train the model for 15000 iterations: + +.. code-block:: python + + losshistory, train_state = model.train(iterations=15000) + +After we train the network using Adam, we continue to train the network using L-BFGS to achieve a smaller loss: + +.. code-block:: python + + model.compile("L-BFGS-B") + losshistory, train_state = model.train() + +However, L-BFGS can stall out early in optimization if it is unable to find a step size satisfying the strong Wolfe conditions. In such cases, we can use the NNCG optimizer to continue decreasing the loss: + +.. code-block:: python + + dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) + model.compile("NNCG") + losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) + +Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer. + +Complete code +-------------- + +.. literalinclude:: ../../../examples/pinn_forward/Burgers_NNCG.py + :language: python From 838eaa8b218ed628354c9cb3dc973ab295dfa41f Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 4 Nov 2024 18:38:54 -0800 Subject: [PATCH 21/26] added demo file to pinn_forward.rst --- docs/demos/pinn_forward.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/demos/pinn_forward.rst b/docs/demos/pinn_forward.rst index 0205c5eaa..fd8ccb661 100644 --- a/docs/demos/pinn_forward.rst +++ b/docs/demos/pinn_forward.rst @@ -52,6 +52,7 @@ Time-dependent PDEs pinn_forward/burgers.rar pinn_forward/allen.cahn pinn_forward/klein.gordon + pinn_forward/burgers.nncg - `Beltrami flow `_ - `Wave propagation with spatio-temporal multi-scale Fourier feature architecture `_ From 7d583175f44bd27f4acb93e09f9612295107e370 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Wed, 13 Nov 2024 12:59:00 -0800 Subject: [PATCH 22/26] change ordering in pinn_forward.rst --- docs/demos/pinn_forward.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/demos/pinn_forward.rst b/docs/demos/pinn_forward.rst index fd8ccb661..791e1516b 100644 --- a/docs/demos/pinn_forward.rst +++ b/docs/demos/pinn_forward.rst @@ -43,6 +43,7 @@ Time-dependent PDEs :maxdepth: 1 pinn_forward/burgers + pinn_forward/burgers.nncg pinn_forward/heat pinn_forward/heat.resample pinn_forward/diffusion.1d @@ -52,7 +53,6 @@ Time-dependent PDEs pinn_forward/burgers.rar pinn_forward/allen.cahn pinn_forward/klein.gordon - pinn_forward/burgers.nncg - `Beltrami flow `_ - `Wave propagation with spatio-temporal multi-scale Fourier feature architecture `_ From 54675245268e7a930c94196ef772a5a4dd21e378 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 18 Nov 2024 12:56:08 -0800 Subject: [PATCH 23/26] merge burgers_nncg demo into burgers demo --- docs/demos/pinn_forward.rst | 1 - docs/demos/pinn_forward/burgers.nncg.rst | 105 ----------------------- docs/demos/pinn_forward/burgers.rst | 12 ++- examples/pinn_forward/Burgers.py | 18 ++++ examples/pinn_forward/Burgers_NNCG.py | 73 ---------------- 5 files changed, 29 insertions(+), 180 deletions(-) delete mode 100644 docs/demos/pinn_forward/burgers.nncg.rst delete mode 100644 examples/pinn_forward/Burgers_NNCG.py diff --git a/docs/demos/pinn_forward.rst b/docs/demos/pinn_forward.rst index 791e1516b..0205c5eaa 100644 --- a/docs/demos/pinn_forward.rst +++ b/docs/demos/pinn_forward.rst @@ -43,7 +43,6 @@ Time-dependent PDEs :maxdepth: 1 pinn_forward/burgers - pinn_forward/burgers.nncg pinn_forward/heat pinn_forward/heat.resample pinn_forward/diffusion.1d diff --git a/docs/demos/pinn_forward/burgers.nncg.rst b/docs/demos/pinn_forward/burgers.nncg.rst deleted file mode 100644 index a31499791..000000000 --- a/docs/demos/pinn_forward/burgers.nncg.rst +++ /dev/null @@ -1,105 +0,0 @@ -Burgers equation with NNCG optimizer (PyTorch only) -================ - -Problem setup --------------- - -We will solve a Burgers equation: - -.. math:: \frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} = \nu\frac{\partial^2u}{\partial x^2}, \qquad x \in [-1, 1], \quad t \in [0, 1] - -with the Dirichlet boundary conditions and initial conditions - -.. math:: u(-1,t)=u(1,t)=0, \quad u(x,0) = - \sin(\pi x). - -The reference solution is `here `_. - -Implementation --------------- - -This description goes through the implementation of a solver for the above described Burgers equation step-by-step. - -First, the DeepXDE modules is imported: - -.. code-block:: python - - import deepxde as dde - -We begin by defining a computational geometry and time domain. We can use a built-in class ``Interval`` and ``TimeDomain`` and we combine both the domains using ``GeometryXTime`` as follows - -.. code-block:: python - - geom = dde.geometry.Interval(-1, 1) - timedomain = dde.geometry.TimeDomain(0, 0.99) - geomtime = dde.geometry.GeometryXTime(geom, timedomain) - -Next, we express the PDE residual of the Burgers equation: - -.. code-block:: python - - def pde(x, y): - dy_x = dde.grad.jacobian(y, x, i=0, j=0) - dy_t = dde.grad.jacobian(y, x, i=0, j=1) - dy_xx = dde.grad.hessian(y, x, i=0, j=0) - return dy_t + y * dy_x - 0.01 / np.pi * dy_xx - -The first argument to ``pde`` is 2-dimensional vector where the first component(``x[:,0]``) is :math:`x`-coordinate and the second componenet (``x[:,1]``) is the :math:`t`-coordinate. The second argument is the network output, i.e., the solution :math:`u(x,t)`, but here we use ``y`` as the name of the variable. - -Next, we consider the boundary/initial condition. ``on_boundary`` is chosen here to use the whole boundary of the computational domain in considered as the boundary condition. We include the ``geomtime`` space, time geometry created above and ``on_boundary`` as the BCs in the ``DirichletBC`` function of DeepXDE. We also define ``IC`` which is the inital condition for the burgers equation and we use the computational domain, initial function, and ``on_initial`` to specify the IC. - -.. code-block:: python - - bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary) - ic = dde.icbc.IC(geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial) - -Now, we have specified the geometry, PDE residual, and boundary/initial condition. We then define the ``TimePDE`` problem as - -.. code-block:: python - - data = dde.data.TimePDE(geomtime, pde, [bc, ic], - num_domain=2540, num_boundary=80, num_initial=160) - -The number 2540 is the number of training residual points sampled inside the domain, and the number 80 is the number of training points sampled on the boundary. We also include 160 initial residual points for the initial conditions. - -Next, we choose the network. Here, we use a fully connected neural network of depth 4 (i.e., 3 hidden layers) and width 20: - -.. code-block:: python - - net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal") - -Now, we have the PDE problem and the network. We build a ``Model`` and choose the optimizer and learning rate: - -.. code-block:: python - - model = dde.Model(data, net) - model.compile("adam", lr=1e-3) - - -We then train the model for 15000 iterations: - -.. code-block:: python - - losshistory, train_state = model.train(iterations=15000) - -After we train the network using Adam, we continue to train the network using L-BFGS to achieve a smaller loss: - -.. code-block:: python - - model.compile("L-BFGS-B") - losshistory, train_state = model.train() - -However, L-BFGS can stall out early in optimization if it is unable to find a step size satisfying the strong Wolfe conditions. In such cases, we can use the NNCG optimizer to continue decreasing the loss: - -.. code-block:: python - - dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) - model.compile("NNCG") - losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) - -Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer. - -Complete code --------------- - -.. literalinclude:: ../../../examples/pinn_forward/Burgers_NNCG.py - :language: python diff --git a/docs/demos/pinn_forward/burgers.rst b/docs/demos/pinn_forward/burgers.rst index 76ded3e31..3f931f866 100644 --- a/docs/demos/pinn_forward/burgers.rst +++ b/docs/demos/pinn_forward/burgers.rst @@ -87,7 +87,17 @@ After we train the network using Adam, we continue to train the network using L- .. code-block:: python model.compile("L-BFGS-B") - losshistory, train_state = model.train() + losshistory, train_state = model.train() + +However, L-BFGS can stall out early in optimization if it is unable to find a step size satisfying the strong Wolfe conditions. In such cases, we can use the NNCG optimizer (compatible with PyTorch only) to continue reducing the loss: + +.. code-block:: python + + dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) + model.compile("NNCG") + losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) + +By default, NNCG does not run in this demo. You will have to uncomment the NNCG code block at the end of the demo to have it run after Adam and L-BFGS. Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer. Complete code -------------- diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index 64ee46bb6..3ab31a05f 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -1,4 +1,5 @@ """Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" + import deepxde as dde import numpy as np @@ -46,3 +47,20 @@ def pde(x, y): print("Mean residual:", np.mean(np.absolute(f))) print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred)) np.savetxt("test.dat", np.hstack((X, y_true, y_pred))) + +# """Backend supported: pytorch""" +# # Run NNCG after Adam and L-BFGS +# dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) +# model.compile("NNCG") +# losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) +# dde.saveplot(losshistory_nncg, train_state_nncg, issave=True, isplot=True) + +# # Get the final results after running Adam+L-BFGS+NNCG +# y_pred = model.predict(X) +# f = model.predict(X, operator=pde) +# print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f))) +# print( +# "L2 relative error after Adam+L-BFGS+NNCG:", +# dde.metrics.l2_relative_error(y_true, y_pred), +# ) +# np.savetxt("test_nncg.dat", np.hstack((X, y_true, y_pred))) diff --git a/examples/pinn_forward/Burgers_NNCG.py b/examples/pinn_forward/Burgers_NNCG.py deleted file mode 100644 index cb0741e9b..000000000 --- a/examples/pinn_forward/Burgers_NNCG.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Backend supported: pytorch""" - -import deepxde as dde -import numpy as np - - -def gen_testdata(): - data = np.load("../dataset/Burgers.npz") - t, x, exact = data["t"], data["x"], data["usol"].T - xx, tt = np.meshgrid(x, t) - X = np.vstack((np.ravel(xx), np.ravel(tt))).T - y = exact.flatten()[:, None] - return X, y - - -def pde(x, y): - dy_x = dde.grad.jacobian(y, x, i=0, j=0) - dy_t = dde.grad.jacobian(y, x, i=0, j=1) - dy_xx = dde.grad.hessian(y, x, i=0, j=0) - return dy_t + y * dy_x - 0.01 / np.pi * dy_xx - - -geom = dde.geometry.Interval(-1, 1) -timedomain = dde.geometry.TimeDomain(0, 0.99) -geomtime = dde.geometry.GeometryXTime(geom, timedomain) - -bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary) -ic = dde.icbc.IC( - geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial -) - -data = dde.data.TimePDE( - geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160 -) -net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal") -model = dde.Model(data, net) - -# Run Adam+L-BFGS -model.compile("adam", lr=1e-3) -model.train(iterations=15000) - -model.compile("L-BFGS") -losshistory, train_state = model.train() -dde.saveplot(losshistory, train_state, issave=True, isplot=True) - -# Get test data -X, y_true = gen_testdata() - -# Get the results after running Adam+L-BFGS -y_pred = model.predict(X) -f = model.predict(X, operator=pde) -print("Mean residual after Adam+L-BFGS:", np.mean(np.absolute(f))) -print( - "L2 relative error after Adam+L-BFGS:", - dde.metrics.l2_relative_error(y_true, y_pred), -) -np.savetxt("test_adam_lbfgs.dat", np.hstack((X, y_true, y_pred))) - -# Run NNCG after Adam+L-BFGS -dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) -model.compile("NNCG") -losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) -dde.saveplot(losshistory_nncg, train_state_nncg, issave=True, isplot=True) - -# Get the final results after running Adam+L-BFGS+NNCG -y_pred = model.predict(X) -f = model.predict(X, operator=pde) -print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f))) -print( - "L2 relative error after Adam+L-BFGS+NNCG:", - dde.metrics.l2_relative_error(y_true, y_pred), -) -np.savetxt("test_adam_lbfgs_nncg.dat", np.hstack((X, y_true, y_pred))) From f2d5c31fcb20561fa7d37861cc560f56b7b872ec Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Thu, 21 Nov 2024 17:21:56 -0800 Subject: [PATCH 24/26] more cleanup in demo --- deepxde/optimizers/config.py | 4 ++-- deepxde/optimizers/pytorch/nncg.py | 2 +- docs/demos/pinn_forward/burgers.rst | 4 ++-- examples/pinn_forward/Burgers.py | 24 ++++++------------------ 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 85830f72c..41bb19b86 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -63,8 +63,8 @@ def set_LBFGS_options( def set_NNCG_options( lr=1, - rank=10, - mu=1e-4, + rank=50, + mu=1e-1, updatefreq=20, chunksz=1, cgtol=1e-16, diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py index 961260986..93704ec9d 100644 --- a/deepxde/optimizers/pytorch/nncg.py +++ b/deepxde/optimizers/pytorch/nncg.py @@ -57,7 +57,7 @@ def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): if torch.norm(resid) > tol: print( "Warning: PCG did not converge to tolerance. " - "Tolerance was {tol} but norm of residual is {torch.norm(resid)}" + f"Tolerance was {tol} but norm of residual is {torch.norm(resid)}" ) return x diff --git a/docs/demos/pinn_forward/burgers.rst b/docs/demos/pinn_forward/burgers.rst index 3f931f866..08e56b38a 100644 --- a/docs/demos/pinn_forward/burgers.rst +++ b/docs/demos/pinn_forward/burgers.rst @@ -95,9 +95,9 @@ However, L-BFGS can stall out early in optimization if it is unable to find a st dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) model.compile("NNCG") - losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) + losshistory, train_state = model.train(iterations=1000, display_every=100) -By default, NNCG does not run in this demo. You will have to uncomment the NNCG code block at the end of the demo to have it run after Adam and L-BFGS. Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer. +By default, NNCG does not run in this demo. You will have to uncomment the NNCG code block in the demo to have it run after Adam and L-BFGS. Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer. Complete code -------------- diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index 3ab31a05f..ecd4449e2 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -1,5 +1,4 @@ """Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" - import deepxde as dde import numpy as np @@ -40,6 +39,12 @@ def pde(x, y): model.compile("L-BFGS") losshistory, train_state = model.train() dde.saveplot(losshistory, train_state, issave=True, isplot=True) +"""Backend supported: pytorch""" +# Run NNCG after Adam and L-BFGS +dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) +model.compile("NNCG") +losshistory, train_state = model.train(iterations=1000, display_every=100) +dde.saveplot(losshistory, train_state, issave=True, isplot=True) X, y_true = gen_testdata() y_pred = model.predict(X) @@ -47,20 +52,3 @@ def pde(x, y): print("Mean residual:", np.mean(np.absolute(f))) print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred)) np.savetxt("test.dat", np.hstack((X, y_true, y_pred))) - -# """Backend supported: pytorch""" -# # Run NNCG after Adam and L-BFGS -# dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) -# model.compile("NNCG") -# losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) -# dde.saveplot(losshistory_nncg, train_state_nncg, issave=True, isplot=True) - -# # Get the final results after running Adam+L-BFGS+NNCG -# y_pred = model.predict(X) -# f = model.predict(X, operator=pde) -# print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f))) -# print( -# "L2 relative error after Adam+L-BFGS+NNCG:", -# dde.metrics.l2_relative_error(y_true, y_pred), -# ) -# np.savetxt("test_nncg.dat", np.hstack((X, y_true, y_pred))) From 74ace2c64422f09913822e18b82040996d30ff91 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Thu, 21 Nov 2024 17:24:54 -0800 Subject: [PATCH 25/26] comment out nncg code block --- examples/pinn_forward/Burgers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index ecd4449e2..893e70789 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -39,12 +39,12 @@ def pde(x, y): model.compile("L-BFGS") losshistory, train_state = model.train() dde.saveplot(losshistory, train_state, issave=True, isplot=True) -"""Backend supported: pytorch""" -# Run NNCG after Adam and L-BFGS -dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) -model.compile("NNCG") -losshistory, train_state = model.train(iterations=1000, display_every=100) -dde.saveplot(losshistory, train_state, issave=True, isplot=True) +# """Backend supported: pytorch""" +# # Run NNCG after Adam and L-BFGS +# dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) +# model.compile("NNCG") +# losshistory, train_state = model.train(iterations=1000, display_every=100) +# dde.saveplot(losshistory, train_state, issave=True, isplot=True) X, y_true = gen_testdata() y_pred = model.predict(X) From c1a6365a1ef76667cb6ae3cbc08084e2fc6e78f7 Mon Sep 17 00:00:00 2001 From: Pratik Rathore Date: Mon, 25 Nov 2024 11:14:27 -0800 Subject: [PATCH 26/26] minor cleanup in burgers example --- examples/pinn_forward/Burgers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pinn_forward/Burgers.py b/examples/pinn_forward/Burgers.py index 893e70789..7b6883ed1 100644 --- a/examples/pinn_forward/Burgers.py +++ b/examples/pinn_forward/Burgers.py @@ -1,4 +1,5 @@ """Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" + import deepxde as dde import numpy as np @@ -38,13 +39,12 @@ def pde(x, y): model.train(iterations=15000) model.compile("L-BFGS") losshistory, train_state = model.train() -dde.saveplot(losshistory, train_state, issave=True, isplot=True) # """Backend supported: pytorch""" # # Run NNCG after Adam and L-BFGS # dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) # model.compile("NNCG") # losshistory, train_state = model.train(iterations=1000, display_every=100) -# dde.saveplot(losshistory, train_state, issave=True, isplot=True) +dde.saveplot(losshistory, train_state, issave=True, isplot=True) X, y_true = gen_testdata() y_pred = model.predict(X)