diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index ed10bb76a4..f02fe2fde3 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -32,7 +32,6 @@ import math -import warnings from abc import abstractmethod from typing import Any, Dict, List, Mapping, Optional, Tuple @@ -56,7 +55,7 @@ from gpytorch.means.constant_mean import ConstantMean from gpytorch.means.mean import Mean from gpytorch.models.exact_gp import ExactGP -from linear_operator import settings +from pyro.ops.integrator import register_exception_handler from torch import Tensor MIN_INFERRED_NOISE_LEVEL = 1e-6 @@ -64,6 +63,20 @@ _sqrt5 = math.sqrt(5) +def _handle_torch_linalg(exception: Exception) -> bool: + return type(exception) == torch.linalg.LinAlgError + + +def _handle_valerr_in_dist_init(exception: Exception) -> bool: + if not type(exception) == ValueError: + return False + return "satisfy the constraint PositiveDefinite()" in str(exception) + + +register_exception_handler("torch_linalg", _handle_torch_linalg) +register_exception_handler("valerr_in_dist_init", _handle_valerr_in_dist_init) + + def matern52_kernel(X: Tensor, lengthscale: Tensor) -> Tensor: """Matern-5/2 kernel.""" dist = compute_dists(X=X, lengthscale=lengthscale) @@ -82,51 +95,6 @@ def reshape_and_detach(target: Tensor, new_value: Tensor) -> None: return new_value.detach().clone().view(target.shape).to(target) -def _psd_safe_pyro_mvn_sample( - name: str, loc: Tensor, covariance_matrix: Tensor, obs: Tensor -) -> None: - r"""Wraps the `pyro.sample` call in a loop to add an increasing series of jitter - to the covariance matrix each time we get a LinAlgError. - - This is modelled after linear_operator's `psd_safe_cholesky`. - """ - jitter = settings.cholesky_jitter.value(loc.dtype) - max_tries = settings.cholesky_max_tries.value() - for i in range(max_tries + 1): - jitter_matrix = ( - torch.eye( - covariance_matrix.shape[-1], - device=covariance_matrix.device, - dtype=covariance_matrix.dtype, - ) - * jitter - ) - jittered_covar = ( - covariance_matrix if i == 0 else covariance_matrix + jitter_matrix - ) - try: - pyro.sample( - name, - pyro.distributions.MultivariateNormal( - loc=loc, - covariance_matrix=jittered_covar, - ), - obs=obs, - ) - return - except (torch.linalg.LinAlgError, ValueError) as e: - if isinstance(e, ValueError) and "satisfy the constraint" not in str(e): - # Not-PSD can be also caught in Distribution.__init__ during parameter - # validation, which raises a ValueError. Only catch those errors. - raise e - jitter = jitter * 10 - warnings.warn( - "Received a linear algebra error while sampling with Pyro. Adding a " - f"jitter of {jitter} to the covariance matrix and retrying.", - RuntimeWarning, - ) - - class PyroModel: r""" Base class for a Pyro model; used to assist in learning hyperparameters. @@ -208,12 +176,14 @@ def sample(self) -> None: mean = self.sample_mean(**tkwargs) noise = self.sample_noise(**tkwargs) lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs) - k = matern52_kernel(X=self.train_X, lengthscale=lengthscale) - k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs) - _psd_safe_pyro_mvn_sample( - name="Y", - loc=mean.view(-1).expand(self.train_X.shape[0]), - covariance_matrix=k, + K = matern52_kernel(X=self.train_X, lengthscale=lengthscale) + K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs) + pyro.sample( + "Y", + pyro.distributions.MultivariateNormal( + loc=mean.view(-1).expand(self.train_X.shape[0]), + covariance_matrix=K, + ), obs=self.train_Y.squeeze(-1), ) diff --git a/botorch/models/fully_bayesian_multitask.py b/botorch/models/fully_bayesian_multitask.py index 0c0de0d89e..e84a2be344 100644 --- a/botorch/models/fully_bayesian_multitask.py +++ b/botorch/models/fully_bayesian_multitask.py @@ -14,7 +14,6 @@ import torch from botorch.acquisition.objective import PosteriorTransform from botorch.models.fully_bayesian import ( - _psd_safe_pyro_mvn_sample, matern52_kernel, MIN_INFERRED_NOISE_LEVEL, PyroModel, @@ -94,7 +93,7 @@ def sample(self) -> None: noise = self.sample_noise(**tkwargs) lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs) - k = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale) + K = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale) # compute task covar matrix task_latent_features = self.sample_latent_features(**tkwargs)[task_indices] @@ -102,12 +101,14 @@ def sample(self) -> None: task_covar = matern52_kernel( X=task_latent_features, lengthscale=task_lengthscale ) - k = k.mul(task_covar) - k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs) - _psd_safe_pyro_mvn_sample( - name="Y", - loc=mean.view(-1).expand(self.train_X.shape[0]), - covariance_matrix=k, + K = K.mul(task_covar) + K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs) + pyro.sample( + "Y", + pyro.distributions.MultivariateNormal( + loc=mean.view(-1).expand(self.train_X.shape[0]), + covariance_matrix=K, + ), obs=self.train_Y.squeeze(-1), ) diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index d32b842a16..3c1ace0656 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -6,7 +6,6 @@ import itertools -import warnings from unittest import mock import pyro @@ -35,7 +34,6 @@ from botorch.models import ModelList, ModelListGP from botorch.models.deterministic import GenericDeterministicModel from botorch.models.fully_bayesian import ( - _psd_safe_pyro_mvn_sample, MCMC_DIM, MIN_INFERRED_NOISE_LEVEL, PyroModel, @@ -55,6 +53,7 @@ from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.means import ConstantMean from linear_operator.operators import to_linear_operator +from pyro.ops.integrator import potential_grad, register_exception_handler EXPECTED_KEYS = [ @@ -665,61 +664,55 @@ def f(x): ) ) - def test_psd_safe_pyro_mvn_sample(self): - def mock_init( - batch_shape=torch.Size(), # noqa - event_shape=torch.Size(), # noqa - validate_args=None, - ): - self._batch_shape = batch_shape - self._event_shape = event_shape - self._validate_args = False - - for dtype in (torch.float, torch.double): - tkwargs = {"dtype": dtype, "device": self.device} - loc = torch.rand(5, **tkwargs) - obs = torch.rand(5, **tkwargs) - psd_covar = torch.eye(5, **tkwargs) - not_psd_covar = torch.ones(5, 5, **tkwargs) - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always") - _psd_safe_pyro_mvn_sample( - name="Y", loc=loc, covariance_matrix=psd_covar, obs=obs - ) - self.assertFalse(any("linear algebra error" in str(w.message) for w in ws)) - # With a PSD covar, it should only get called once. - # Raised as a ValueError: - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always") - _psd_safe_pyro_mvn_sample( - name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs - ) - self.assertTrue(any("linear algebra error" in str(w.message) for w in ws)) - # Raised as a LinAlgError: - with mock.patch( - "torch.distributions.multivariate_normal.Distribution.__init__", - wraps=mock_init, - ), mock.patch( - "pyro.distributions.MultivariateNormal", - wraps=pyro.distributions.MultivariateNormal, - ) as mock_mvn, warnings.catch_warnings( - record=True - ) as ws: - warnings.simplefilter("always") - _psd_safe_pyro_mvn_sample( - name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs - ) - # Check that it added the jitter. - self.assertGreaterEqual( - mock_mvn.call_args[-1]["covariance_matrix"][0, 0].item(), 1 + 1e-8 + +class TestPyroCatchNumericalErrors(BotorchTestCase): + def test_pyro_catch_error(self): + def potential_fn(z): + mvn = pyro.distributions.MultivariateNormal( + loc=torch.zeros(2), + covariance_matrix=z["K"], ) - # With a not-PSD covar, it should get called multiple times. - self.assertTrue(any("linear algebra error" in str(w.message) for w in ws)) - # We don't catch random Value errors. - with mock.patch( - "torch.distributions.multivariate_normal.Distribution.__init__", - side_effect=ValueError("dummy error"), - ), self.assertRaisesRegex(ValueError, "dummy"): - _psd_safe_pyro_mvn_sample( - name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs - ) + return mvn.log_prob(torch.zeros(2)) + + # Test base case where everything is fine + z = {"K": torch.eye(2)} + grads, val = potential_grad(potential_fn, z) + self.assertTrue(torch.allclose(grads["K"], -0.5 * torch.eye(2))) + norm_mvn = torch.distributions.Normal(0, 1) + self.assertTrue(torch.allclose(val, 2 * norm_mvn.log_prob(torch.zeros(1)))) + + # Default behavior should catch the ValueError when trying to instantiate + # the MVN and return NaN instead + z = {"K": torch.ones(2, 2)} + _, val = potential_grad(potential_fn, z) + self.assertTrue(torch.isnan(val)) + + # Default behavior should catch the LinAlgError when peforming a + # Cholesky decomposition and return NaN instead + def potential_fn_chol(z): + return torch.linalg.cholesky(z["K"]) + + _, val = potential_grad(potential_fn_chol, z) + self.assertTrue(torch.isnan(val)) + + # Default behavior should not catch other errors + def potential_fn_rterr_foo(z): + raise RuntimeError("foo") + + with self.assertRaisesRegex(RuntimeError, "foo"): + potential_grad(potential_fn_rterr_foo, z) + + # But once we register this specific error then it should + def catch_runtime_error(e): + return type(e) == RuntimeError and "foo" in str(e) + + register_exception_handler("foo_runtime", catch_runtime_error) + _, val = potential_grad(potential_fn_rterr_foo, z) + self.assertTrue(torch.isnan(val)) + + # Unless the error message is different + def potential_fn_rterr_bar(z): + raise RuntimeError("bar") + + with self.assertRaisesRegex(RuntimeError, "bar"): + potential_grad(potential_fn_rterr_bar, z)