diff --git a/.travis.yml b/.travis.yml index 9883d7cde3..df8943a823 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,11 +8,8 @@ env: install: - pip install -U pip - - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then - pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl; - else - pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl; - fi + - pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torchvision --no-dependencies - pip install .[test] - pip freeze diff --git a/Makefile b/Makefile index c0e2dc9692..8117eb3fe1 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,12 @@ test-cuda: lint FORCE CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx --stage unit CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda +test-jit: FORCE + @echo See jit.log + pytest -v -n auto --tb=short --runxfail tests/infer/test_jit.py tests/test_examples.py::test_jit | tee jit.log + pytest -v -n auto --tb=short --runxfail tests/infer/mcmc/test_hmc.py tests/infer/mcmc/test_nuts.py \ + -k JIT=True | tee -a jit.log + clean: FORCE git clean -dfx -e pyro-egg.info diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 8bcf8f2aa7..b38bd7f633 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -56,14 +56,6 @@ AVFMultivariateNormal :undoc-members: :show-inheritance: -Binomial --------- - -.. autoclass:: pyro.distributions.Binomial - :members: - :undoc-members: - :show-inheritance: - Delta ----- .. autoclass:: pyro.distributions.Delta @@ -85,13 +77,6 @@ GaussianScaleMixture :undoc-members: :show-inheritance: -HalfCauchy ----------- -.. autoclass:: pyro.distributions.HalfCauchy - :members: - :undoc-members: - :show-inheritance: - LowRankMultivariateNormal ------------------------- .. autoclass:: pyro.distributions.LowRankMultivariateNormal diff --git a/docs/source/primitives.rst b/docs/source/primitives.rst index fd47624400..bc7100304f 100644 --- a/docs/source/primitives.rst +++ b/docs/source/primitives.rst @@ -20,4 +20,4 @@ Primitives .. autofunction:: pyro.validation_enabled .. autofunction:: pyro.enable_validation -.. autofunction:: pyro.ops.jit.compile +.. autofunction:: pyro.ops.jit.trace diff --git a/examples/baseball.py b/examples/baseball.py index 4aee43956e..d3aeb7ac7a 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -218,7 +218,7 @@ def main(args): baseball_dataset = pd.read_csv(DATA_URL, "\t") train, _, player_names = train_test_split(baseball_dataset) at_bats, hits = train[:, 0], train[:, 1] - nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) + nuts_kernel = NUTS(conditioned_model, adapt_step_size=True, jit_compile=args.jit) logging.info("Original Dataset:") logging.info(baseball_dataset) @@ -270,5 +270,7 @@ def main(args): parser.add_argument("-n", "--num-samples", nargs="?", default=1200, type=int) parser.add_argument("--warmup-steps", nargs='?', default=300, type=int) parser.add_argument("--rng_seed", nargs='?', default=0, type=int) + parser.add_argument('--jit', action='store_true', default=False, + help='use PyTorch jit') args = parser.parse_args() main(args) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index b8a577e14b..fe29f96e6d 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -34,7 +34,7 @@ def conditioned_model(model, sigma, y): def main(args): - nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) + nuts_kernel = NUTS(conditioned_model, adapt_step_size=True, jit_compile=args.jit) posterior = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps)\ .run(model, data.sigma, data.y) marginal_mu_tau = EmpiricalMarginal(posterior, sites=["mu", "tau"])\ @@ -54,6 +54,7 @@ def main(args): help='number of MCMC samples (default: 1000)') parser.add_argument('--warmup-steps', type=int, default=1000, help='number of MCMC samples for warmup (default: 1000)') + parser.add_argument('--jit', action='store_true', default=False) args = parser.parse_args() main(args) diff --git a/pyro/contrib/gp/models/gplvm.py b/pyro/contrib/gp/models/gplvm.py index 51260d99e8..ac68cb162f 100644 --- a/pyro/contrib/gp/models/gplvm.py +++ b/pyro/contrib/gp/models/gplvm.py @@ -1,14 +1,14 @@ from __future__ import absolute_import, division, print_function -import torch from torch.distributions import constraints from torch.nn import Parameter import pyro -from pyro.contrib.gp.util import Parameterized import pyro.distributions as dist import pyro.infer as infer import pyro.optim as optim +from pyro.contrib.gp.util import Parameterized +from pyro.distributions.util import eye_like from pyro.params import param_with_module_name @@ -74,7 +74,7 @@ def __init__(self, base_model, name="GPLVM"): C = self.X_loc.shape[1] X_scale_tril_shape = self.X_loc.shape + (C,) - Id = torch.eye(C, out=self.X_loc.new_empty(C, C)) + Id = eye_like(self.X_loc, C) X_scale_tril = Id.expand(X_scale_tril_shape) self.X_scale_tril = Parameter(X_scale_tril) self.set_constraint("X_scale_tril", constraints.lower_cholesky) @@ -87,7 +87,7 @@ def model(self): # sample X from unit multivariate normal distribution zero_loc = self.X_loc.new_zeros(self.X_loc.shape) C = self.X_loc.shape[1] - Id = torch.eye(C, out=self.X_loc.new_empty(C, C)) + Id = eye_like(self.X_loc, C) X_name = param_with_module_name(self.name, "X") X = pyro.sample(X_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim()-1)) diff --git a/pyro/contrib/gp/models/vgp.py b/pyro/contrib/gp/models/vgp.py index 2acfed9ef9..4a8f8b0a23 100644 --- a/pyro/contrib/gp/models/vgp.py +++ b/pyro/contrib/gp/models/vgp.py @@ -8,6 +8,7 @@ import pyro.distributions as dist from pyro.contrib.gp.models.model import GPModel from pyro.contrib.gp.util import conditional +from pyro.distributions.util import eye_like from pyro.params import param_with_module_name @@ -74,7 +75,7 @@ def __init__(self, X, y, kernel, likelihood, mean_function=None, self.f_loc = Parameter(f_loc) f_scale_tril_shape = self.latent_shape + (N, N) - Id = torch.eye(N, out=self.X.new_empty(N, N)) + Id = eye_like(self.X, N) f_scale_tril = Id.expand(f_scale_tril_shape) self.f_scale_tril = Parameter(f_scale_tril) self.set_constraint("f_scale_tril", constraints.lower_cholesky) @@ -96,7 +97,7 @@ def model(self): f_name = param_with_module_name(self.name, "f") if self.whiten: - Id = torch.eye(N, out=self.X.new_empty(N, N)) + Id = eye_like(self.X, N) pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) diff --git a/pyro/contrib/gp/models/vsgp.py b/pyro/contrib/gp/models/vsgp.py index f0c67e084c..2d6586bad6 100644 --- a/pyro/contrib/gp/models/vsgp.py +++ b/pyro/contrib/gp/models/vsgp.py @@ -9,6 +9,7 @@ import pyro.poutine as poutine from pyro.contrib.gp.models.model import GPModel from pyro.contrib.gp.util import conditional +from pyro.distributions.util import eye_like from pyro.params import param_with_module_name @@ -98,7 +99,7 @@ def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None, self.u_loc = Parameter(u_loc) u_scale_tril_shape = self.latent_shape + (M, M) - Id = torch.eye(M, out=self.Xu.new_empty(M, M)) + Id = eye_like(self.Xu, M) u_scale_tril = Id.expand(u_scale_tril_shape) self.u_scale_tril = Parameter(u_scale_tril) self.set_constraint("u_scale_tril", constraints.lower_cholesky) @@ -120,7 +121,7 @@ def model(self): zero_loc = Xu.new_zeros(u_loc.shape) u_name = param_with_module_name(self.name, "u") if self.whiten: - Id = torch.eye(M, out=Xu.new_empty(M, M)) + Id = eye_like(Xu, M) pyro.sample(u_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index bc624c4ccc..43f0d80f1c 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -126,7 +126,7 @@ def naive_rainforth_eig(model, design, observation_labels, target_labels=None, reexpanded_design = lexpand(design, M_prime, N) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() - conditional_lp = logsumexp(sum(retrace.nodes[l]["log_prob"] for l in observation_labels), 0) \ + conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - np.log(M_prime) else: # This assumes that y are independent conditional on theta @@ -141,7 +141,7 @@ def naive_rainforth_eig(model, design, observation_labels, target_labels=None, reexpanded_design = lexpand(design, M, 1) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() - marginal_lp = logsumexp(sum(retrace.nodes[l]["log_prob"] for l in observation_labels), 0) \ + marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - np.log(M) return (conditional_lp - marginal_lp).sum(0)/N @@ -334,30 +334,6 @@ def loss_fn(design, num_particles): return loss_fn -def logsumexp(inputs, dim=None, keepdim=False): - """Numerically stable logsumexp. - - Args: - inputs: A Variable with any shape. - dim: An integer. - keepdim: A boolean. - - Returns: - Equivalent of `log(sum(exp(inputs), dim=dim, keepdim=keepdim))`. - """ - # For a 1-D array x (any array along a single dimension), - # log sum exp(x) = s + log sum exp(x - s) - # with s = max(x) being a common choice. - if dim is None: - inputs = inputs.view(-1) - dim = 0 - s, _ = torch.max(inputs, dim=dim, keepdim=True) - outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() - if not keepdim: - outputs = outputs.squeeze(dim) - return outputs - - class EwmaLog(torch.autograd.Function): """Logarithm function with exponentially weighted moving average for gradients. diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 280f434de7..2d30709ba7 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -2,16 +2,13 @@ import pyro.distributions.torch_patch # noqa F403 from pyro.distributions.avf_mvn import AVFMultivariateNormal -from pyro.distributions.binomial import Binomial from pyro.distributions.delta import Delta from pyro.distributions.diag_normal_mixture_shared_cov import MixtureOfDiagNormalsSharedCovariance from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals from pyro.distributions.distribution import Distribution from pyro.distributions.empirical import Empirical from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture -from pyro.distributions.half_cauchy import HalfCauchy from pyro.distributions.iaf import InverseAutoregressiveFlow -from pyro.distributions.lowrank_mvn import LowRankMultivariateNormal from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal from pyro.distributions.rejector import Rejector @@ -30,14 +27,11 @@ "is_validation_enabled", "validation_enabled", "AVFMultivariateNormal", - "Binomial", "Delta", "Distribution", "Empirical", "GaussianScaleMixture", - "HalfCauchy", "InverseAutoregressiveFlow", - "LowRankMultivariateNormal", "MaskedMixture", "MixtureOfDiagNormalsSharedCovariance", "MixtureOfDiagNormals", diff --git a/pyro/distributions/binomial.py b/pyro/distributions/binomial.py deleted file mode 100644 index 13d164a3e5..0000000000 --- a/pyro/distributions/binomial.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import absolute_import, division, print_function - -from numbers import Number - -import torch -from torch.distributions import constraints -from torch.distributions.utils import broadcast_all, lazy_property, logits_to_probs, probs_to_logits - -from pyro.distributions.torch_distribution import TorchDistributionMixin - - -class Binomial(torch.distributions.Distribution, TorchDistributionMixin): - r""" - Creates a Binomial distribution parameterized by `total_count` and - either `probs` or `logits` (but not both). `total_count` must be - broadcastable with `probs`/`logits`. - - This is adapted from :class:`torch.distributions.binomial.Binomial`, - with the important difference that `total_count` is not limited to - being a single `int`, but can be a `torch.Tensor`. - - Example:: - - >>> m = Binomial(100, torch.Tensor([0 , .2, .8, 1])) - >>> m.sample() # doctest: +SKIP - 0 - 22 - 71 - 100 - [torch.FloatTensor of size 4]] - - >>> m = Binomial(torch.Tensor([[5.], [10.]]), torch.Tensor([0.5, 0.8])) - >>> m.sample() # doctest: +SKIP - 4 5 - 7 6 - [torch.FloatTensor of size (2,2)] - - :param (Tensor) total_count: number of Bernoulli trials - :param (Tensor) probs: Event probabilities - :param (Tensor) logits: Event log-odds - """ - arg_constraints = {'total_count': constraints.nonnegative_integer, - 'probs': constraints.unit_interval} - has_enumerate_support = True - - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): - if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") - if probs is not None: - self.total_count, self.probs, = broadcast_all(total_count, probs) - is_scalar = isinstance(self.probs, Number) - else: - self.total_count, self.logits, = broadcast_all(total_count, logits) - is_scalar = isinstance(self.logits, Number) - - self._param = self.probs if probs is not None else self.logits - if is_scalar: - batch_shape = torch.Size() - else: - batch_shape = self._param.shape - super(Binomial, self).__init__(batch_shape, validate_args=validate_args) - - def _new(self, *args, **kwargs): - return self._param.new(*args, **kwargs) - - @constraints.dependent_property - def support(self): - return constraints.integer_interval(0, self.total_count) - - @property - def mean(self): - return self.total_count * self.probs - - @property - def variance(self): - return self.total_count * self.probs * (1 - self.probs) - - @lazy_property - def logits(self): - return probs_to_logits(self.probs, is_binary=True) - - @lazy_property - def probs(self): - return logits_to_probs(self.logits, is_binary=True) - - @property - def param_shape(self): - return self._param.shape - - def sample(self, sample_shape=torch.Size()): - with torch.no_grad(): - max_count = max(int(self.total_count.max()), 1) - shape = self._extended_shape(sample_shape) + (max_count,) - bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)) - if self.total_count.min() != max_count: - arange = torch.arange(max_count, out=self.total_count.new_empty(max_count)) - mask = arange >= self.total_count.unsqueeze(-1) - bernoullis.masked_fill_(mask, 0.) - return bernoullis.sum(dim=-1) - - def log_prob(self, value): - if self._validate_args: - self._validate_sample(value) - log_factorial_n = torch.lgamma(self.total_count + 1) - log_factorial_k = torch.lgamma(value + 1) - log_factorial_nmk = torch.lgamma(self.total_count - value + 1) - max_val = (-self.logits).clamp(min=0.0) - # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) - return (log_factorial_n - log_factorial_k - log_factorial_nmk + - value * self.logits + self.total_count * max_val - - self.total_count * torch.log1p((self.logits + 2 * max_val).exp())) - - def enumerate_support(self, expand=True): - total_count = int(self.total_count.max()) - if not self.total_count.min() == total_count: - raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.") - values = self._new(1 + total_count,) - torch.arange(1 + total_count, out=values) - values = values.view((-1,) + (1,) * len(self._batch_shape)) - if expand: - values = values.expand((-1,) + self._batch_shape) - return values - - def expand(self, batch_shape): - try: - return super(Binomial, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - total_count = self.total_count.expand(batch_shape) - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape) - return type(self)(total_count, probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape) - return type(self)(total_count, logits=logits, validate_args=validate_args) diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index a4d1c7bea7..774977cae5 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -42,12 +42,14 @@ def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): self.log_density = log_density super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args) - def expand(self, batch_shape): - validate_args = self.__dict__.get('_validate_args') + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Delta, _instance) batch_shape = torch.Size(batch_shape) - v = self.v.expand(batch_shape + self.event_shape) - log_density = self.log_density.expand(batch_shape) - return Delta(v, log_density, self.event_dim, validate_args=validate_args) + new.v = self.v.expand(batch_shape + self.event_shape) + new.log_density = self.log_density.expand(batch_shape) + super(Delta, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new def rsample(self, sample_shape=torch.Size()): shape = sample_shape + self.v.shape diff --git a/pyro/distributions/diag_normal_mixture.py b/pyro/distributions/diag_normal_mixture.py index 0ce30a1e8a..521905a679 100644 --- a/pyro/distributions/diag_normal_mixture.py +++ b/pyro/distributions/diag_normal_mixture.py @@ -67,7 +67,22 @@ def __init__(self, locs, coord_scale, component_logits): self.dim = locs.size(-1) self.categorical = Categorical(logits=component_logits) self.probs = self.categorical.probs - super(MixtureOfDiagNormals, self).__init__(batch_shape=batch_shape, event_shape=(self.dim,)) + super(MixtureOfDiagNormals, self).__init__(batch_shape=torch.Size(batch_shape), + event_shape=torch.Size((self.dim,))) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MixtureOfDiagNormals, _instance) + new.batch_mode = True + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.locs = self.locs.expand(batch_shape + self.locs.shape[-2:]) + new.coord_scale = self.coord_scale.expand(batch_shape + self.coord_scale.shape[-2:]) + new.component_logits = self.component_logits.expand(batch_shape + self.component_logits.shape[-1:]) + new.categorical = self.categorical.expand(batch_shape) + new.probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) + super(MixtureOfDiagNormals, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new def log_prob(self, value): epsilon = (value.unsqueeze(-2) - self.locs) / self.coord_scale # L B K D diff --git a/pyro/distributions/diag_normal_mixture_shared_cov.py b/pyro/distributions/diag_normal_mixture_shared_cov.py index 4361d5790c..fd1b44a8b9 100644 --- a/pyro/distributions/diag_normal_mixture_shared_cov.py +++ b/pyro/distributions/diag_normal_mixture_shared_cov.py @@ -68,6 +68,21 @@ def __init__(self, locs, coord_scale, component_logits): self.probs = self.categorical.probs super(MixtureOfDiagNormalsSharedCovariance, self).__init__(batch_shape=batch_shape, event_shape=(self.dim,)) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MixtureOfDiagNormalsSharedCovariance, _instance) + new.batch_mode = True + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.locs = self.locs.expand(batch_shape + self.locs.shape[-2:]) + coord_scale_shape = -1 if self.batch_mode else -2 + new.coord_scale = self.coord_scale.expand(batch_shape + self.coord_scale.shape[coord_scale_shape:]) + new.component_logits = self.component_logits.expand(batch_shape + self.component_logits.shape[-1:]) + new.categorical = self.categorical.expand(batch_shape) + new.probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) + super(MixtureOfDiagNormalsSharedCovariance, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + def log_prob(self, value): # TODO: use torch.logsumexp once it's in PyTorch release coord_scale = self.coord_scale.unsqueeze(-2) if self.batch_mode else self.coord_scale diff --git a/pyro/distributions/half_cauchy.py b/pyro/distributions/half_cauchy.py deleted file mode 100644 index a4b8f8f0eb..0000000000 --- a/pyro/distributions/half_cauchy.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import math - -from torch.distributions import constraints -from torch.distributions.transforms import AbsTransform, AffineTransform -from torch.distributions.utils import broadcast_all - -from pyro.distributions.torch import Cauchy, TransformedDistribution - - -class HalfCauchy(TransformedDistribution): - r""" - Half-Cauchy distribution. - - This is a continuous distribution with lower-bounded domain (`x > loc`). - See also the :class:`~pyro.distributions.torch.Cauchy` distribution. - - :param torch.Tensor loc: lower bound of the distribution. - :param torch.Tensor scale: half width at half maximum. - """ - arg_constraints = Cauchy.arg_constraints - support = Cauchy.support - - def __init__(self, loc=0, scale=1): - loc, scale = broadcast_all(loc, scale) - base_dist = Cauchy(0, scale) - transforms = [AbsTransform(), AffineTransform(loc, 1)] - super(HalfCauchy, self).__init__(base_dist, transforms) - - @property - def loc(self): - return self.transforms[1].loc - - @property - def scale(self): - return self.base_dist.scale - - @constraints.dependent_property - def support(self): - return constraints.greater_than(self.loc) - - def log_prob(self, value): - log_prob = self.base_dist.log_prob(value - self.loc) + math.log(2) - log_prob[value < self.loc] = -float('inf') - return log_prob - - def entropy(self): - return self.base_dist.entropy() - math.log(2) - - def expand(self, batch_shape): - try: - return super(HalfCauchy, self).expand(batch_shape) - except NotImplementedError: - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale) diff --git a/pyro/distributions/lowrank_mvn.py b/pyro/distributions/lowrank_mvn.py deleted file mode 100644 index e88b2d84f6..0000000000 --- a/pyro/distributions/lowrank_mvn.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import math - -import torch -from torch.distributions import constraints -from torch.distributions.utils import lazy_property - -from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistribution - - -def _matrix_triangular_solve_compat(b, A, upper=True): - """ - Computes the solution to the linear equation AX = b, - where A is a triangular matrix. - - :param b: A 1D or 2D tensor of size N or N x C. - :param A: A 2D tensor of size N X N. - :param upper: A flag if A is a upper triangular matrix or not. - """ - return b.view(b.shape[0], -1).trtrs(A, upper=upper)[0].view(b.shape) - - -class LowRankMultivariateNormal(TorchDistribution): - """ - Low Rank Multivariate Normal distribution. - - Implements fast computation for log probability of Multivariate Normal distribution - when the covariance matrix has the form:: - - covariance_matrix = W @ W.T + D. - - Here D is a diagonal vector and ``W`` is a matrix of size ``N x M``. The - computation will be beneficial when ``M << N``. - - :param torch.Tensor loc: Mean. - Must be a 1D or 2D tensor with the last dimension of size N. - :param torch.Tensor W_term: W term of covariance matrix. - Must be in 2 dimensional of size N x M. - :param torch.Tensor D_term: D term of covariance matrix. - Must be in 1 dimensional of size N. - :param float trace_term: A optional term to be added into Mahalabonis term - according to p(y) = N(y|loc, cov).exp(-1/2 * trace_term). - """ - arg_constraints = {"loc": constraints.real, - "covariance_matrix_D_term": constraints.positive, - "scale_tril": constraints.lower_triangular} - support = IndependentConstraint(constraints.real, 1) - has_rsample = True - - def __init__(self, loc, W_term, D_term, trace_term=None): - W_term = W_term.t() - if loc.shape[-1] != D_term.shape[0]: - raise ValueError("Expected loc.shape == D_term.shape, but got {} vs {}".format( - loc.shape, D_term.shape)) - if D_term.shape[0] != W_term.shape[1]: - raise ValueError("The dimension of D_term must match the first dimension of W_term.") - if D_term.dim() != 1 or W_term.dim() != 2 or loc.dim() > 2: - raise ValueError("D_term, W_term must be 1D, 2D tensors respectively and " - "loc must be a 1D or 2D tensor.") - - self.loc = loc - self.covariance_matrix_D_term = D_term - self.covariance_matrix_W_term = W_term - self.trace_term = trace_term if trace_term is not None else 0 - - batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:] - super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape) - - @property - def mean(self): - return self.loc - - @property - def variance(self): - return self.covariance_matrix_D_term + (self.covariance_matrix_W_term ** 2).sum(0) - - @lazy_property - def scale_tril(self): - # We use the following formula to increase the numerically computation stability - # when using Cholesky decomposition (see GPML section 3.4.3): - # D + W.T @ W = D1/2 @ (I + D-1/2 @ W.T @ W @ D-1/2) @ D1/2 - Dsqrt = self.covariance_matrix_D_term.sqrt() - A = self.covariance_matrix_W_term / Dsqrt - At_A = A.t().matmul(A) - N = A.shape[1] - Id = torch.eye(N, N, out=A.new_empty(N, N)) - K = Id + At_A - L = K.potrf(upper=False) - return Dsqrt.unsqueeze(1) * L - - def rsample(self, sample_shape=torch.Size()): - white = self.loc.new_empty(sample_shape + self.loc.shape).normal_() - return self.loc + torch.matmul(white, self.scale_tril.t()) - - def log_prob(self, value): - delta = value - self.loc - logdet, mahalanobis_squared = self._compute_logdet_and_mahalanobis( - self.covariance_matrix_D_term, self.covariance_matrix_W_term, delta, self.trace_term) - normalization_const = 0.5 * (self.event_shape[-1] * math.log(2 * math.pi) + logdet) - return -(normalization_const + 0.5 * mahalanobis_squared) - - def _compute_logdet_and_mahalanobis(self, D, W, y, trace_term=0): - """ - Calculates log determinant and (squared) Mahalanobis term of covariance - matrix ``(D + Wt.W)``, where ``D`` is a diagonal matrix, based on the - "Woodbury matrix identity" and "matrix determinant lemma":: - - inv(D + Wt.W) = inv(D) - inv(D).Wt.inv(I + W.inv(D).Wt).W.inv(D) - log|D + Wt.W| = log|Id + Wt.inv(D).W| + log|D| - """ - W_Dinv = W / D - M = W.shape[0] - Id = torch.eye(M, M, out=W.new_empty(M, M)) - K = Id + W_Dinv.matmul(W.t()) - L = K.potrf(upper=False) - if y.dim() == 1: - W_Dinv_y = W_Dinv.matmul(y) - elif y.dim() == 2: - W_Dinv_y = W_Dinv.matmul(y.t()) - else: - raise NotImplementedError("SparseMultivariateNormal distribution does not support " - "computing log_prob for a tensor with more than 2 dimensionals.") - Linv_W_Dinv_y = _matrix_triangular_solve_compat(W_Dinv_y, L, upper=False) - if y.dim() == 2: - Linv_W_Dinv_y = Linv_W_Dinv_y.t() - - logdet = 2 * L.diag().log().sum() + D.log().sum() - - mahalanobis1 = (y * y / D).sum(-1) - mahalanobis2 = (Linv_W_Dinv_y * Linv_W_Dinv_y).sum(-1) - mahalanobis_squared = mahalanobis1 - mahalanobis2 + trace_term - - return logdet, mahalanobis_squared diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index 8265e52dd5..c3a393bc85 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -6,7 +6,7 @@ from torch.distributions import constraints from pyro.distributions.torch import MultivariateNormal -from pyro.distributions.util import sum_leftmost +from pyro.distributions.util import eye_like, sum_leftmost class OMTMultivariateNormal(MultivariateNormal): @@ -51,7 +51,7 @@ def backward(ctx, grad_output): g = grad_output loc_grad = sum_leftmost(grad_output, -1) - identity = torch.eye(dim, out=g.new_empty(dim, dim)) + identity = eye_like(g, dim) R_inv = torch.trtrs(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) diff --git a/pyro/distributions/relaxed_straight_through.py b/pyro/distributions/relaxed_straight_through.py index f0f7739b06..5d57cdf1f5 100644 --- a/pyro/distributions/relaxed_straight_through.py +++ b/pyro/distributions/relaxed_straight_through.py @@ -29,10 +29,6 @@ class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical): [2] Categorical Reparameterization with Gumbel-Softmax, Eric Jang, Shixiang Gu, Ben Poole """ - def __init__(self, temperature, probs=None, logits=None, validate_args=None): - super(RelaxedOneHotCategoricalStraightThrough, self).__init__(temperature=temperature, probs=probs, - logits=logits, validate_args=validate_args) - def rsample(self, sample_shape=torch.Size()): soft_sample = super(RelaxedOneHotCategoricalStraightThrough, self).rsample(sample_shape) soft_sample = clamp_probs(soft_sample) @@ -81,10 +77,6 @@ class RelaxedBernoulliStraightThrough(RelaxedBernoulli): [2] Categorical Reparameterization with Gumbel-Softmax, Eric Jang, Shixiang Gu, Ben Poole """ - def __init__(self, temperature, probs=None, logits=None, validate_args=None): - super(RelaxedBernoulliStraightThrough, self).__init__(temperature=temperature, probs=probs, - logits=logits, validate_args=validate_args) - def rsample(self, sample_shape=torch.Size()): soft_sample = super(RelaxedBernoulliStraightThrough, self).rsample(sample_shape) soft_sample = clamp_probs(soft_sample) diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index f815e0e323..137c39334a 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -18,7 +18,7 @@ def __init__(self, concentration): if concentration.data.min() < 1: raise NotImplementedError('concentration < 1 is not supported') self.concentration = concentration - self._standard_gamma = Gamma(concentration, concentration.new_tensor([1.]).squeeze().expand_as(concentration)) + self._standard_gamma = Gamma(concentration, concentration.new([1.]).squeeze().expand_as(concentration)) # The following are Marsaglia & Tsang's variable names. self._d = self.concentration - 1.0 / 3.0 self._c = 1.0 / torch.sqrt(9.0 * self._d) @@ -27,6 +27,20 @@ def __init__(self, concentration): log_scale = self.propose_log_prob(x) + self.log_prob_accept(x) - self.log_prob(x) super(RejectionStandardGamma, self).__init__(self.propose, self.log_prob_accept, log_scale) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RejectionStandardGamma, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new._standard_gamma = self._standard_gamma.expand(batch_shape) + new._d = self._d.expand(batch_shape) + new._c = self._c.expand(batch_shape) + # Compute log scale using Gamma.log_prob(). + x = new._d.detach() # just an arbitrary x. + log_scale = new.propose_log_prob(x) + new.log_prob_accept(x) - new.log_prob(x) + super(RejectionStandardGamma, new).__init__(new.propose, new.log_prob_accept, log_scale) + new._validate_args = self._validate_args + return new + def propose(self, sample_shape=torch.Size()): # Marsaglia & Tsang's x == Naesseth's epsilon x = self.concentration.new_empty(sample_shape + self.concentration.shape).normal_() @@ -65,6 +79,13 @@ def __init__(self, concentration, rate, validate_args=None): self._standard_gamma = RejectionStandardGamma(concentration) self.rate = rate + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RejectionGamma, _instance) + new = super(RejectionGamma, self).expand(batch_shape, new) + new._standard_gamma = self._standard_gamma.expand(batch_shape) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): return self._standard_gamma.rsample(sample_shape) / self.rate @@ -94,6 +115,16 @@ def __init__(self, concentration, rate, boost=1, validate_args=None): self._rejection_gamma = RejectionGamma(concentration + boost, rate) self._unboost_x_cache = None, None + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShapeAugmentedGamma, _instance) + new = super(ShapeAugmentedGamma, self).expand(batch_shape, new) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new._boost = self._boost + new._rejection_gamma = self._rejection_gamma.expand(batch_shape) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): x = self._rejection_gamma.rsample(sample_shape) boosted_x = x.clone() @@ -124,6 +155,14 @@ def __init__(self, concentration, boost=1, validate_args=None): super(ShapeAugmentedDirichlet, self).__init__(concentration, validate_args=validate_args) self._gamma = ShapeAugmentedGamma(concentration, torch.ones_like(concentration), boost) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShapeAugmentedDirichlet, _instance) + new = super(ShapeAugmentedDirichlet, self).expand(batch_shape, new) + batch_shape = torch.Size(batch_shape) + new._gamma = self._gamma.expand(batch_shape + self._gamma.concentration.shape[-1:]) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): gammas = self._gamma.rsample(sample_shape) return gammas / gammas.sum(-1, True) @@ -142,6 +181,14 @@ def __init__(self, concentration1, concentration0, boost=1, validate_args=None): alpha_beta = torch.stack([concentration1, concentration0], -1) self._gamma = ShapeAugmentedGamma(alpha_beta, torch.ones_like(alpha_beta), boost) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShapeAugmentedBeta, _instance) + new = super(ShapeAugmentedBeta, self).expand(batch_shape, new) + batch_shape = torch.Size(batch_shape) + new._gamma = self._gamma.expand(batch_shape + self._gamma.concentration.shape[-1:]) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): gammas = self._gamma.rsample(sample_shape) probs = gammas / gammas.sum(-1, True) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index ea82e5504b..c2c8b593f7 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -6,139 +6,8 @@ from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistributionMixin -class Bernoulli(torch.distributions.Bernoulli, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Bernoulli, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape) - return type(self)(logits=logits, validate_args=validate_args) - - def enumerate_support(self, expand=True): - values = self._param.new_tensor([0., 1.]) - values = values.reshape((2,) + (1,) * len(self.batch_shape)) - if expand: - values = values.expand((2,) + self.batch_shape) - return values - - -class Beta(torch.distributions.Beta, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Beta, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - concentration1 = self.concentration1.expand(batch_shape) - concentration0 = self.concentration0.expand(batch_shape) - return type(self)(concentration1, concentration0, validate_args=validate_args) - - -class Categorical(torch.distributions.Categorical, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Categorical, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape + self.logits.shape[-1:]) - return type(self)(logits=logits, validate_args=validate_args) - - def enumerate_support(self, expand=True): - num_events = self._num_events - values = torch.arange(num_events, dtype=torch.long) - values = values.view((-1,) + (1,) * len(self._batch_shape)) - if expand: - values = values.expand((-1,) + self._batch_shape) - if self._param.is_cuda: - values = values.cuda(self._param.get_device()) - return values - - -class Cauchy(torch.distributions.Cauchy, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Cauchy, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class Chi2(torch.distributions.Chi2, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Chi2, self).expand_by(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - df = self.df.expand(batch_shape) - return type(self)(df, validate_args=validate_args) - - -class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Dirichlet, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - concentration = self.concentration.expand(batch_shape + self.event_shape) - return type(self)(concentration, validate_args=validate_args) - - -class Exponential(torch.distributions.Exponential, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Exponential, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - rate = self.rate.expand(batch_shape) - return type(self)(rate, validate_args=validate_args) - - -class Gamma(torch.distributions.Gamma, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Gamma, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - concentration = self.concentration.expand(batch_shape) - rate = self.rate.expand(batch_shape) - return type(self)(concentration, rate, validate_args=validate_args) - - -class Geometric(torch.distributions.Geometric, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Geometric, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape) - return type(self)(logits=logits, validate_args=validate_args) - - -class Gumbel(torch.distributions.Gumbel, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Gumbel, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) +class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): + support = IndependentConstraint(constraints.real, 1) # TODO move upstream class Independent(torch.distributions.Independent, TorchDistributionMixin): @@ -154,155 +23,10 @@ def _validate_args(self): def _validate_args(self, value): self.base_dist._validate_args = value - def expand(self, batch_shape): - batch_shape = torch.Size(batch_shape) - base_shape = self.base_dist.batch_shape - reinterpreted_shape = base_shape[len(base_shape) - self.reinterpreted_batch_ndims:] - base_dist = self.base_dist.expand(batch_shape + reinterpreted_shape) - return type(self)(base_dist, self.reinterpreted_batch_ndims) - - def enumerate_support(self, expand=expand): - if self.reinterpreted_batch_ndims: - raise NotImplementedError("Pyro does not enumerate over cartesian products") - return self.base_dist.enumerate_support(expand=expand) - - -class Laplace(torch.distributions.Laplace, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Laplace, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(LogNormal, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Multinomial, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape + self.event_shape) - return type(self)(self.total_count, probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape + self.event_shape) - return type(self)(self.total_count, logits=logits, validate_args=validate_args) - - -class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): - support = IndependentConstraint(constraints.real, 1) # TODO move upstream - - def expand(self, batch_shape): - try: - return super(MultivariateNormal, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape + self.event_shape) - if 'scale_tril' in self.__dict__: - scale_tril = self.scale_tril.expand(batch_shape + self.event_shape + self.event_shape) - return type(self)(loc, scale_tril=scale_tril, validate_args=validate_args) - elif 'covariance_matrix' in self.__dict__: - covariance_matrix = self.covariance_matrix.expand(batch_shape + self.event_shape + self.event_shape) - return type(self)(loc, covariance_matrix=covariance_matrix, validate_args=validate_args) - else: - precision_matrix = self.precision_matrix.expand(batch_shape + self.event_shape + self.event_shape) - return type(self)(loc, precision_matrix=precision_matrix, validate_args=validate_args) - - -class Normal(torch.distributions.Normal, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Normal, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(OneHotCategorical, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape + self.event_shape) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape + self.event_shape) - return type(self)(logits=logits, validate_args=validate_args) - - def enumerate_support(self, expand=True): - n = self.event_shape[0] - values = self._new((n, n)) - torch.eye(n, out=values) - values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) - if expand: - values = values.expand((n,) + self.batch_shape + (n,)) - return values - - -class Poisson(torch.distributions.Poisson, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Poisson, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - rate = self.rate.expand(batch_shape) - return type(self)(rate, validate_args=validate_args) - - -class StudentT(torch.distributions.StudentT, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(StudentT, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - df = self.df.expand(batch_shape) - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(df, loc, scale, validate_args=validate_args) - - -class TransformedDistribution(torch.distributions.TransformedDistribution, TorchDistributionMixin): - def expand(self, batch_shape): - return super(TransformedDistribution, self).expand(batch_shape) - - -class Uniform(torch.distributions.Uniform, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Uniform, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - low = self.low.expand(batch_shape) - high = self.high.expand(batch_shape) - return type(self)(low, high, validate_args=validate_args) - # Programmatically load all distributions from PyTorch. __all__ = [] for _name, _Dist in torch.distributions.__dict__.items(): - if _name == 'Binomial': - continue if not isinstance(_Dist, type): continue if not issubclass(_Dist, torch.distributions.Distribution): diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 7276478b14..5509e5fe2d 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -1,14 +1,11 @@ from __future__ import absolute_import, division, print_function -import numbers - import torch from torch.distributions import biject_to, constraints, transform_to import pyro.distributions.torch from pyro.distributions.distribution import Distribution -from pyro.distributions.score_parts import ScoreParts -from pyro.distributions.util import broadcast_shape, scale_and_mask, sum_rightmost +from pyro.distributions.util import broadcast_shape, scale_and_mask class TorchDistributionMixin(Distribution): @@ -65,35 +62,6 @@ def shape(self, sample_shape=torch.Size()): """ return sample_shape + self.batch_shape + self.event_shape - def expand(self, batch_shape): - """ - Expands a distribution to a desired - :attr:`~torch.distributions.distribution.Distribution.batch_shape`. - - Note that this is more general than :meth:`expand_by` because - ``d.expand_by(sample_shape)`` can be reduced to - ``d.expand(sample_shape + d.batch_shape)``. - - :param torch.Size batch_shape: The target ``batch_shape``. This must - compatible with ``self.batch_shape`` similar to the requirements - of :func:`torch.Tensor.expand`: the target ``batch_shape`` must - be at least as long as ``self.batch_shape``, and for each - non-singleton dim of ``self.batch_shape``, ``batch_shape`` must - either agree or be set to ``-1``. - :return: An expanded version of this distribution. - :rtype: :class:`ReshapedDistribution` - """ - batch_shape = torch.Size(batch_shape) - cut = len(batch_shape) - len(self.batch_shape) - left, right = batch_shape[:cut], batch_shape[cut:] - if right == self.batch_shape: - return self.expand_by(left) - else: - raise NotImplementedError("`TorchDistributionMixin.expand()` cannot expand " - "distribution's existing batch shape. Consider " - "overriding the default implementation for the " - "distribution class.") - def expand_by(self, sample_shape): """ Expands a distribution by adding ``sample_shape`` to the left side of @@ -107,9 +75,7 @@ def expand_by(self, sample_shape): :return: An expanded version of this distribution. :rtype: :class:`ReshapedDistribution` """ - if not sample_shape: - return self - return ReshapedDistribution(self, sample_shape=sample_shape) + return self.expand(torch.Size(sample_shape) + self.batch_shape) def reshape(self, sample_shape=None, extra_event_dims=None): raise Exception(''' @@ -254,133 +220,6 @@ def check(self, value): transform_to.register(IndependentConstraint, lambda c: transform_to(c.base_constraint)) -class ReshapedDistribution(TorchDistribution): - """ - Reshapes a distribution by adding ``sample_shape`` to its total shape - and adding ``reinterpreted_batch_ndims`` to its - :attr:`~torch.distributions.distribution.Distribution.event_shape`. - - :param torch.Size sample_shape: The size of the iid batch to be drawn from - the distribution. - :param int reinterpreted_batch_ndims: The number of extra event dimensions that will - be considered dependent. - """ - arg_constraints = {} - - def __init__(self, base_dist, sample_shape=torch.Size(), reinterpreted_batch_ndims=0): - sample_shape = torch.Size(sample_shape) - if reinterpreted_batch_ndims > len(sample_shape + base_dist.batch_shape): - raise ValueError('Expected reinterpreted_batch_ndims <= len(sample_shape + base_dist.batch_shape), ' - 'actual {} vs {}'.format(reinterpreted_batch_ndims, - len(sample_shape + base_dist.batch_shape))) - self.base_dist = base_dist - self.sample_shape = sample_shape - self.reinterpreted_batch_ndims = reinterpreted_batch_ndims - shape = sample_shape + base_dist.batch_shape + base_dist.event_shape - batch_dim = len(shape) - reinterpreted_batch_ndims - len(base_dist.event_shape) - batch_shape, event_shape = shape[:batch_dim], shape[batch_dim:] - super(ReshapedDistribution, self).__init__(batch_shape, event_shape) - - def expand(self, batch_shape): - batch_shape = torch.Size(batch_shape) - # Raise error if existing batch shape is being shrunk. - # e.g. (2, 4) -> (2, 1) - proposed_shape = broadcast_shape(self.batch_shape, batch_shape) - if tuple(reversed(proposed_shape)) > tuple(reversed(batch_shape)): - raise ValueError("Existing batch shape {} cannot be expanded " - "to the new batch shape {}." - .format(self.batch_shape, batch_shape)) - # Adjust existing sample shape if possible. - base_dist = self.base_dist - base_batch_shape = batch_shape + self.event_shape[:self.reinterpreted_batch_ndims] - cut = len(base_batch_shape) - len(base_dist.batch_shape) - left, right = base_batch_shape[:cut], base_batch_shape[cut:] - if right == base_dist.batch_shape: - sample_shape = left - # Modify the base distribution's batch shape, - # if existing sample shape cannot be adjusted. - else: - base_dist = self.base_dist.expand(base_batch_shape) - assert not isinstance(base_dist, ReshapedDistribution) - sample_shape = torch.Size(()) - return ReshapedDistribution(base_dist, sample_shape, self.reinterpreted_batch_ndims) - - def expand_by(self, sample_shape): - base_dist = self.base_dist - sample_shape = torch.Size(sample_shape) + self.sample_shape - reinterpreted_batch_ndims = self.reinterpreted_batch_ndims - return ReshapedDistribution(base_dist, sample_shape, reinterpreted_batch_ndims) - - def independent(self, reinterpreted_batch_ndims=None): - if reinterpreted_batch_ndims is None: - reinterpreted_batch_ndims = len(self.batch_shape) - base_dist = self.base_dist - sample_shape = self.sample_shape - reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + reinterpreted_batch_ndims - return ReshapedDistribution(base_dist, sample_shape, reinterpreted_batch_ndims) - - @property - def has_rsample(self): - return self.base_dist.has_rsample - - @property - def has_enumerate_support(self): - return self.base_dist.has_enumerate_support - - @constraints.dependent_property - def support(self): - return IndependentConstraint(self.base_dist.support, self.reinterpreted_batch_ndims) - - @property - def _validate_args(self): - return self.base_dist._validate_args - - @_validate_args.setter - def _validate_args(self, value): - self.base_dist._validate_args = value - - def sample(self, sample_shape=torch.Size()): - return self.base_dist.sample(sample_shape + self.sample_shape) - - def rsample(self, sample_shape=torch.Size()): - return self.base_dist.rsample(sample_shape + self.sample_shape) - - def log_prob(self, value): - shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) - return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape) - - def score_parts(self, value): - shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) - log_prob, score_function, entropy_term = self.base_dist.score_parts(value) - log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape) - if not isinstance(score_function, numbers.Number): - score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape) - if not isinstance(entropy_term, numbers.Number): - entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape) - return ScoreParts(log_prob, score_function, entropy_term) - - def enumerate_support(self, expand=True): - if self.reinterpreted_batch_ndims: - raise NotImplementedError("Pyro does not enumerate over cartesian products") - - samples = self.base_dist.enumerate_support(expand=False) - samples = samples.reshape(samples.shape[:1] + (1,) * len(self.batch_shape) + self.event_shape) - if expand: - samples = samples.expand(samples.shape[:1] + self.batch_shape + self.event_shape) - return samples - - @property - def mean(self): - return self.base_dist.mean.expand(self.batch_shape + self.event_shape) - - @property - def variance(self): - return self.base_dist.variance.expand(self.batch_shape + self.event_shape) - - def entropy(self): - return sum_rightmost(self.base_dist.entropy(), self.reinterpreted_batch_ndims) - - class MaskedDistribution(TorchDistribution): """ Masks a distribution by a zero-one tensor that is broadcastable to the @@ -398,6 +237,15 @@ def __init__(self, base_dist, mask): self._mask = mask.byte() super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MaskedDistribution, _instance) + batch_shape = torch.Size(batch_shape) + new.base_dist = self.base_dist.expand(batch_shape) + new._mask = self._mask.expand(batch_shape) + super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + @property def has_rsample(self): return self.base_dist.has_rsample diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index b7b3f6938c..a37efe6558 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -22,21 +22,6 @@ def decorator(new_fn): return decorator -@_patch('torch._standard_gamma') -def _torch_standard_gamma(concentration): - unpatched_fn = _torch_standard_gamma._pyro_unpatched - if concentration.is_cuda: - return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) - return unpatched_fn(concentration) - - -@_patch('torch.distributions.gamma._standard_gamma') -def _standard_gamma(concentration): - if concentration.is_cuda: - return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) - return concentration._standard_gamma() - - @_patch('torch._dirichlet_grad') def _torch_dirichlet_grad(x, concentration, total): unpatched_fn = _torch_dirichlet_grad._pyro_unpatched @@ -72,10 +57,6 @@ def _einsum(equation, operands): y, x = operands return (x.unsqueeze(1) * y).sum(0).transpose(0, 1) - # this workaround can be deleted after this issue is fixed in release: - # https://github.com/pytorch/pytorch/issues/7763 - operands = [t.clone() for t in operands] - return _einsum._pyro_unpatched(equation, operands) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index 3c26d5a135..adc9509aeb 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -3,11 +3,16 @@ import numbers from contextlib import contextmanager +import torch import torch.distributions as torch_dist +from torch import logsumexp from torch.distributions.utils import broadcast_all + _VALIDATION_ENABLED = False +log_sum_exp = logsumexp # DEPRECATED + def copy_docs_from(source_class, full_text=False): """ @@ -52,7 +57,11 @@ def is_identically_zero(x): Check if argument is exactly the number zero. True for the number zero; false for other numbers; false for :class:`~torch.Tensor`s. """ - return isinstance(x, numbers.Number) and x == 0 + if isinstance(x, numbers.Number): + return x == 0 + elif isinstance(x, torch.Tensor) and x.dtype == torch.int64 and not x.shape: + return x.item() == 0 + return False def is_identically_one(x): @@ -60,7 +69,11 @@ def is_identically_one(x): Check if argument is exactly the number one. True for the number one; false for other numbers; false for :class:`~torch.Tensor`s. """ - return isinstance(x, numbers.Number) and x == 1 + if isinstance(x, numbers.Number): + return x == 1 + elif isinstance(x, torch.Tensor) and x.dtype == torch.int64 and not x.shape: + return x.item() == 1 + return False def broadcast_shape(*shapes, **kwargs): @@ -157,15 +170,18 @@ def scale_and_mask(tensor, scale=1.0, mask=None): :param mask: an optional masking tensor :type mask: torch.ByteTensor or None """ - if is_identically_zero(tensor): - return tensor - if mask is None: - if is_identically_one(scale): + if not torch._C._get_tracing_state(): + if is_identically_zero(tensor) or (mask is None and is_identically_one(scale)): return tensor + if mask is None: return tensor * scale tensor, mask = broadcast_all(tensor, mask) - tensor = tensor * scale # triggers a copy, avoiding in-place op errors - tensor.masked_fill_(~mask, 0.) + # TODO: Remove .contiguous once https://github.com/pytorch/pytorch/issues/12230 is fixed. + tensor = (tensor * scale).contiguous() + if torch._C._get_tracing_state(): + tensor[~mask] = 0. + else: + tensor.masked_fill_(~mask, 0.) return tensor @@ -178,27 +194,6 @@ def eye_like(value, m, n=None): return eye -try: - from torch import logsumexp # for pytorch 0.4.1 and later -except ImportError: - def logsumexp(tensor, dim=-1, keepdim=False): - """ - Numerically stable implementation for the `LogSumExp` operation. The - summing is done along the dimension specified by ``dim``. - - :param torch.Tensor tensor: Input tensor. - :param dim: Dimension to be summed out. - :param keepdim: Whether to retain the dimension - that is summed out. - """ - max_val = tensor.max(dim, keepdim=True)[0] - log_sum_exp = max_val + (tensor - max_val).exp().sum(dim=dim, keepdim=True).log() - return log_sum_exp if keepdim else log_sum_exp.squeeze(dim) - - -log_sum_exp = logsumexp # DEPRECATED - - def enable_validation(is_validate): global _VALIDATION_ENABLED _VALIDATION_ENABLED = is_validate diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 5717a66137..cdb15cb8fc 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -36,6 +36,8 @@ class ELBO(object): misuse of enumeration, i.e. that :class:`pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there are enumerated sample sites. + :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT + tracer, when . All :class:`torch.jit.TracerWarning` will be ignored. References @@ -50,7 +52,8 @@ def __init__(self, num_particles=1, max_iarange_nesting=float('inf'), vectorize_particles=False, - strict_enumeration_warning=True): + strict_enumeration_warning=True, + ignore_jit_warnings=False): self.num_particles = num_particles self.max_iarange_nesting = max_iarange_nesting self.vectorize_particles = vectorize_particles @@ -61,6 +64,7 @@ def __init__(self, "a finite value for `max_iarange_nesting` arg.") self.max_iarange_nesting += 1 self.strict_enumeration_warning = strict_enumeration_warning + self.ignore_jit_warnings = ignore_jit_warnings def _vectorized_num_particles(self, fn): """ diff --git a/pyro/infer/enum.py b/pyro/infer/enum.py index 5185eadd54..c220450529 100644 --- a/pyro/infer/enum.py +++ b/pyro/infer/enum.py @@ -8,7 +8,7 @@ from pyro.infer.util import is_validation_enabled from pyro.poutine import Trace from pyro.poutine.util import prune_subsample_sites -from pyro.util import check_model_guide_match, check_site_shape +from pyro.util import check_model_guide_match, check_site_shape, ignore_jit_warnings def iter_discrete_escape(trace, msg): @@ -20,10 +20,14 @@ def iter_discrete_escape(trace, msg): def iter_discrete_extend(trace, site, **ignored): values = site["fn"].enumerate_support(expand=site["infer"].get("expand", False)) + enum_total = values.shape[0] + with ignore_jit_warnings(["Converting a tensor to a Python index", + ("Iterating over a tensor", RuntimeWarning)]): + values = iter(values) for i, value in enumerate(values): extended_site = site.copy() extended_site["infer"] = site["infer"].copy() - extended_site["infer"]["_enum_total"] = len(values) + extended_site["infer"]["_enum_total"] = enum_total extended_site["value"] = value extended_trace = trace.copy() extended_trace.add_node(site["name"], **extended_site) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 83d70e25f8..6e549137ec 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -15,7 +15,7 @@ from pyro.ops.dual_averaging import DualAveraging from pyro.ops.integrator import single_step_velocity_verlet, velocity_verlet from pyro.primitives import _Subsample -from pyro.util import torch_isinf, torch_isnan, optional +from pyro.util import torch_isinf, torch_isnan, optional, ignore_jit_warnings class HMC(TraceKernel): @@ -50,6 +50,11 @@ class HMC(TraceKernel): :param int max_iarange_nesting: Optional bound on max number of nested :func:`pyro.iarange` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. + :param bool jit_compile: Optional parameter denoting whether to use + the PyTorch JIT to trace the log density computation, and use this + optimized executable trace in the integrator. + :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT + tracer when ``jit_compile=True``. Default is False. :param bool experimental_use_einsum: Whether to use an einsum operation to evaluate log pdf for the model trace. No-op unless the trace has discrete sample sites. This flag is experimental and will most likely @@ -83,6 +88,8 @@ def __init__(self, adapt_step_size=False, transforms=None, max_iarange_nesting=float("inf"), + jit_compile=False, + ignore_jit_warnings=False, experimental_use_einsum=False): # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. @@ -99,6 +106,8 @@ def __init__(self, self.trajectory_length = 2 * math.pi # from Stan self.num_steps = max(1, int(self.trajectory_length / self.step_size)) self.adapt_step_size = adapt_step_size + self._jit_compile = jit_compile + self._ignore_jit_warnings = ignore_jit_warnings self.use_einsum = experimental_use_einsum self._target_accept_prob = 0.8 # from Stan @@ -129,6 +138,8 @@ def _kinetic_energy(self, r): return 0.5 * sum(x.pow(2).sum() for x in r.values()) def _potential_energy(self, z): + if self._jit_compile: + return self._potential_energy_jit(z) # Since the model is specified in the constrained space, transform the # unconstrained R.V.s `z` to the constrained space. z_constrained = z.copy() @@ -141,6 +152,32 @@ def _potential_energy(self, z): potential_energy += transform.log_abs_det_jacobian(z_constrained[name], z[name]).sum() return potential_energy + def _potential_energy_jit(self, z): + names, vals = zip(*sorted(z.items())) + if self._compiled_potential_fn: + return self._compiled_potential_fn(*vals) + + def compiled(*zi): + z_constrained = list(zi) + # transform to constrained space. + for i, name in enumerate(names): + if name in self.transforms: + transform = self.transforms[name] + z_constrained[i] = transform.inv(z_constrained[i]) + z_constrained = dict(zip(names, z_constrained)) + trace = self._get_trace(z_constrained) + potential_energy = -self._compute_trace_log_prob(trace) + # adjust by the jacobian for this transformation. + for i, name in enumerate(names): + if name in self.transforms: + transform = self.transforms[name] + potential_energy += transform.log_abs_det_jacobian(z_constrained[name], zi[i]).sum() + return potential_energy + + with pyro.validation_enabled(False), optional(ignore_jit_warnings(), self._ignore_jit_warnings): + self._compiled_potential_fn = torch.jit.trace(compiled, vals, check_trace=False) + return self._compiled_potential_fn(*vals) + def _energy(self, z, r): return self._kinetic_energy(r) + self._potential_energy(z) @@ -149,6 +186,7 @@ def _reset(self): self._accept_cnt = 0 self._r_dist = OrderedDict() self._args = None + self._compiled_potential_fn = None self._kwargs = None self._prototype_trace = None self._adapt_phase = False diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index c87eba2dd7..75a1ded9ee 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -56,6 +56,9 @@ class NUTS(HMC): :param int max_iarange_nesting: Optional bound on max number of nested :func:`pyro.iarange` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. + :param bool jit_compile: Optional parameter denoting whether to use + the PyTorch JIT to trace the log density computation, and use this + optimized executable trace in the integrator. :param bool experimental_use_einsum: Whether to use an einsum operation to evaluat log pdf for the model trace. No-op unless the trace has discrete sample sites. This flag is experimental and will most likely @@ -87,12 +90,16 @@ def __init__(self, adapt_step_size=False, transforms=None, max_iarange_nesting=float("inf"), + jit_compile=False, + ignore_jit_warnings=False, experimental_use_einsum=False): super(NUTS, self).__init__(model, step_size, adapt_step_size=adapt_step_size, transforms=transforms, max_iarange_nesting=max_iarange_nesting, + jit_compile=jit_compile, + ignore_jit_warnings=ignore_jit_warnings, experimental_use_einsum=experimental_use_einsum) self._max_tree_depth = 10 # from Stan @@ -139,7 +146,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): else: diverging = (sliced_energy >= self._max_sliced_energy) delta_energy = energy_new - energy_current - accept_prob = (-delta_energy).exp().clamp(max=1) + accept_prob = (-delta_energy).exp().clamp(max=1.0) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, tree_size, False, diverging, accept_prob, 1) diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index f26537acbd..a0c05bd9c0 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -155,12 +155,13 @@ class JitTrace_ELBO(Trace_ELBO): .. warning:: Experimental. Interface subject to change. """ + def loss_and_grads(self, model, guide, *args, **kwargs): if getattr(self, '_loss_and_surrogate_loss', None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings) def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 @@ -200,7 +201,7 @@ def loss_and_surrogate_loss(*args): # invoke _loss_and_surrogate_loss loss, surrogate_loss = self._loss_and_surrogate_loss(*args) - surrogate_loss.backward() # this line triggers jit compilation + surrogate_loss.backward() loss = loss.item() warn_if_nan(loss, "loss") diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 305b7d1d18..2066b4619e 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -13,7 +13,6 @@ import pyro.distributions as dist import pyro.ops.jit import pyro.poutine as poutine -from pyro.distributions.torch_distribution import ReshapedDistribution from pyro.distributions.util import is_identically_zero, scale_and_mask from pyro.ops.contract import contract_tensor_tree, contract_to_tensor from pyro.infer.elbo import ELBO @@ -130,8 +129,6 @@ def _make_dist(dist_, logits): # Reshape for Bernoulli vs Categorical, OneHotCategorical, etc.. if isinstance(dist_, dist.Bernoulli): logits = logits[..., 1] - logits[..., 0] - elif isinstance(dist_, ReshapedDistribution): - return _make_dist(dist_.base_dist, logits=logits) return type(dist_)(logits=logits) @@ -413,12 +410,13 @@ class JitTraceEnum_ELBO(TraceEnum_ELBO): .. warning:: Experimental. Interface subject to change. """ + def loss_and_grads(self, model, guide, *args, **kwargs): if getattr(self, '_differentiable_loss', None) is None: weakself = weakref.ref(self) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings) def differentiable_loss(*args): self = weakself() elbo = 0.0 diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 4e6d28e695..a1f1f5317a 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -255,7 +255,7 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace): class JitTraceGraph_ELBO(TraceGraph_ELBO): """ - Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.compile` to + Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.trace` to compile :meth:`loss_and_grads`. This works only for a limited set of models: @@ -275,7 +275,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings) def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 @@ -303,7 +303,7 @@ def loss_and_surrogate_loss(*args): self._loss_and_surrogate_loss = loss_and_surrogate_loss loss, surrogate_loss = self._loss_and_surrogate_loss(*args) - surrogate_loss.backward() # this line triggers jit compilation + surrogate_loss.backward() loss = loss.item() warn_if_nan(loss, "loss") diff --git a/pyro/infer/util.py b/pyro/infer/util.py index da47278f1f..9183d4a39f 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -259,10 +259,11 @@ def compute_expectation(self, costs): for cost in cost_terms: prob = sumproduct(factors, cost.shape, device=cost.device) mask = prob > 0 - if torch.is_tensor(mask) and not mask.all(): - cost, prob, mask = broadcast_all(cost, prob, mask) - prob = prob[mask] - cost = cost[mask] + if torch.is_tensor(mask): + if torch._C._get_tracing_state() or not mask.all(): + cost, prob, mask = broadcast_all(cost, prob, mask) + prob = prob[mask] + cost = cost[mask] expected_cost = expected_cost + (prob * cost).sum() LAST_CACHE_SIZE[0] = count_cached_ops(cache) return expected_cost diff --git a/pyro/ops/einsum/torch_log.py b/pyro/ops/einsum/torch_log.py index 062a49ee8a..9db5926adf 100644 --- a/pyro/ops/einsum/torch_log.py +++ b/pyro/ops/einsum/torch_log.py @@ -49,12 +49,9 @@ def einsum(equation, *operands): # This function is copied and adapted from: # https://github.com/dgasmith/opt_einsum/blob/a6dd686/opt_einsum/backends/torch.py def tensordot(x, y, axes=2): - xnd = x.ndimension() - ynd = y.ndimension() - # convert int argument to (list[int], list[int]) if isinstance(axes, int): - axes = range(xnd - axes, xnd), range(axes) + axes = list(range(x.dim() - axes, x.dim())), list(range(axes)) # convert (int, int) to (list[int], list[int]) if isinstance(axes[0], int): @@ -62,30 +59,22 @@ def tensordot(x, y, axes=2): if isinstance(axes[1], int): axes = axes[0], (axes[1],) - # initialize empty indices - x_ix = [None] * xnd - y_ix = [None] * ynd - out_ix = [] - - # fill in repeated indices - available_ix = iter(EINSUM_SYMBOLS_BASE) - for ax1, ax2 in zip(*axes): - repeat = next(available_ix) - x_ix[ax1] = repeat - y_ix[ax2] = repeat - - # fill in the rest, and maintain output order - for i in range(xnd): - if x_ix[i] is None: - leave = next(available_ix) - x_ix[i] = leave - out_ix.append(leave) - for i in range(ynd): - if y_ix[i] is None: - leave = next(available_ix) - y_ix[i] = leave - out_ix.append(leave) - - # form full string and contract! - einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix))) - return einsum(einsum_str, x, y) + # compute shifts + assert all(dim >= 0 for axis in axes for dim in axis) + x_shift = x + y_shift = y + for dim in axes[0]: + x_shift = x_shift.max(dim, keepdim=True)[0] + for dim in axes[1]: + y_shift = y_shift.max(dim, keepdim=True)[0] + + result = torch.tensordot((x - x_shift).exp(), (y - y_shift).exp(), axes).log() + + # apply shifts to result + x_part = x.dim() - len(axes[0]) + y_part = y.dim() - len(axes[1]) + assert result.dim() == x_part + y_part + result += x_shift.reshape(result.shape[:x_part] + (1,) * y_part) + result += y_shift.reshape(result.shape[x_part:]) + + return result diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index 8143b3c563..3a91083cea 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -1,78 +1,83 @@ +from __future__ import absolute_import, division, print_function + import weakref + import torch import pyro import pyro.poutine as poutine +from pyro.util import ignore_jit_warnings, optional class CompiledFunction(object): """ - Output type of :func:`pyro.ops.jit.compile`. + Output type of :func:`pyro.ops.jit.trace`. - Wrapper around the output of :func:`torch.jit.compile` + Wrapper around the output of :func:`torch.jit.trace` that handles parameter plumbing. The actual PyTorch compilation artifact is stored in :attr:`compiled`. Call diagnostic methods on this attribute. """ - def __init__(self, fn, **jit_options): + def __init__(self, fn, ignore_warnings=False): self.fn = fn - self._jit_options = jit_options - self.compiled = None + self.compiled = {} # len(args) -> callable + self.ignore_warnings = ignore_warnings self._param_names = None def __call__(self, *args, **kwargs): + argc = len(args) # if first time - if self.compiled is None: + if argc not in self.compiled: # param capture with poutine.block(): with poutine.trace(param_only=True) as first_param_capture: self.fn(*args, **kwargs) self._param_names = list(set(first_param_capture.trace.nodes.keys())) - + unconstrained_params = tuple(pyro.param(name).unconstrained() + for name in self._param_names) + params_and_args = unconstrained_params + args weakself = weakref.ref(self) - @torch.jit.compile(**self._jit_options) - def compiled(unconstrained_params, *args): + def compiled(*params_and_args): self = weakself() + unconstrained_params = params_and_args[:len(self._param_names)] + args = params_and_args[len(self._param_names):] constrained_params = {} for name, unconstrained_param in zip(self._param_names, unconstrained_params): constrained_param = pyro.param(name) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param constrained_params[name] = constrained_param + return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) - return poutine.replay( - self.fn, params=constrained_params)(*args, **kwargs) - - self.compiled = compiled - - param_list = [pyro.param(name).unconstrained() - for name in self._param_names] + with pyro.validation_enabled(False), optional(ignore_jit_warnings(), self.ignore_warnings): + self.compiled[argc] = torch.jit.trace(compiled, params_and_args, check_trace=False) + else: + unconstrained_params = [pyro.param(name).unconstrained() + for name in self._param_names] + params_and_args = unconstrained_params + list(args) with poutine.block(hide=self._param_names): with poutine.trace(param_only=True) as param_capture: - ret = self.compiled(param_list, *args, **kwargs) - - new_params = filter(lambda name: name not in self._param_names, - param_capture.trace.nodes.keys()) + ret = self.compiled[argc](*params_and_args) - for name in new_params: - # enforce uniqueness + for name in param_capture.trace.nodes.keys(): if name not in self._param_names: - self._param_names.append(name) + raise NotImplementedError('pyro.ops.jit.trace assumes all params are created on ' + 'first invocation, but found new param: {}'.format(name)) return ret -def compile(fn=None, **jit_options): +def trace(fn=None, ignore_warnings=False): """ - Drop-in replacement for :func:`torch.jit.compile` that works with + Lazy replacement for :func:`torch.jit.trace` that works with Pyro functions that call :func:`pyro.param`. - The actual compilation artifact is stored in the ``compiled`` attribute of the output. - Call diagnostic methods on this attribute. + The actual compilation artifact is stored in the ``compiled`` attribute of + the output. Call diagnostic methods on this attribute. Example:: @@ -80,12 +85,12 @@ def model(x): scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) return pyro.sample("y", dist.Normal(x, scale)) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace def model_log_prob_fn(x, y): cond_model = pyro.condition(model, data={"y": y}) tr = pyro.poutine.trace(cond_model).get_trace(x) return tr.log_prob_sum() """ if fn is None: - return lambda fn: compile(fn, **jit_options) - return CompiledFunction(fn, **jit_options) + return lambda fn: trace(fn, ignore_warnings=ignore_warnings) + return CompiledFunction(fn, ignore_warnings=ignore_warnings) diff --git a/pyro/ops/sumproduct.py b/pyro/ops/sumproduct.py index 2e63c80315..510abfa7ff 100644 --- a/pyro/ops/sumproduct.py +++ b/pyro/ops/sumproduct.py @@ -9,6 +9,7 @@ from pyro.distributions.util import broadcast_shape from pyro.ops.einsum import contract +from pyro.util import ignore_jit_warnings def zip_align_right(xs, ys): @@ -47,8 +48,9 @@ def sumproduct(factors, target_shape=(), optimize=True, device=None): for t in factors: (numbers if isinstance(t, Number) else tensors).append(t) if not tensors: - return torch.tensor(float(reduce(operator.mul, numbers, 1.)), - device=device).expand(target_shape) + with ignore_jit_warnings(["torch.tensor results are registered as constants"]): + return torch.tensor(float(reduce(operator.mul, numbers, 1.)), + device=device).expand(target_shape) if numbers: number_part = reduce(operator.mul, numbers, 1.) tensor_part = sumproduct(tensors, target_shape, optimize=optimize) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index e4e04d28b6..20ba466f23 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +from pyro.util import ignore_jit_warnings from .messenger import Messenger @@ -11,6 +12,7 @@ class BroadcastMessenger(Messenger): broadcastable with the size of the :class:`~pyro.iarange` contexts installed in the `cond_indep_stack`. """ + @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) def _pyro_sample(self, msg): """ :param msg: current message at a trace site. @@ -21,13 +23,14 @@ def _pyro_sample(self, msg): dist = msg["fn"] actual_batch_shape = getattr(dist, "batch_shape", None) if actual_batch_shape is not None: - target_batch_shape = [None if size == 1 else size for size in actual_batch_shape] + target_batch_shape = [None if size == 1 else size + for size in actual_batch_shape] for f in msg["cond_indep_stack"]: if f.dim is None or f.size == -1: continue assert f.dim < 0 target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape - if target_batch_shape[f.dim] not in (None, f.size): + if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size: raise ValueError("Shape mismatch inside iarange('{}') at site {} dim {}, {} vs {}".format( f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim])) target_batch_shape[f.dim] = f.size diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index ee8a3d832f..73449bb8d3 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -4,6 +4,7 @@ import torch +from pyro.util import ignore_jit_warnings from .messenger import Messenger @@ -13,8 +14,9 @@ def vectorized(self): return self.dim is not None def _key(self): - size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size - return self.name, self.dim, size, self.counter + with ignore_jit_warnings(["Converting a tensor to a Python number"]): + size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size + return self.name, self.dim, size, self.counter def __eq__(self, other): return type(self) == type(other) and self._key() == other._key() diff --git a/pyro/primitives.py b/pyro/primitives.py index 2a10fdcc85..963a0e2a9f 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import copy -import numbers import warnings from collections import OrderedDict from contextlib import contextmanager @@ -15,7 +14,7 @@ from pyro.distributions.distribution import Distribution from pyro.params import param_with_module_name from pyro.poutine.runtime import _DIM_ALLOCATOR, _MODULE_NAMESPACE_DIVIDER, _PYRO_PARAM_STORE, am_i_wrapped, apply_stack -from pyro.util import deep_getattr, set_rng_seed # noqa: F401 +from pyro.util import deep_getattr, ignore_jit_warnings, torch_float, jit_compatible_arange # noqa: F401 def get_param_store(): @@ -108,7 +107,8 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None): if self.use_cuda ^ (device != "cpu"): raise ValueError("Incompatible arg values use_cuda={}, device={}." .format(use_cuda, device)) - self.device = torch.Tensor().device if not device else device + with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): + self.device = torch.Tensor().device if not device else device def sample(self, sample_shape=torch.Size()): """ @@ -118,10 +118,8 @@ def sample(self, sample_shape=torch.Size()): if sample_shape: raise NotImplementedError subsample_size = self.subsample_size - if subsample_size is None or subsample_size > self.size: - subsample_size = self.size - if subsample_size >= self.size: - result = torch.arange(self.size, dtype=torch.long).to(self.device) + if subsample_size is None or subsample_size >= self.size: + result = jit_compatible_arange(self.size) else: result = torch.multinomial(torch.ones(self.size), self.subsample_size, replacement=False).to(self.device) @@ -146,12 +144,13 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No elif subsample is None: subsample = sample(name, _Subsample(size, subsample_size, use_cuda=use_cuda, device=device)) - if subsample_size is None: - subsample_size = len(subsample) - elif subsample is not None and subsample_size != len(subsample): - raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( - subsample_size, len(subsample)) + - " Did you accidentally use different subsample_size in the model and guide?") + with ignore_jit_warnings(): + if subsample_size is None: + subsample_size = subsample.shape[0] if torch._C._get_tracing_state() else len(subsample) + elif subsample is not None and subsample_size != len(subsample): + raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( + subsample_size, len(subsample)) + + " Did you accidentally use different subsample_size in the model and guide?") return size, subsample_size, subsample @@ -256,7 +255,7 @@ def __enter__(self): self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) if self._wrapped: try: - self._scale_messenger = poutine.scale(scale=self.size / self.subsample_size) + self._scale_messenger = poutine.scale(scale=torch_float(self.size) / self.subsample_size) self._indep_messenger = poutine.indep(name=self.name, size=self.subsample_size, dim=self.dim) self._scale_messenger.__enter__() self._indep_messenger.__enter__() @@ -314,18 +313,19 @@ def __init__(self, name, size, subsample_size=None, subsample=None, use_cuda=Non use_cuda=use_cuda, device=device) def __iter__(self): + with ignore_jit_warnings(["Converting a tensor to a Python index", + ("Iterating over a tensor", RuntimeWarning)]): + subsample = iter(self.subsample) if not am_i_wrapped(): - for i in self.subsample: - yield i if isinstance(i, numbers.Number) else i.item() + for i in subsample: + yield i else: indep_context = poutine.indep(name=self.name, size=self.subsample_size) - with poutine.scale(scale=self.size / self.subsample_size): - for i in self.subsample: + with poutine.scale(scale=torch_float(self.size) / self.subsample_size): + for i in subsample: indep_context.next_context() with indep_context: - # convert to python numeric type as functions like torch.ones(*args) - # do not work with dim 0 torch.Tensor instances. - yield i if isinstance(i, numbers.Number) else i.item() + yield i # XXX this should have the same call signature as torch.Tensor constructors diff --git a/pyro/util.py b/pyro/util.py index 6b49e0d2b5..1c6a7722fc 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -5,7 +5,7 @@ import random import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib2 import contextmanager import graphviz import torch @@ -324,6 +324,30 @@ def check_if_enumerated(guide_trace): 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) +@contextmanager +def ignore_jit_warnings(filter=None): + """ + Ignore JIT tracer warnings with messages that match `filter`. If + `filter` is not specified all tracer warnings are ignored. + + :param filter: A list containing either warning message (str), + or tuple consisting of (warning message (str), Warning class). + """ + with warnings.catch_warnings(): + if filter is None: + warnings.filterwarnings("ignore", + category=torch.jit.TracerWarning) + else: + for msg in filter: + category = torch.jit.TracerWarning + if isinstance(msg, tuple): + msg, category = msg + warnings.filterwarnings("ignore", + category=category, + message=msg) + yield + + @contextmanager def optional(context_manager, condition): """ @@ -342,3 +366,13 @@ def deep_getattr(obj, name): Throws an AttributeError if bad attribute """ return functools.reduce(getattr, name.split("."), obj) + + +# work around https://github.com/pytorch/pytorch/issues/11829 +def jit_compatible_arange(end, dtype=None, device=None): + dtype = torch.long if dtype is None else dtype + return torch.cumsum(torch.ones(end, dtype=dtype, device=device), dim=0) - 1 + + +def torch_float(x): + return x.float() if isinstance(x, torch.Tensor) else float(x) diff --git a/setup.py b/setup.py index 50ec0766dc..d8f51f3261 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,8 @@ 'matplotlib>=1.3', 'observations>=0.1.4', 'pillow', - 'torchvision', + # TODO: uncomment on release; using torch-nightly build + # 'torchvision', 'visdom>=0.1.4', 'pandas', 'wget', @@ -83,7 +84,8 @@ 'numpy>=1.7', 'opt_einsum>=2.2.0', 'six>=1.10.0', - 'torch==0.4.0', + # TODO: uncomment on release; using torch-nightly build + # 'torch>=0.4.1', 'tqdm>=4.25', ], extras_require={ diff --git a/tests/contrib/oed/test_eig.py b/tests/contrib/oed/test_eig.py index 922283a914..cd2df30061 100644 --- a/tests/contrib/oed/test_eig.py +++ b/tests/contrib/oed/test_eig.py @@ -21,6 +21,7 @@ ) from pyro.contrib.oed.util import linear_model_ground_truth from pyro.infer import TraceEnum_ELBO +from tests.common import xfail_param logger = logging.getLogger(__name__) @@ -147,7 +148,7 @@ def h(p): False, 0.3 ), - T( + xfail_param(*T( basic_2p_linear_model_sds_10_2pt5, X_circle_5d_1n_2p, "y", @@ -157,7 +158,7 @@ def h(p): optim.Adam({"lr": 0.025}), False, None, 500], True, 0.3 - ), + ), reason="https://github.com/uber/pyro/issues/1418"), T( basic_2p_linear_model_sds_10_2pt5, AB_test_2d_10n_2p, @@ -203,7 +204,7 @@ def h(p): 0.3, marks=pytest.mark.xfail ), - T( + xfail_param(*T( group_2p_linear_model_sds_10_2pt5, X_circle_5d_1n_2p, "y", @@ -213,7 +214,7 @@ def h(p): optim.Adam({"lr": 0.025}), False, None, 500], True, 0.3 - ), + ), reason="https://github.com/uber/pyro/issues/1418"), T( group_2p_linear_model_sds_10_2pt5, X_circle_5d_1n_2p, diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index ef606e22e7..514c94ddad 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -8,6 +8,7 @@ @pytest.mark.parametrize("alpha", [0.5, 0.9, 0.99]) +@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1418") def test_ewma(alpha, NS=10000, D=1): ewma_log = EwmaLog(alpha=alpha) sigma = torch.tensor(1.0, requires_grad=True) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 38e8bf8c4f..2e74ecca08 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -153,10 +153,10 @@ Fixture(pyro_dist=dist.LowRankMultivariateNormal, scipy_dist=sp.multivariate_normal, examples=[ - {'loc': [2.0, 1.0], 'D_term': [0.5, 0.5], 'W_term': [[1.0], [0.5]], + {'loc': [2.0, 1.0], 'cov_diag': [0.5, 0.5], 'cov_factor': [[1.0], [0.5]], 'test_data': [[2.0, 1.0], [9.0, 3.4]]}, ], - scipy_arg_fn=lambda loc, D_term=None, W_term=None: + scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None: ((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}), prec=0.01, min_samples=500000), diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index d997ed207f..b893afed17 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -6,7 +6,6 @@ import pyro import pyro.distributions as dist -from pyro.distributions.torch_distribution import ReshapedDistribution from pyro.distributions.util import broadcast_shape from tests.common import assert_equal, xfail_if_not_implemented @@ -130,8 +129,7 @@ def test_distribution_validate_args(dist_class, args, validate_args): def check_sample_shapes(small, large): - dist_instance = small.base_dist if isinstance(small, ReshapedDistribution) \ - else small + dist_instance = small if isinstance(dist_instance, (dist.LogNormal, dist.LowRankMultivariateNormal, dist.VonMises)): # Ignore broadcasting bug in LogNormal: # https://github.com/pytorch/pytorch/pull/7269 @@ -147,9 +145,10 @@ def check_sample_shapes(small, large): def test_expand_by(dist, sample_shape, shape_type): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) - large = small.expand_by(shape_type(sample_shape)) - assert large.batch_shape == sample_shape + small.batch_shape - check_sample_shapes(small, large) + with xfail_if_not_implemented(): + large = small.expand_by(shape_type(sample_shape)) + assert large.batch_shape == sample_shape + small.batch_shape + check_sample_shapes(small, large) @pytest.mark.parametrize('sample_shape', [(), (2,), (2, 3)]) @@ -157,9 +156,10 @@ def test_expand_by(dist, sample_shape, shape_type): def test_expand_new_dim(dist, sample_shape, shape_type): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) - large = small.expand(shape_type(sample_shape + small.batch_shape)) - assert large.batch_shape == sample_shape + small.batch_shape - check_sample_shapes(small, large) + with xfail_if_not_implemented(): + large = small.expand(shape_type(sample_shape + small.batch_shape)) + assert large.batch_shape == sample_shape + small.batch_shape + check_sample_shapes(small, large) @pytest.mark.parametrize('shape_type', [torch.Size, tuple, list]) @@ -174,8 +174,8 @@ def test_expand_existing_dim(dist, shape_type): batch_shape = torch.Size(batch_shape) with xfail_if_not_implemented(): large = small.expand(shape_type(batch_shape)) - assert large.batch_shape == batch_shape - check_sample_shapes(small, large) + assert large.batch_shape == batch_shape + check_sample_shapes(small, large) @pytest.mark.parametrize("sample_shapes", [ @@ -203,10 +203,11 @@ def test_subsequent_expands_ok(dist, sample_shapes): def test_expand_error(dist, initial_shape, proposed_shape): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) - large = small.expand(torch.Size(initial_shape) + small.batch_shape) - proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape - with pytest.raises(ValueError): - large.expand(proposed_batch_shape) + with xfail_if_not_implemented(): + large = small.expand(torch.Size(initial_shape) + small.batch_shape) + proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape + with pytest.raises(RuntimeError): + large.expand(proposed_batch_shape) @pytest.mark.parametrize("extra_event_dims,expand_shape", [ @@ -228,19 +229,10 @@ def test_expand_reshaped_distribution(extra_event_dims, expand_shape): assert large.batch_shape == torch.Size(expand_shape) assert large.event_shape == torch.Size(event_shape) - # Change base_dist only if sample_shape cannot be adjusted. - if extra_event_dims >= 1: - assert large.base_dist == reshaped_dist.base_dist - else: - if expand_shape[-1] == 1: - assert large.base_dist == reshaped_dist.base_dist - else: - assert large.base_dist.batch_shape == torch.Size(expand_shape) - # Throws error when batch shape cannot be broadcasted - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): reshaped_dist.expand(expand_shape + [3]) # Throws error when trying to shrink existing batch shape - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): large.expand(expand_shape[1:]) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 53dce68afe..f3bf152425 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -18,6 +18,22 @@ logger = logging.getLogger(__name__) +def mark_jit(*args, **kwargs): + jit_markers = kwargs.pop("marks", []) + jit_markers += [ + pytest.mark.skipif(torch.__version__ <= "0.4.1", + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228"), + pytest.mark.skipif('CI' in os.environ, + reason='slow test') + ] + kwargs["marks"] = jit_markers + return pytest.param(*args, **kwargs) + + +def jit_idfn(param): + return "JIT={}".format(param) + + class GaussianChain(object): def __init__(self, dim, chain_len, num_obs): @@ -93,8 +109,7 @@ def rmse(t1, t2): mean_tol=0.05, std_tol=0.05, ), marks=[pytest.mark.xfail(reason="flaky"), - pytest.mark.skipif('CI' in os.environ and os.environ['CI'] == 'true', - reason='Slow test - skip on CI')]), + pytest.mark.skip(reason='Slow test')]), pytest.param(*T( GaussianChain(dim=5, chain_len=9, num_obs=1), num_samples=3000, @@ -106,8 +121,7 @@ def rmse(t1, t2): mean_tol=0.08, std_tol=0.08, ), marks=[pytest.mark.xfail(reason="flaky"), - pytest.mark.skipif('CI' in os.environ and os.environ['CI'] == 'true', - reason='Slow test - skip on CI')]) + pytest.mark.skipif(reason='Slow test')]) ] TEST_IDS = [t[0].id_fn() if type(t).__name__ == 'TestExample' @@ -154,7 +168,8 @@ def test_hmc_conjugate_gaussian(fixture, assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) -def test_logistic_regression(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -166,13 +181,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, step_size=0.0855, num_steps=4) + hmc_kernel = HMC(model, step_size=0.0855, num_steps=4, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) beta_posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, beta_posterior.mean).item(), 0.0, prec=0.1) -def test_beta_bernoulli(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -182,13 +198,14 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, step_size=0.02, num_steps=3) + hmc_kernel = HMC(model, step_size=0.02, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.05) -def test_gamma_normal(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_normal(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) @@ -198,13 +215,16 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3) + hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) -def test_dirichlet_categorical(): +@pytest.mark.parametrize("jit", [False, + mark_jit(True, marks=[pytest.mark.xfail("https://github.com/uber/pyro/issues/1418")]) + ], ids=jit_idfn) +def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) @@ -213,13 +233,14 @@ def model(data): true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3) + hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) -def test_logistic_regression_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression_with_dual_averaging(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -231,13 +252,15 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(posterior.mean, true_coefs).item(), 0.0, prec=0.1) -def test_beta_bernoulli_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli_with_dual_averaging(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -248,13 +271,15 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=2) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=2, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.05) -def test_gamma_normal_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_normal_with_dual_averaging(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) @@ -264,13 +289,17 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) -def test_gaussian_mixture_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True, + marks=[pytest.mark.skip("FIXME: Slow on JIT.")])], + ids=jit_idfn) +def test_gaussian_mixture_model(jit): K, N = 3, 1000 @poutine.broadcast @@ -287,15 +316,17 @@ def gmm(data): true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() - hmc_kernel = HMC(gmm, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=1) + hmc_kernel = HMC(gmm, trajectory_length=1, adapt_step_size=True, + max_iarange_nesting=1, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=300, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) assert_equal(posterior[1], true_cluster_means, prec=0.2) +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) @pytest.mark.parametrize("use_einsum", [False, True]) -def test_bernoulli_latent_model(use_einsum): +def test_bernoulli_latent_model(jit, use_einsum): @poutine.broadcast def model(data): y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0)) @@ -311,6 +342,7 @@ def model(data): z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2. * z, 1.0).sample() hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=1, + jit_compile=jit, ignore_jit_warnings=True, experimental_use_einsum=use_einsum) mcmc_run = MCMC(hmc_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="y_prob").mean diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 817e85e470..3b294af5ab 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -12,12 +12,30 @@ from pyro.infer.mcmc.mcmc import MCMC from pyro.infer.mcmc.nuts import NUTS import pyro.poutine as poutine +from pyro.util import ignore_jit_warnings from tests.common import assert_equal from .test_hmc import TEST_CASES, TEST_IDS, T, rmse logger = logging.getLogger(__name__) + +def mark_jit(*args, **kwargs): + jit_markers = kwargs.pop("marks", []) + jit_markers += [ + pytest.mark.skipif(torch.__version__ <= "0.4.1", + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228"), + pytest.mark.skipif('CI' in os.environ, + reason='slow test') + ] + kwargs["marks"] = jit_markers + return pytest.param(*args, **kwargs) + + +def jit_idfn(param): + return "JIT={}".format(param) + + T2 = T(*TEST_CASES[2].values)._replace(num_samples=800, warmup_steps=200) TEST_CASES[2] = pytest.param(*T2, marks=pytest.mark.skipif( 'CI' in os.environ and os.environ['CI'] == 'true', reason='Slow test - skip on CI')) @@ -68,7 +86,8 @@ def test_nuts_conjugate_gaussian(fixture, assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) -def test_logistic_regression(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -80,13 +99,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, step_size=0.0855) + nuts_kernel = NUTS(model, step_size=0.0855, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, posterior.mean).item(), 0.0, prec=0.1) -def test_beta_bernoulli(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -96,13 +116,14 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, step_size=0.02) + nuts_kernel = NUTS(model, step_size=0.02, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) -def test_gamma_normal(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_normal(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) @@ -112,13 +133,14 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, step_size=0.01) + nuts_kernel = NUTS(model, step_size=0.01, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) -def test_logistic_regression_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression_with_dual_averaging(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -130,13 +152,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, posterior.mean).item(), 0.0, prec=0.1) -def test_beta_bernoulli_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli_with_dual_averaging(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -146,13 +169,17 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, + ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="p_latent") assert_equal(posterior.mean, true_probs, prec=0.03) -def test_dirichlet_categorical(): +@pytest.mark.parametrize("jit", [False, + mark_jit(True, marks=[pytest.mark.xfail("https://github.com/uber/pyro/issues/1418")]) + ], ids=jit_idfn) +def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) @@ -161,13 +188,14 @@ def model(data): true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) -def test_gamma_beta(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_beta(jit): def model(data): alpha_prior = pyro.sample('alpha', dist.Gamma(concentration=1., rate=1.)) beta_prior = pyro.sample('beta', dist.Gamma(concentration=1., rate=1.)) @@ -176,13 +204,16 @@ def model(data): true_alpha = torch.tensor(5.) true_beta = torch.tensor(1.) data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000,))) - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=['alpha', 'beta']) assert_equal(posterior.mean, torch.stack([true_alpha, true_beta]), prec=0.05) -def test_gaussian_mixture_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True, + marks=[pytest.mark.skip("FIXME: Slow on JIT.")])], + ids=jit_idfn) +def test_gaussian_mixture_model(jit): K, N = 3, 1000 @poutine.broadcast @@ -199,14 +230,16 @@ def gmm(data): true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() - nuts_kernel = NUTS(gmm, adapt_step_size=True, max_iarange_nesting=1) + nuts_kernel = NUTS(gmm, adapt_step_size=True, max_iarange_nesting=1, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=300, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) assert_equal(posterior[1], true_cluster_means, prec=0.2) -def test_bernoulli_latent_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_bernoulli_latent_model(jit): @poutine.broadcast def model(data): y_prob = pyro.sample("y_prob", dist.Beta(1., 1.)) @@ -220,12 +253,14 @@ def model(data): y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2. * z, 1.0).sample() - nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=1) + nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=1, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="y_prob").mean assert_equal(posterior, y_prob, prec=0.05) +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) @pytest.mark.parametrize("num_steps,use_einsum", [ (2, False), (3, False), @@ -234,7 +269,7 @@ def model(data): pytest.param(30, True, marks=pytest.mark.skip(reason="https://github.com/pytorch/pytorch/issues/10661")), ]) -def test_gaussian_hmm_enum_shape(num_steps, use_einsum): +def test_gaussian_hmm_enum_shape(jit, num_steps, use_einsum): dim = 4 def model(data): @@ -243,14 +278,16 @@ def model(data): emission_loc = pyro.sample("emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim))) emission_scale = pyro.sample("emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim))) x = None - for t, y in enumerate(data): - x = pyro.sample("x_{}".format(t), dist.Categorical(initialize if x is None else transition[x])) - pyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) - # check shape - effective_dim = sum(1 for size in x.shape if size > 1) - assert effective_dim == 1 + with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): + for t, y in enumerate(data): + x = pyro.sample("x_{}".format(t), dist.Categorical(initialize if x is None else transition[x])) + pyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) + # check shape + effective_dim = sum(1 for size in x.shape if size > 1) + assert effective_dim == 1 data = torch.ones(num_steps) nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=0, + jit_compile=jit, ignore_jit_warnings=True, experimental_use_einsum=use_einsum) MCMC(nuts_kernel, num_samples=5, warmup_steps=5).run(data) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 38efc8b3f8..feb73d8b57 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -152,7 +152,7 @@ def gmm_model(data, verbose=False): z = pyro.sample("z_{}".format(i), dist.Bernoulli(p)) z = z.long() if verbose: - logger.debug("M{} z_{} = {}".format(" " * i, i, z.cpu().numpy())) + logger.debug("M{} z_{} = {}".format(" " * int(i), int(i), z.cpu().numpy())) pyro.sample("x_{}".format(i), dist.Normal(mus[z], scale), obs=data[i]) @@ -162,7 +162,7 @@ def gmm_guide(data, verbose=False): z = pyro.sample("z_{}".format(i), dist.Bernoulli(p)) z = z.long() if verbose: - logger.debug("G{} z_{} = {}".format(" " * i, i, z.cpu().numpy())) + logger.debug("G{} z_{} = {}".format(" " * int(i), int(i), z.cpu().numpy())) @pytest.mark.parametrize("data_size", [1, 2, 3]) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 1059fa2f95..2248f1b375 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +import warnings import logging import pytest @@ -9,12 +10,19 @@ import pyro import pyro.distributions as dist -from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, - Trace_ELBO, TraceEnum_ELBO, +import pyro.ops.jit +import pyro.poutine as poutine +from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) from pyro.optim import Adam from pyro.poutine.indep_messenger import CondIndepStackFrame -from tests.common import assert_equal, xfail_param +from tests.common import assert_equal + + +def constant(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + return torch.tensor(*args, **kwargs) logger = logging.getLogger(__name__) @@ -23,71 +31,212 @@ def test_simple(): y = torch.ones(2) - @torch.jit.compile(nderivs=0) def f(x): logger.debug('Inside f') - assert x is y + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + assert x is y return y + 1.0 + logger.debug('Compiling f') + f = torch.jit.trace(f, (y,), check_trace=False) logger.debug('Calling f(y)') - assert_equal(f(y), y.new_tensor([2, 2])) + assert_equal(f(y), y.new_tensor([2., 2.])) logger.debug('Calling f(y)') - assert_equal(f(y), y.new_tensor([2, 2])) + assert_equal(f(y), y.new_tensor([2., 2.])) logger.debug('Calling f(torch.zeros(2))') - assert_equal(f(torch.zeros(2)), y.new_tensor([1, 1])) - with pytest.raises(AssertionError): - assert_equal(f(torch.ones(5)), y.new_tensor([2, 2, 2, 2, 2])) + assert_equal(f(torch.zeros(2)), y.new_tensor([1., 1.])) + logger.debug('Calling f(torch.zeros(5))') + assert_equal(f(torch.ones(5)), y.new_tensor([2., 2., 2., 2., 2.])) + + +def test_multi_output(): + y = torch.ones(2) + + def f(x): + logger.debug('Inside f') + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + assert x is y + return y - 1.0, y + 1.0 + + logger.debug('Compiling f') + f = torch.jit.trace(f, (y,), check_trace=False) + logger.debug('Calling f(y)') + assert_equal(f(y)[1], y.new_tensor([2., 2.])) + logger.debug('Calling f(y)') + assert_equal(f(y)[1], y.new_tensor([2., 2.])) + logger.debug('Calling f(torch.zeros(2))') + assert_equal(f(torch.zeros(2))[1], y.new_tensor([1., 1.])) + logger.debug('Calling f(torch.zeros(5))') + assert_equal(f(torch.ones(5))[1], y.new_tensor([2., 2., 2., 2., 2.])) def test_backward(): y = torch.ones(2, requires_grad=True) - @torch.jit.compile(nderivs=1) def f(x): logger.debug('Inside f') - assert x is y + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + assert x is y return (y + 1.0).sum() + logger.debug('Compiling f') + f = torch.jit.trace(f, (y,), check_trace=False) logger.debug('Calling f(y)') f(y).backward() logger.debug('Calling f(y)') f(y) logger.debug('Calling f(torch.zeros(2))') f(torch.zeros(2, requires_grad=True)) - with pytest.raises(AssertionError): - f(torch.ones(5, requires_grad=True)) + logger.debug('Calling f(torch.zeros(5))') + f(torch.ones(5, requires_grad=True)) +@pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad(): - @torch.jit.compile(nderivs=0) def f(x, y): logger.debug('Inside f') loss = (x - y).pow(2).sum() return torch.autograd.grad(loss, [x, y], allow_unused=True) + logger.debug('Compiling f') + f = torch.jit.trace(f, (torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True))) logger.debug('Invoking f') f(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) logger.debug('Invoking f') f(torch.zeros(2, requires_grad=True), torch.zeros(2, requires_grad=True)) -@pytest.mark.xfail(reason='RuntimeError: ' - 'saved_variables() needed but not implemented in ExpandBackward') +@pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad_expand(): - @torch.jit.compile(nderivs=0) def f(x, y): logger.debug('Inside f') loss = (x - y).pow(2).sum() return torch.autograd.grad(loss, [x, y], allow_unused=True) + logger.debug('Compiling f') + f = torch.jit.trace(f, (torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True))) logger.debug('Invoking f') f(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) logger.debug('Invoking f') f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) +@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/11555") +def test_masked_fill(): + + def f(y, mask): + return y.clone().masked_fill_(mask, 0.) + + x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) + y = x / x.unsqueeze(-1) + mask = ~(y == y) + f = torch.jit.trace(f, (y, mask)) + + +def test_masked_fill_workaround(): + + def f(y, mask): + return y.clone().masked_fill_(mask, 0.) + + def g(y, mask): + y = y.clone() + y[mask] = 0. # this is much slower than .masked_fill_() + return y + + x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) + y = x / x.unsqueeze(-1) + mask = ~(y == y) + assert_equal(f(y, mask), g(y, mask)) + g = torch.jit.trace(g, (y, mask)) + assert_equal(f(y, mask), g(y, mask)) + + +@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/11614") +def test_scatter(): + + def make_one_hot(x, i): + return x.new_zeros(x.shape).scatter(-1, i.unsqueeze(-1), 1.0) + + x = torch.randn(5, 4, 3) + i = torch.randint(0, 3, torch.Size((5, 4))) + torch.jit.trace(make_one_hot, (x, i)) + + +@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python integer') +def test_scatter_workaround(): + + def make_one_hot_expected(x, i): + return x.new_zeros(x.shape).scatter(-1, i.unsqueeze(-1), 1.0) + + def make_one_hot_actual(x, i): + eye = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) + return eye[i].clone() + + x = torch.randn(5, 4, 3) + i = torch.randint(0, 3, torch.Size((5, 4))) + torch.jit.trace(make_one_hot_actual, (x, i)) + expected = make_one_hot_expected(x, i) + actual = make_one_hot_actual(x, i) + assert_equal(actual, expected) + + +@pytest.mark.parametrize('expand', [False, True]) +@pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) +@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python boolean') +def test_bernoulli_enumerate(shape, expand): + shape = torch.Size(shape) + probs = torch.empty(shape).fill_(0.25) + + @pyro.ops.jit.trace + def f(probs): + d = dist.Bernoulli(probs) + support = d.enumerate_support(expand=expand) + return d.log_prob(support) + + log_prob = f(probs) + assert log_prob.shape == (2,) + shape + + +@pytest.mark.parametrize('expand', [False, True]) +@pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) +def test_categorical_enumerate(shape, expand): + shape = torch.Size(shape) + probs = torch.ones(shape) + + @pyro.ops.jit.trace + def f(probs): + d = dist.Categorical(probs) + support = d.enumerate_support(expand=expand) + return d.log_prob(support) + + log_prob = f(probs) + batch_shape = shape[:-1] + assert log_prob.shape == shape[-1:] + batch_shape + + +@pytest.mark.parametrize('expand', [False, True]) +@pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) +@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python integer') +def test_one_hot_categorical_enumerate(shape, expand): + shape = torch.Size(shape) + probs = torch.ones(shape) + + @pyro.ops.jit.trace + def f(probs): + d = dist.OneHotCategorical(probs) + support = d.enumerate_support(expand=expand) + return d.log_prob(support) + + log_prob = f(probs) + batch_shape = shape[:-1] + assert log_prob.shape == shape[-1:] + batch_shape + + @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, @@ -99,11 +248,11 @@ def f(x, y): ]) def test_svi(Elbo, num_particles): pyro.clear_param_store() - data = torch.arange(10) + data = torch.arange(10.) def model(data): - loc = pyro.param("loc", torch.tensor(0.0)) - scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) + loc = pyro.param("loc", constant(0.0)) + scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) pyro.sample("x", dist.Normal(loc, scale).expand_by(data.shape).independent(1), obs=data) def guide(data): @@ -118,18 +267,11 @@ def guide(data): @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("irange_dim", [1, 2]) -@pytest.mark.parametrize('Elbo', [ - Trace_ELBO, - JitTrace_ELBO, - TraceGraph_ELBO, - JitTraceGraph_ELBO, - TraceEnum_ELBO, - JitTraceEnum_ELBO, -]) +@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 - q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval) + q = pyro.param("q", constant(0.75), constraint=constraints.unit_interval) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): @@ -149,9 +291,10 @@ def guide(): inner_particles = 2 outer_particles = num_particles // inner_particles - elbo = TraceEnum_ELBO(max_iarange_nesting=0, - strict_enumeration_warning=any([enumerate1, enumerate2]), - num_particles=inner_particles) + elbo = Elbo(max_iarange_nesting=0, + strict_enumeration_warning=any([enumerate1, enumerate2]), + num_particles=inner_particles, + ignore_jit_warnings=True) actual_loss = sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles actual_grad = q.unconstrained().grad / outer_particles @@ -167,25 +310,21 @@ def guide(): @pytest.mark.parametrize('vectorized', [False, True]) -@pytest.mark.parametrize('Elbo', [ - TraceEnum_ELBO, - xfail_param(JitTraceEnum_ELBO, - reason="jit RuntimeError: Unsupported op descriptor: stack-2-dim_i"), -]) +@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_beta_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) def model1(data): - alpha0 = torch.tensor(10.0) - beta0 = torch.tensor(10.0) + alpha0 = constant(10.0) + beta0 = constant(10.0) f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) for i in pyro.irange("irange", len(data)): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) def model2(data): - alpha0 = torch.tensor(10.0) - beta0 = torch.tensor(10.0) + alpha0 = constant(10.0) + beta0 = constant(10.0) f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), obs=data) @@ -193,36 +332,63 @@ def model2(data): model = model2 if vectorized else model1 def guide(data): - alpha_q = pyro.param("alpha_q", torch.tensor(15.0), + alpha_q = pyro.param("alpha_q", constant(15.0), constraint=constraints.positive) - beta_q = pyro.param("beta_q", torch.tensor(15.0), + beta_q = pyro.param("beta_q", constant(15.0), constraint=constraints.positive) pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) - elbo = Elbo(num_particles=7, strict_enumeration_warning=False) + elbo = Elbo(num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): svi.step(data) -@pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [ + Trace_ELBO, + JitTrace_ELBO, + TraceGraph_ELBO, + JitTraceGraph_ELBO, TraceEnum_ELBO, - xfail_param(JitTraceEnum_ELBO, reason="jit RuntimeError in Dirichlet.rsample"), + JitTraceEnum_ELBO, ]) +def test_svi_irregular_batch_size(Elbo): + pyro.clear_param_store() + + @poutine.broadcast + def model(data): + loc = pyro.param("loc", constant(0.0)) + scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) + with pyro.iarange("data", data.shape[0]): + pyro.sample("x", + dist.Normal(loc, scale).expand([data.shape[0]]), + obs=data) + + def guide(data): + pass + + pyro.clear_param_store() + elbo = Elbo(strict_enumeration_warning=False, max_iarange_nesting=1) + inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) + inference.step(torch.ones(10)) + inference.step(torch.ones(3)) + + +@pytest.mark.parametrize('vectorized', [False, True]) +@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_dirichlet_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) def model1(data): - concentration0 = torch.tensor([10.0, 10.0]) + concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] for i in pyro.irange("irange", len(data)): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) def model2(data): - concentration0 = torch.tensor([10.0, 10.0]) + concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), obs=data) @@ -230,11 +396,11 @@ def model2(data): model = model2 if vectorized else model1 def guide(data): - concentration_q = pyro.param("concentration_q", torch.tensor([15.0, 15.0]), + concentration_q = pyro.param("concentration_q", constant([15.0, 15.0]), constraint=constraints.positive) pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) - elbo = Elbo(num_particles=7, strict_enumeration_warning=False) + elbo = Elbo(num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): @@ -249,3 +415,12 @@ def test_cond_indep_equality(x, y): assert x == y assert not x != y assert hash(x) == hash(y) + + +def test_jit_arange_workaround(): + def fn(x): + y = torch.ones(x.shape[0], dtype=torch.long, device=x.device) + return torch.cumsum(y, 0) - 1 + + compiled = torch.jit.trace(fn, torch.ones(3)) + assert_equal(compiled(torch.ones(10)), torch.arange(10)) diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 7b2f60f744..2a969a98f3 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -66,18 +66,8 @@ def model(): return lambda_latent def guide(): - alpha_q_log = pyro.param( - "alpha_q_log", - torch.tensor( - log_alpha_n.data + - 0.17, - requires_grad=True)) - beta_q_log = pyro.param( - "beta_q_log", - torch.tensor( - log_beta_n.data - - 0.143, - requires_grad=True)) + alpha_q_log = pyro.param("alpha_q_log", log_alpha_n + 0.17) + beta_q_log = pyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) diff --git a/tests/test_examples.py b/tests/test_examples.py index dd3e93ace5..8c5ada06ca 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -66,26 +66,34 @@ 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda', ] + +def xfail_jit(*args): + return pytest.param(*args, marks=[pytest.mark.xfail(reason="not jittable"), + pytest.mark.skipif('CI' in os.environ, reason='slow test')]) + + JIT_EXAMPLES = [ - 'air/main.py --num-steps=1 --jit', - 'bayesian_regression.py --num-epochs=1 --jit', - 'contrib/autoname/mixture.py --num-epochs=1 --jit', - 'dmm/dmm.py --num-epochs=1 --jit', - 'dmm/dmm.py --num-epochs=1 --num-iafs=1 --jit', - 'eight_schools/svi.py --num-epochs=1 --jit', - 'examples/contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=1 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=2 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=3 --jit', - 'hmm.py --num-steps=1 --truncate=10 --model=4 --jit', - 'lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --jit', - 'vae/vae.py --num-epochs=1 --jit', - 'vae/vae_comparison.py --num-epochs=1 --jit', - 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit', + xfail_jit('air/main.py --num-steps=1 --jit'), + xfail_jit('baseball.py --num-samples=200 --warmup-steps=100 --jit'), + xfail_jit('bayesian_regression.py --num-epochs=1 --jit'), + xfail_jit('contrib/autoname/mixture.py --num-epochs=1 --jit'), + xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), + xfail_jit('dmm/dmm.py --num-epochs=1 --jit'), + xfail_jit('dmm/dmm.py --num-epochs=1 --num-iafs=1 --jit'), + xfail_jit('eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit'), + xfail_jit('eight_schools/svi.py --num-epochs=1 --jit'), + xfail_jit('examples/contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), + xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=1 --jit'), + xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=2 --jit'), + xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=3 --jit'), + xfail_jit('hmm.py --num-steps=1 --truncate=10 --model=4 --jit'), + xfail_jit('lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --jit'), + xfail_jit('vae/vae.py --num-epochs=1 --jit'), + xfail_jit('vae/vae_comparison.py --num-epochs=1 --jit'), ] @@ -129,8 +137,6 @@ def test_cuda(example): check_call([sys.executable, filename] + args) -@pytest.mark.skipif('CI' in os.environ, reason='slow test') -@pytest.mark.xfail(reason='not jittable') @pytest.mark.parametrize('example', JIT_EXAMPLES) def test_jit(example): logger.info('Running:\npython examples/{}'.format(example))