Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use approximate log_beta() in .fit(), .predict() #2502

Merged
merged 4 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS, config_enumerate, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.ops.tensor_utils import convolve, safe_log
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import convolve
from pyro.util import warn_if_nan

logging.basicConfig(format='%(message)s', level=logging.INFO)
Expand Down
2 changes: 2 additions & 0 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def generate(self, fixed={}):
self._concat_series(samples)
return samples

@set_approx_log_prob_tol(0.1)
def fit(self, **options):
r"""
Runs inference to generate posterior samples.
Expand Down Expand Up @@ -397,6 +398,7 @@ def heuristic():
return mcmc # E.g. so user can run mcmc.summary().

@torch.no_grad()
@set_approx_log_prob_tol(0.1)
@set_approx_sample_thresh(10000)
def predict(self, forecast=0):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/epidemiology/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def set_approx_log_prob_tol(tol):
:type tol: int or float.
"""
assert isinstance(tol, (float, int))
assert tol > 0
assert tol >= 0
old1 = dist.Binomial.approx_log_prob_tol
old2 = dist.BetaBinomial.approx_log_prob_tol
try:
Expand Down
Empty file removed pyro/contrib/epidemiology/seir.py
Empty file.
2 changes: 1 addition & 1 deletion pyro/contrib/epidemiology/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.util import broadcast_shape
from pyro.ops.tensor_utils import safe_log
from pyro.ops.special import safe_log


def clamp(tensor, *, min=None, max=None):
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/coalescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.distributions import constraints

from pyro.distributions.util import broadcast_shape, is_validation_enabled
from pyro.ops.tensor_utils import safe_log
from pyro.ops.special import safe_log

from .torch_distribution import TorchDistribution

Expand Down
9 changes: 6 additions & 3 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ class BetaBinomial(TorchDistribution):
is unknown and randomly drawn from a :class:`~pyro.distributions.Beta` distribution
prior to a certain number of Bernoulli trials given by ``total_count``.

:param float or torch.Tensor concentration1: 1st concentration parameter (alpha) for the
:param concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param float or torch.Tensor concentration0: 2nd concentration parameter (beta) for the
:type concentration1: float or torch.Tensor
:param concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param int or torch.Tensor total_count: number of Bernoulli trials.
:type concentration0: float or torch.Tensor
:param total_count: Number of Bernoulli trials.
:type total_count: float or torch.Tensor
"""
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive,
'total_count': constraints.nonnegative_integer}
Expand Down
18 changes: 15 additions & 3 deletions pyro/distributions/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ExtendedBinomial(Binomial):

def log_prob(self, value):
result = super().log_prob(value)
invalid = ~super().support.check(value)
invalid = (value < 0) | (value > self.total_count)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curous: why this change?

Copy link
Member Author

@fritzo fritzo May 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new version is a little cheaper. The old integer_interval version additionally checks value % 1 == 0, but that is already checked in validation by the above line, and it doesn't affect numerical stability.

return result.masked_fill(invalid, -math.inf)


Expand All @@ -40,6 +40,18 @@ class ExtendedBetaBinomial(BetaBinomial):
support = constraints.integer

def log_prob(self, value):
result = super().log_prob(value)
invalid = ~super().support.check(value)
if self._validate_args:
self._validate_sample(value)

total_count = self.total_count
invalid = (value < 0) | (value > total_count)
n = total_count.clamp(min=0)
k = value.masked_fill(invalid, 0)
Comment on lines +48 to +49
Copy link
Member Author

@fritzo fritzo May 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the crux of the NAN gradient fix


try:
self.total_count = n
result = super().log_prob(k)
finally:
self.total_count = total_count

return result.masked_fill(invalid, -math.inf)
3 changes: 2 additions & 1 deletion pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot,
matrix_and_mvn_to_gamma_gaussian)
from pyro.ops.gaussian import Gaussian, gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian
from pyro.ops.tensor_utils import cholesky, cholesky_solve, safe_log
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky, cholesky_solve


@torch.jit.script
Expand Down
2 changes: 1 addition & 1 deletion pyro/ops/einsum/torch_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from pyro.ops.einsum.util import Tensordot
from pyro.ops.tensor_utils import safe_log
from pyro.ops.special import safe_log


def transpose(a, axes):
Expand Down
20 changes: 20 additions & 0 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
import torch


class _SafeLog(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.log()

@staticmethod
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad / x.clamp(min=torch.finfo(x.dtype).eps)


def safe_log(x):
"""
Like :func:`torch.log` but avoids infinite gradients at log(0)
by clamping them to at most ``1 / finfo.eps``.
"""
return _SafeLog.apply(x)


def log_beta(x, y, tol=0.):
"""
Computes log Beta function.
Expand Down
20 changes: 0 additions & 20 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,6 @@
_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)


class _SafeLog(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.log()

@staticmethod
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad / x.clamp(min=torch.finfo(x.dtype).eps)


def safe_log(x):
"""
Like :func:`torch.log` but avoids infinite gradients at log(0)
by clamping them to at most ``1 / finfo.eps``.
"""
return _SafeLog.apply(x)


def block_diag_embed(mat):
"""
Takes a tensor of shape (..., B, M, N) and returns a block diagonal tensor
Expand Down
149 changes: 84 additions & 65 deletions tests/distributions/test_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,92 @@

import pytest
import torch
from torch.autograd import grad

import pyro.distributions as dist
from pyro.contrib.epidemiology.distributions import set_approx_log_prob_tol
from tests.common import assert_equal


def test_extended_binomial():
total_count = torch.tensor([1., 2., 10.])
probs = torch.tensor([0.5, 0.4, 0.2])

d1 = dist.Binomial(total_count, probs)
d2 = dist.ExtendedBinomial(total_count, probs)

# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBinomial(total_count, 0.5)
assert (d.log_prob(data) == -math.inf).all()


def test_extended_beta_binomial():
concentration1 = torch.tensor([1.0, 2.0, 1.0])
concentration0 = torch.tensor([0.5, 1.0, 2.0])
total_count = torch.tensor([1., 2., 10.])

d1 = dist.BetaBinomial(concentration1, concentration0, total_count)
d2 = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)

# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBetaBinomial(1.5, 1.5, total_count)
assert (d.log_prob(data) == -math.inf).all()
def check_grad(value, *params):
grads = grad(value.sum(), params, create_graph=True)
assert all(torch.isfinite(g).all() for g in grads)


@pytest.mark.parametrize("tol", [0., 0.02, 0.05, 0.1])
def test_extended_binomial(tol):
with set_approx_log_prob_tol(tol):
total_count = torch.tensor([0., 1., 2., 10.])
probs = torch.tensor([0.5, 0.5, 0.4, 0.2]).requires_grad_()

d1 = dist.Binomial(total_count, probs)
d2 = dist.ExtendedBinomial(total_count, probs)
# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()
check_grad(log_prob, probs)

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
probs = torch.tensor(0.5).requires_grad_()
d = dist.ExtendedBinomial(total_count, probs)
log_prob = d.log_prob(data)
assert (log_prob == -math.inf).all()
check_grad(log_prob, probs)


@pytest.mark.parametrize("tol", [0., 0.02, 0.05, 0.1])
def test_extended_beta_binomial(tol):
with set_approx_log_prob_tol(tol):
concentration1 = torch.tensor([0.2, 1.0, 2.0, 1.0]).requires_grad_()
concentration0 = torch.tensor([0.2, 0.5, 1.0, 2.0]).requires_grad_()
total_count = torch.tensor([0., 1., 2., 10.])

d1 = dist.BetaBinomial(concentration1, concentration0, total_count)
d2 = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)

# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()
check_grad(log_prob, concentration1, concentration0)

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
concentration1 = torch.tensor(1.5).requires_grad_()
concentration0 = torch.tensor(1.5).requires_grad_()
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)
log_prob = d.log_prob(data)
assert (log_prob == -math.inf).all()
check_grad(log_prob, concentration1, concentration0)
18 changes: 17 additions & 1 deletion tests/ops/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,24 @@

import pytest
import torch
from torch.autograd import grad

from pyro.ops.special import log_beta, log_binomial
from pyro.ops.special import log_beta, log_binomial, safe_log
from tests.common import assert_equal


def test_safe_log():
# Test values.
x = torch.randn(1000).exp().requires_grad_()
expected = x.log()
actual = safe_log(x)
assert_equal(actual, expected)
assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0])

# Test gradients.
x = torch.tensor(0., requires_grad=True)
assert not torch.isfinite(grad(x.log(), [x])[0])
assert torch.isfinite(grad(safe_log(x), [x])[0])


@pytest.mark.parametrize("tol", [
Expand Down
17 changes: 1 addition & 16 deletions tests/ops/test_tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,16 @@
import pytest
import scipy.fftpack as fftpack
import torch
from torch.autograd import grad

import pyro
from pyro.ops.tensor_utils import (block_diag_embed, block_diagonal, convolve, dct, idct, next_fast_len,
periodic_cumsum, periodic_features, periodic_repeat, precision_to_scale_tril,
repeated_matmul, safe_log)
repeated_matmul)
from tests.common import assert_close, assert_equal

pytestmark = pytest.mark.stage('unit')


def test_safe_log():
# Test values.
x = torch.randn(1000).exp().requires_grad_()
expected = x.log()
actual = safe_log(x)
assert_equal(actual, expected)
assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0])

# Test gradients.
x = torch.tensor(0., requires_grad=True)
assert not torch.isfinite(grad(x.log(), [x])[0])
assert torch.isfinite(grad(safe_log(x), [x])[0])


@pytest.mark.parametrize('batch_size', [1, 2, 3])
@pytest.mark.parametrize('block_size', [torch.Size([2, 2]), torch.Size([3, 1]), torch.Size([4, 2])])
def test_block_diag_embed(batch_size, block_size):
Expand Down