Skip to content

Commit

Permalink
Use approximate log_beta() in .fit(), .predict() (#2502)
Browse files Browse the repository at this point in the history
* Fix ExtendedBetaBinomial gradient issue

* Use approximate log_beta in CompartmentalModel

* Revive change that had been lost in merge conflict

* Fix another merge conflict error
  • Loading branch information
fritzo authored May 25, 2020
1 parent c00ff3b commit aa40beb
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 114 deletions.
3 changes: 2 additions & 1 deletion examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS, config_enumerate, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.ops.tensor_utils import convolve, safe_log
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import convolve
from pyro.util import warn_if_nan

logging.basicConfig(format='%(message)s', level=logging.INFO)
Expand Down
2 changes: 2 additions & 0 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def generate(self, fixed={}):
self._concat_series(samples)
return samples

@set_approx_log_prob_tol(0.1)
def fit(self, **options):
r"""
Runs inference to generate posterior samples.
Expand Down Expand Up @@ -397,6 +398,7 @@ def heuristic():
return mcmc # E.g. so user can run mcmc.summary().

@torch.no_grad()
@set_approx_log_prob_tol(0.1)
@set_approx_sample_thresh(10000)
def predict(self, forecast=0):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/epidemiology/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def set_approx_log_prob_tol(tol):
:type tol: int or float.
"""
assert isinstance(tol, (float, int))
assert tol > 0
assert tol >= 0
old1 = dist.Binomial.approx_log_prob_tol
old2 = dist.BetaBinomial.approx_log_prob_tol
try:
Expand Down
Empty file removed pyro/contrib/epidemiology/seir.py
Empty file.
2 changes: 1 addition & 1 deletion pyro/contrib/epidemiology/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.util import broadcast_shape
from pyro.ops.tensor_utils import safe_log
from pyro.ops.special import safe_log


def clamp(tensor, *, min=None, max=None):
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/coalescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.distributions import constraints

from pyro.distributions.util import broadcast_shape, is_validation_enabled
from pyro.ops.tensor_utils import safe_log
from pyro.ops.special import safe_log

from .torch_distribution import TorchDistribution

Expand Down
9 changes: 6 additions & 3 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ class BetaBinomial(TorchDistribution):
is unknown and randomly drawn from a :class:`~pyro.distributions.Beta` distribution
prior to a certain number of Bernoulli trials given by ``total_count``.
:param float or torch.Tensor concentration1: 1st concentration parameter (alpha) for the
:param concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param float or torch.Tensor concentration0: 2nd concentration parameter (beta) for the
:type concentration1: float or torch.Tensor
:param concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param int or torch.Tensor total_count: number of Bernoulli trials.
:type concentration0: float or torch.Tensor
:param total_count: Number of Bernoulli trials.
:type total_count: float or torch.Tensor
"""
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive,
'total_count': constraints.nonnegative_integer}
Expand Down
18 changes: 15 additions & 3 deletions pyro/distributions/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ExtendedBinomial(Binomial):

def log_prob(self, value):
result = super().log_prob(value)
invalid = ~super().support.check(value)
invalid = (value < 0) | (value > self.total_count)
return result.masked_fill(invalid, -math.inf)


Expand All @@ -40,6 +40,18 @@ class ExtendedBetaBinomial(BetaBinomial):
support = constraints.integer

def log_prob(self, value):
result = super().log_prob(value)
invalid = ~super().support.check(value)
if self._validate_args:
self._validate_sample(value)

total_count = self.total_count
invalid = (value < 0) | (value > total_count)
n = total_count.clamp(min=0)
k = value.masked_fill(invalid, 0)

try:
self.total_count = n
result = super().log_prob(k)
finally:
self.total_count = total_count

return result.masked_fill(invalid, -math.inf)
3 changes: 2 additions & 1 deletion pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot,
matrix_and_mvn_to_gamma_gaussian)
from pyro.ops.gaussian import Gaussian, gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian
from pyro.ops.tensor_utils import cholesky, cholesky_solve, safe_log
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky, cholesky_solve


@torch.jit.script
Expand Down
2 changes: 1 addition & 1 deletion pyro/ops/einsum/torch_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from pyro.ops.einsum.util import Tensordot
from pyro.ops.tensor_utils import safe_log
from pyro.ops.special import safe_log


def transpose(a, axes):
Expand Down
20 changes: 20 additions & 0 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
import torch


class _SafeLog(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.log()

@staticmethod
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad / x.clamp(min=torch.finfo(x.dtype).eps)


def safe_log(x):
"""
Like :func:`torch.log` but avoids infinite gradients at log(0)
by clamping them to at most ``1 / finfo.eps``.
"""
return _SafeLog.apply(x)


def log_beta(x, y, tol=0.):
"""
Computes log Beta function.
Expand Down
20 changes: 0 additions & 20 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,6 @@
_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)


class _SafeLog(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.log()

@staticmethod
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad / x.clamp(min=torch.finfo(x.dtype).eps)


def safe_log(x):
"""
Like :func:`torch.log` but avoids infinite gradients at log(0)
by clamping them to at most ``1 / finfo.eps``.
"""
return _SafeLog.apply(x)


def block_diag_embed(mat):
"""
Takes a tensor of shape (..., B, M, N) and returns a block diagonal tensor
Expand Down
149 changes: 84 additions & 65 deletions tests/distributions/test_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,92 @@

import pytest
import torch
from torch.autograd import grad

import pyro.distributions as dist
from pyro.contrib.epidemiology.distributions import set_approx_log_prob_tol
from tests.common import assert_equal


def test_extended_binomial():
total_count = torch.tensor([1., 2., 10.])
probs = torch.tensor([0.5, 0.4, 0.2])

d1 = dist.Binomial(total_count, probs)
d2 = dist.ExtendedBinomial(total_count, probs)

# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBinomial(total_count, 0.5)
assert (d.log_prob(data) == -math.inf).all()


def test_extended_beta_binomial():
concentration1 = torch.tensor([1.0, 2.0, 1.0])
concentration0 = torch.tensor([0.5, 1.0, 2.0])
total_count = torch.tensor([1., 2., 10.])

d1 = dist.BetaBinomial(concentration1, concentration0, total_count)
d2 = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)

# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBetaBinomial(1.5, 1.5, total_count)
assert (d.log_prob(data) == -math.inf).all()
def check_grad(value, *params):
grads = grad(value.sum(), params, create_graph=True)
assert all(torch.isfinite(g).all() for g in grads)


@pytest.mark.parametrize("tol", [0., 0.02, 0.05, 0.1])
def test_extended_binomial(tol):
with set_approx_log_prob_tol(tol):
total_count = torch.tensor([0., 1., 2., 10.])
probs = torch.tensor([0.5, 0.5, 0.4, 0.2]).requires_grad_()

d1 = dist.Binomial(total_count, probs)
d2 = dist.ExtendedBinomial(total_count, probs)
# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()
check_grad(log_prob, probs)

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
probs = torch.tensor(0.5).requires_grad_()
d = dist.ExtendedBinomial(total_count, probs)
log_prob = d.log_prob(data)
assert (log_prob == -math.inf).all()
check_grad(log_prob, probs)


@pytest.mark.parametrize("tol", [0., 0.02, 0.05, 0.1])
def test_extended_beta_binomial(tol):
with set_approx_log_prob_tol(tol):
concentration1 = torch.tensor([0.2, 1.0, 2.0, 1.0]).requires_grad_()
concentration0 = torch.tensor([0.2, 0.5, 1.0, 2.0]).requires_grad_()
total_count = torch.tensor([0., 1., 2., 10.])

d1 = dist.BetaBinomial(concentration1, concentration0, total_count)
d2 = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)

# Check on good data.
data = d1.sample((100,))
assert_equal(d1.log_prob(data), d2.log_prob(data))

# Check on extended data.
data = torch.arange(-10., 20.).unsqueeze(-1)
with pytest.raises(ValueError):
d1.log_prob(data)
log_prob = d2.log_prob(data)
valid = d1.support.check(data)
assert ((log_prob > -math.inf) == valid).all()
check_grad(log_prob, concentration1, concentration0)

# Check on shape error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor([0., 0.]))

# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
concentration1 = torch.tensor(1.5).requires_grad_()
concentration0 = torch.tensor(1.5).requires_grad_()
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)
log_prob = d.log_prob(data)
assert (log_prob == -math.inf).all()
check_grad(log_prob, concentration1, concentration0)
18 changes: 17 additions & 1 deletion tests/ops/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,24 @@

import pytest
import torch
from torch.autograd import grad

from pyro.ops.special import log_beta, log_binomial
from pyro.ops.special import log_beta, log_binomial, safe_log
from tests.common import assert_equal


def test_safe_log():
# Test values.
x = torch.randn(1000).exp().requires_grad_()
expected = x.log()
actual = safe_log(x)
assert_equal(actual, expected)
assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0])

# Test gradients.
x = torch.tensor(0., requires_grad=True)
assert not torch.isfinite(grad(x.log(), [x])[0])
assert torch.isfinite(grad(safe_log(x), [x])[0])


@pytest.mark.parametrize("tol", [
Expand Down
17 changes: 1 addition & 16 deletions tests/ops/test_tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,16 @@
import pytest
import scipy.fftpack as fftpack
import torch
from torch.autograd import grad

import pyro
from pyro.ops.tensor_utils import (block_diag_embed, block_diagonal, convolve, dct, idct, next_fast_len,
periodic_cumsum, periodic_features, periodic_repeat, precision_to_scale_tril,
repeated_matmul, safe_log)
repeated_matmul)
from tests.common import assert_close, assert_equal

pytestmark = pytest.mark.stage('unit')


def test_safe_log():
# Test values.
x = torch.randn(1000).exp().requires_grad_()
expected = x.log()
actual = safe_log(x)
assert_equal(actual, expected)
assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0])

# Test gradients.
x = torch.tensor(0., requires_grad=True)
assert not torch.isfinite(grad(x.log(), [x])[0])
assert torch.isfinite(grad(safe_log(x), [x])[0])


@pytest.mark.parametrize('batch_size', [1, 2, 3])
@pytest.mark.parametrize('block_size', [torch.Size([2, 2]), torch.Size([3, 1]), torch.Size([4, 2])])
def test_block_diag_embed(batch_size, block_size):
Expand Down

0 comments on commit aa40beb

Please sign in to comment.