Skip to content

Commit

Permalink
improve speed and numerical stability of scale_tril to precision (#2264)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and fritzo committed Jan 17, 2020
1 parent 64407af commit afcd19c
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 10 deletions.
21 changes: 15 additions & 6 deletions pyro/distributions/multivariate_studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,29 @@ def __init__(self, df, loc, scale_tril, validate_args=None):
batch_shape = broadcast_shape(df.shape, loc.shape[:-1], scale_tril.shape[:-2])
event_shape = (dim,)
self.df = df.expand(batch_shape)
self.loc = loc
self.scale_tril = scale_tril
self.loc = loc.expand(batch_shape + event_shape)
self._unbroadcasted_scale_tril = scale_tril
self._chi2 = Chi2(self.df)
super(MultivariateStudentT, self).__init__(batch_shape, event_shape, validate_args=validate_args)

@lazy_property
def scale_tril(self):
return self._unbroadcasted_scale_tril.expand(
self._batch_shape + self._event_shape + self._event_shape)

@lazy_property
def covariance_matrix(self):
# NB: this is not covariance of this distribution;
# the actual covariance is df / (df - 2) * covariance_matrix
return torch.matmul(self.scale_tril, self.scale_tril.transpose(-1, -2))
return (torch.matmul(self._unbroadcasted_scale_tril,
self._unbroadcasted_scale_tril.transpose(-1, -2))
.expand(self._batch_shape + self._event_shape + self._event_shape))

@lazy_property
def precision_matrix(self):
identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype)
scale_inv = identity.triangular_solve(self.scale_tril, upper=False).solution.transpose(-1, -2)
return torch.matmul(scale_inv.transpose(-1, -2), scale_inv)
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
self._batch_shape + self._event_shape + self._event_shape)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(MultivariateStudentT, _instance)
Expand All @@ -60,7 +67,9 @@ def expand(self, batch_shape, _instance=None):
scale_shape = loc_shape + self.event_shape
new.df = self.df.expand(batch_shape)
new.loc = self.loc.expand(loc_shape)
new.scale_tril = self.scale_tril.expand(scale_shape)
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
if 'scale_tril' in self.__dict__:
new.scale_tril = self.scale_tril.expand(scale_shape)
if 'covariance_matrix' in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(scale_shape)
if 'precision_matrix' in self.__dict__:
Expand Down
8 changes: 8 additions & 0 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from pyro.distributions.constraints import IndependentConstraint
from pyro.distributions.torch_distribution import TorchDistributionMixin
Expand Down Expand Up @@ -40,6 +41,13 @@ def enumerate_support(self, expand=True):
class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin):
support = IndependentConstraint(constraints.real, 1) # TODO move upstream

# TODO: remove this in the PyTorch release > 1.4.0
@lazy_property
def precision_matrix(self):
identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype)
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
self._batch_shape + self._event_shape + self._event_shape)


class Independent(torch.distributions.Independent, TorchDistributionMixin):
@constraints.dependent_property
Expand Down
10 changes: 6 additions & 4 deletions pyro/ops/gamma_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def _precision_to_scale_tril(P):
class Gamma:
"""
Non-normalized Gamma distribution.
Gamma(concentration, rate) ~ (concentration - 1) * log(s) - rate * s
"""
def __init__(self, log_normalizer, concentration, rate):
self.log_normalizer = log_normalizer
Expand Down Expand Up @@ -280,15 +282,15 @@ def compound(self):
Integrates out the latent multiplier `s`. The result will be a
Student-T distribution.
"""
alpha = self.alpha - 0.5 * self.dim() + 1
concentration = self.alpha - 0.5 * self.dim() + 1
scale_tril = _precision_to_scale_tril(self.precision)
scale_tril_t_u = scale_tril.transpose(-1, -2).matmul(self.info_vec.unsqueeze(-1)).squeeze(-1)
u_Pinv_u = scale_tril_t_u.pow(2).sum(-1)
beta = self.beta - 0.5 * u_Pinv_u
rate = self.beta - 0.5 * u_Pinv_u

loc = scale_tril.matmul(scale_tril_t_u.unsqueeze(-1)).squeeze(-1)
scale_tril = scale_tril * (beta / alpha).sqrt().unsqueeze(-1).unsqueeze(-1)
return MultivariateStudentT(2 * alpha, loc, scale_tril)
scale_tril = scale_tril * (rate / concentration).sqrt().unsqueeze(-1).unsqueeze(-1)
return MultivariateStudentT(2 * concentration, loc, scale_tril)

def event_logsumexp(self):
"""
Expand Down
41 changes: 41 additions & 0 deletions tests/distributions/test_mvn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pyro.distributions import MultivariateNormal
from tests.common import assert_equal


def random_mvn(loc_shape, cov_shape, dim):
"""
Generate a random MultivariateNormal distribution for testing.
"""
rank = dim + dim
loc = torch.randn(loc_shape + (dim,), requires_grad=True)
cov = torch.randn(cov_shape + (dim, rank), requires_grad=True)
cov = cov.matmul(cov.transpose(-1, -2))
return MultivariateNormal(loc, cov)


@pytest.mark.parametrize('loc_shape', [
(), (2,), (3, 2),
])
@pytest.mark.parametrize('cov_shape', [
(), (2,), (3, 2),
])
@pytest.mark.parametrize('dim', [
1, 3, 5,
])
def test_shape(loc_shape, cov_shape, dim):
mvn = random_mvn(loc_shape, cov_shape, dim)
assert mvn.loc.shape == mvn.batch_shape + mvn.event_shape
assert mvn.covariance_matrix.shape == mvn.batch_shape + mvn.event_shape * 2
assert mvn.scale_tril.shape == mvn.covariance_matrix.shape
assert mvn.precision_matrix.shape == mvn.covariance_matrix.shape

assert_equal(mvn.precision_matrix, mvn.covariance_matrix.inverse())

# smoke test for precision/log_prob backward
(mvn.precision_matrix.sum() + mvn.log_prob(torch.zeros(dim)).sum()).backward()
39 changes: 39 additions & 0 deletions tests/distributions/test_mvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,45 @@
from tests.common import assert_equal


def random_mvt(df_shape, loc_shape, cov_shape, dim):
"""
Generate a random MultivariateStudentT distribution for testing.
"""
rank = dim + dim
df = torch.rand(df_shape, requires_grad=True).exp()
loc = torch.randn(loc_shape + (dim,), requires_grad=True)
cov = torch.randn(cov_shape + (dim, rank), requires_grad=True)
cov = cov.matmul(cov.transpose(-1, -2))
scale_tril = cov.cholesky()
return MultivariateStudentT(df, loc, scale_tril)


@pytest.mark.parametrize('df_shape', [
(), (2,), (3, 2),
])
@pytest.mark.parametrize('loc_shape', [
(), (2,), (3, 2),
])
@pytest.mark.parametrize('cov_shape', [
(), (2,), (3, 2),
])
@pytest.mark.parametrize('dim', [
1, 3, 5,
])
def test_shape(df_shape, loc_shape, cov_shape, dim):
mvt = random_mvt(df_shape, loc_shape, cov_shape, dim)
assert mvt.df.shape == mvt.batch_shape
assert mvt.loc.shape == mvt.batch_shape + mvt.event_shape
assert mvt.covariance_matrix.shape == mvt.batch_shape + mvt.event_shape * 2
assert mvt.scale_tril.shape == mvt.covariance_matrix.shape
assert mvt.precision_matrix.shape == mvt.covariance_matrix.shape

assert_equal(mvt.precision_matrix, mvt.covariance_matrix.inverse())

# smoke test for precision/log_prob backward
(mvt.precision_matrix.sum() + mvt.log_prob(torch.zeros(dim)).sum()).backward()


@pytest.mark.parametrize("batch_shape", [
(),
(3, 2),
Expand Down

0 comments on commit afcd19c

Please sign in to comment.