Skip to content

Commit

Permalink
Register torch.linalg.LinAlgError to pyro exception handling (#1607)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1607

Uses draft changes from pyro-ppl/pyro#3168 (part of pyro 1.8.4 pulled in via D42331876) to register handling of `torch.linalg.LinAlgError` and the `ValueError` that can be raised in the torch distribution's `__init__()`

Reviewed By: saitcakmak

Differential Revision: D42159791

fbshipit-source-id: 3bbe2433b83bd114edd277e42f0017010ac9199f
  • Loading branch information
Balandat authored and facebook-github-bot committed Jan 4, 2023
1 parent 9cd4dea commit cbbbf11
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 120 deletions.
76 changes: 23 additions & 53 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


import math
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Tuple

Expand All @@ -56,14 +55,28 @@
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.means.mean import Mean
from gpytorch.models.exact_gp import ExactGP
from linear_operator import settings
from pyro.ops.integrator import register_exception_handler
from torch import Tensor

MIN_INFERRED_NOISE_LEVEL = 1e-6

_sqrt5 = math.sqrt(5)


def _handle_torch_linalg(exception: Exception) -> bool:
return type(exception) == torch.linalg.LinAlgError


def _handle_valerr_in_dist_init(exception: Exception) -> bool:
if not type(exception) == ValueError:
return False
return "satisfy the constraint PositiveDefinite()" in str(exception)


register_exception_handler("torch_linalg", _handle_torch_linalg)
register_exception_handler("valerr_in_dist_init", _handle_valerr_in_dist_init)


def matern52_kernel(X: Tensor, lengthscale: Tensor) -> Tensor:
"""Matern-5/2 kernel."""
dist = compute_dists(X=X, lengthscale=lengthscale)
Expand All @@ -82,51 +95,6 @@ def reshape_and_detach(target: Tensor, new_value: Tensor) -> None:
return new_value.detach().clone().view(target.shape).to(target)


def _psd_safe_pyro_mvn_sample(
name: str, loc: Tensor, covariance_matrix: Tensor, obs: Tensor
) -> None:
r"""Wraps the `pyro.sample` call in a loop to add an increasing series of jitter
to the covariance matrix each time we get a LinAlgError.
This is modelled after linear_operator's `psd_safe_cholesky`.
"""
jitter = settings.cholesky_jitter.value(loc.dtype)
max_tries = settings.cholesky_max_tries.value()
for i in range(max_tries + 1):
jitter_matrix = (
torch.eye(
covariance_matrix.shape[-1],
device=covariance_matrix.device,
dtype=covariance_matrix.dtype,
)
* jitter
)
jittered_covar = (
covariance_matrix if i == 0 else covariance_matrix + jitter_matrix
)
try:
pyro.sample(
name,
pyro.distributions.MultivariateNormal(
loc=loc,
covariance_matrix=jittered_covar,
),
obs=obs,
)
return
except (torch.linalg.LinAlgError, ValueError) as e:
if isinstance(e, ValueError) and "satisfy the constraint" not in str(e):
# Not-PSD can be also caught in Distribution.__init__ during parameter
# validation, which raises a ValueError. Only catch those errors.
raise e
jitter = jitter * 10
warnings.warn(
"Received a linear algebra error while sampling with Pyro. Adding a "
f"jitter of {jitter} to the covariance matrix and retrying.",
RuntimeWarning,
)


class PyroModel:
r"""
Base class for a Pyro model; used to assist in learning hyperparameters.
Expand Down Expand Up @@ -208,12 +176,14 @@ def sample(self) -> None:
mean = self.sample_mean(**tkwargs)
noise = self.sample_noise(**tkwargs)
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
k = matern52_kernel(X=self.train_X, lengthscale=lengthscale)
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
_psd_safe_pyro_mvn_sample(
name="Y",
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=k,
K = matern52_kernel(X=self.train_X, lengthscale=lengthscale)
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
pyro.sample(
"Y",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=K,
),
obs=self.train_Y.squeeze(-1),
)

Expand Down
17 changes: 9 additions & 8 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import (
_psd_safe_pyro_mvn_sample,
matern52_kernel,
MIN_INFERRED_NOISE_LEVEL,
PyroModel,
Expand Down Expand Up @@ -94,20 +93,22 @@ def sample(self) -> None:
noise = self.sample_noise(**tkwargs)

lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
k = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale)
K = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale)

# compute task covar matrix
task_latent_features = self.sample_latent_features(**tkwargs)[task_indices]
task_lengthscale = self.sample_task_lengthscale(**tkwargs)
task_covar = matern52_kernel(
X=task_latent_features, lengthscale=task_lengthscale
)
k = k.mul(task_covar)
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
_psd_safe_pyro_mvn_sample(
name="Y",
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=k,
K = K.mul(task_covar)
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
pyro.sample(
"Y",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=K,
),
obs=self.train_Y.squeeze(-1),
)

Expand Down
111 changes: 52 additions & 59 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


import itertools
import warnings
from unittest import mock

import pyro
Expand Down Expand Up @@ -35,7 +34,6 @@
from botorch.models import ModelList, ModelListGP
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.fully_bayesian import (
_psd_safe_pyro_mvn_sample,
MCMC_DIM,
MIN_INFERRED_NOISE_LEVEL,
PyroModel,
Expand All @@ -55,6 +53,7 @@
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood
from gpytorch.means import ConstantMean
from linear_operator.operators import to_linear_operator
from pyro.ops.integrator import potential_grad, register_exception_handler


EXPECTED_KEYS = [
Expand Down Expand Up @@ -665,61 +664,55 @@ def f(x):
)
)

def test_psd_safe_pyro_mvn_sample(self):
def mock_init(
batch_shape=torch.Size(), # noqa
event_shape=torch.Size(), # noqa
validate_args=None,
):
self._batch_shape = batch_shape
self._event_shape = event_shape
self._validate_args = False

for dtype in (torch.float, torch.double):
tkwargs = {"dtype": dtype, "device": self.device}
loc = torch.rand(5, **tkwargs)
obs = torch.rand(5, **tkwargs)
psd_covar = torch.eye(5, **tkwargs)
not_psd_covar = torch.ones(5, 5, **tkwargs)
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
_psd_safe_pyro_mvn_sample(
name="Y", loc=loc, covariance_matrix=psd_covar, obs=obs
)
self.assertFalse(any("linear algebra error" in str(w.message) for w in ws))
# With a PSD covar, it should only get called once.
# Raised as a ValueError:
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
_psd_safe_pyro_mvn_sample(
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
)
self.assertTrue(any("linear algebra error" in str(w.message) for w in ws))
# Raised as a LinAlgError:
with mock.patch(
"torch.distributions.multivariate_normal.Distribution.__init__",
wraps=mock_init,
), mock.patch(
"pyro.distributions.MultivariateNormal",
wraps=pyro.distributions.MultivariateNormal,
) as mock_mvn, warnings.catch_warnings(
record=True
) as ws:
warnings.simplefilter("always")
_psd_safe_pyro_mvn_sample(
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
)
# Check that it added the jitter.
self.assertGreaterEqual(
mock_mvn.call_args[-1]["covariance_matrix"][0, 0].item(), 1 + 1e-8

class TestPyroCatchNumericalErrors(BotorchTestCase):
def test_pyro_catch_error(self):
def potential_fn(z):
mvn = pyro.distributions.MultivariateNormal(
loc=torch.zeros(2),
covariance_matrix=z["K"],
)
# With a not-PSD covar, it should get called multiple times.
self.assertTrue(any("linear algebra error" in str(w.message) for w in ws))
# We don't catch random Value errors.
with mock.patch(
"torch.distributions.multivariate_normal.Distribution.__init__",
side_effect=ValueError("dummy error"),
), self.assertRaisesRegex(ValueError, "dummy"):
_psd_safe_pyro_mvn_sample(
name="Y", loc=loc, covariance_matrix=not_psd_covar, obs=obs
)
return mvn.log_prob(torch.zeros(2))

# Test base case where everything is fine
z = {"K": torch.eye(2)}
grads, val = potential_grad(potential_fn, z)
self.assertTrue(torch.allclose(grads["K"], -0.5 * torch.eye(2)))
norm_mvn = torch.distributions.Normal(0, 1)
self.assertTrue(torch.allclose(val, 2 * norm_mvn.log_prob(torch.zeros(1))))

# Default behavior should catch the ValueError when trying to instantiate
# the MVN and return NaN instead
z = {"K": torch.ones(2, 2)}
_, val = potential_grad(potential_fn, z)
self.assertTrue(torch.isnan(val))

# Default behavior should catch the LinAlgError when peforming a
# Cholesky decomposition and return NaN instead
def potential_fn_chol(z):
return torch.linalg.cholesky(z["K"])

_, val = potential_grad(potential_fn_chol, z)
self.assertTrue(torch.isnan(val))

# Default behavior should not catch other errors
def potential_fn_rterr_foo(z):
raise RuntimeError("foo")

with self.assertRaisesRegex(RuntimeError, "foo"):
potential_grad(potential_fn_rterr_foo, z)

# But once we register this specific error then it should
def catch_runtime_error(e):
return type(e) == RuntimeError and "foo" in str(e)

register_exception_handler("foo_runtime", catch_runtime_error)
_, val = potential_grad(potential_fn_rterr_foo, z)
self.assertTrue(torch.isnan(val))

# Unless the error message is different
def potential_fn_rterr_bar(z):
raise RuntimeError("bar")

with self.assertRaisesRegex(RuntimeError, "bar"):
potential_grad(potential_fn_rterr_bar, z)

0 comments on commit cbbbf11

Please sign in to comment.