Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jitter to Cholesky factorization in Gaussian ops #3151

Merged
merged 11 commits into from
Oct 30, 2022
10 changes: 5 additions & 5 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from pyro.ops.indexing import Vindex
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky, cholesky_solve
from pyro.ops.tensor_utils import cholesky_solve, safe_cholesky

from . import constraints
from .torch import Categorical, Gamma, Independent, MultivariateNormal
Expand Down Expand Up @@ -628,9 +628,9 @@ def filter(self, value):

# Convert to a distribution
precision = logp.precision
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), cholesky(precision)).squeeze(
-1
)
loc = cholesky_solve(
logp.info_vec.unsqueeze(-1), safe_cholesky(precision)
).squeeze(-1)
return MultivariateNormal(
loc, precision_matrix=precision, validate_args=self._validate_args
)
Expand Down Expand Up @@ -928,7 +928,7 @@ def filter(self, value):
gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args
)
# Conditional of last state on unit scale
scale_tril = cholesky(logp.precision)
scale_tril = safe_cholesky(logp.precision)
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1)
mvn = MultivariateNormal(
loc, scale_tril=scale_tril, validate_args=self._validate_args
Expand Down
4 changes: 2 additions & 2 deletions pyro/distributions/transforms/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def log_abs_det_jacobian(self, x, y):

class CholeskyTransform(Transform):
r"""
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a
positive definite matrix.
"""
bijective = True
Expand All @@ -116,7 +116,7 @@ def log_abs_det_jacobian(self, x, y):

class CorrMatrixCholeskyTransform(CholeskyTransform):
r"""
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a
correlation matrix.
"""
bijective = True
Expand Down
10 changes: 5 additions & 5 deletions pyro/ops/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.nn.functional import pad

from pyro.distributions.util import broadcast_shape
from pyro.ops.tensor_utils import cholesky, matmul, matvecmul, triangular_solve
from pyro.ops.tensor_utils import matmul, matvecmul, safe_cholesky, triangular_solve


class Gaussian:
Expand Down Expand Up @@ -154,7 +154,7 @@ def rsample(
"""
Reparameterized sampler.
"""
P_chol = cholesky(self.precision)
P_chol = safe_cholesky(self.precision)
loc = self.info_vec.unsqueeze(-1).cholesky_solve(P_chol).squeeze(-1)
shape = sample_shape + self.batch_shape + (self.dim(), 1)
if noise is None:
Expand Down Expand Up @@ -254,7 +254,7 @@ def marginalize(self, left=0, right=0) -> "Gaussian":
P_aa = self.precision[..., a, a]
P_ba = self.precision[..., b, a]
P_bb = self.precision[..., b, b]
P_b = cholesky(P_bb)
P_b = safe_cholesky(P_bb)
P_a = triangular_solve(P_ba, P_b, upper=False)
P_at = P_a.transpose(-1, -2)
precision = P_aa - matmul(P_at, P_a)
Expand All @@ -277,7 +277,7 @@ def event_logsumexp(self) -> torch.Tensor:
Integrates out all latent state (i.e. operating on event dimensions).
"""
n = self.dim()
chol_P = cholesky(self.precision)
chol_P = safe_cholesky(self.precision)
chol_P_u = triangular_solve(
self.info_vec.unsqueeze(-1), chol_P, upper=False
).squeeze(-1)
Expand Down Expand Up @@ -550,7 +550,7 @@ def gaussian_tensordot(x: Gaussian, y: Gaussian, dims: int = 0) -> Gaussian:
b = xb + yb

# Pbb + Qbb needs to be positive definite, so that we can malginalize out `b` (to have a finite integral)
L = cholesky(Pbb + Qbb)
L = safe_cholesky(Pbb + Qbb)
LinvB = triangular_solve(B, L, upper=False)
LinvBt = LinvB.transpose(-2, -1)
Linvb = triangular_solve(b.unsqueeze(-1), L, upper=False)
Expand Down
13 changes: 12 additions & 1 deletion pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.fft import irfft, rfft

_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)
CHOLESKY_RELATIVE_JITTER = 4.0 # in units of finfo.eps


def as_complex(x):
Expand Down Expand Up @@ -393,9 +394,19 @@ def inverse_haar_transform(x):
return x


def cholesky(x):
def safe_cholesky(x):
if x.size(-1) == 1:
if CHOLESKY_RELATIVE_JITTER:
x = x.clamp(min=torch.finfo(x.dtype).tiny)
return x.sqrt()

if CHOLESKY_RELATIVE_JITTER:
# Add adaptive jitter.
x = x.clone()
x_max = x.data.abs().max(-1).values
jitter = CHOLESKY_RELATIVE_JITTER * torch.finfo(x.dtype).eps * x_max
x.data.diagonal(dim1=-1, dim2=-2).add_(jitter)

return torch.linalg.cholesky(x)


Expand Down
55 changes: 54 additions & 1 deletion tests/ops/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AffineNormal,
Gaussian,
gaussian_tensordot,
matrix_and_gaussian_to_gaussian,
matrix_and_mvn_to_gaussian,
mvn_to_gaussian,
sequential_gaussian_filter_sample,
Expand Down Expand Up @@ -378,7 +379,7 @@ def test_gaussian_tensordot(
nc = y_dim - dot_dims
try:
torch.linalg.cholesky(x.precision[..., na:, na:] + y.precision[..., :nb, :nb])
except RuntimeError:
except Exception:
pytest.skip("Cannot marginalize the common variables of two Gaussians.")

z = gaussian_tensordot(x, y, dot_dims)
Expand Down Expand Up @@ -557,3 +558,55 @@ def test_sequential_gaussian_filter_sample_antithetic(
)
expected = torch.stack([sample, mean, 2 * mean - sample])
assert torch.allclose(sample3, expected)


@pytest.mark.filterwarnings("ignore:Singular matrix in cholesky")
@pytest.mark.parametrize("num_steps", [10, 100, 1000, 10000, 100000, 1000000])
def test_sequential_gaussian_filter_sample_stability(num_steps):
# This tests long-chain filtering at low precision.
zero = torch.zeros((), dtype=torch.float)
eye = torch.eye(4, dtype=torch.float)
noise = torch.randn(num_steps, 4, dtype=torch.float, requires_grad=True)
trans_matrix = torch.tensor(
[
[
0.8571434617042542,
-0.23285813629627228,
0.05360094830393791,
-0.017088839784264565,
],
[
0.7609677314758301,
0.6596274971961975,
-0.022656921297311783,
0.05166701227426529,
],
[
3.0979342460632324,
5.446939945220947,
-0.3425334692001343,
0.01096670888364315,
],
[
-1.8180007934570312,
-0.4965082108974457,
-0.006048532668501139,
-0.08525419235229492,
],
],
dtype=torch.float,
requires_grad=True,
)

init = Gaussian(zero, zero.expand(4), eye)
trans = matrix_and_gaussian_to_gaussian(
trans_matrix, Gaussian(zero, zero.expand(4), eye)
).expand((num_steps - 1,))

# Check numerically stabilized value.
x = sequential_gaussian_filter_sample(init, trans, (), noise)
assert torch.isfinite(x).all()

# Check gradients.
grads = torch.autograd.grad(x.sum(), [trans_matrix, noise])
assert all(torch.isfinite(g).all() for g in grads)