From 4c6d9b78afe7e3755cea153b660864aa585aa178 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 26 Jul 2018 17:44:12 -0700 Subject: [PATCH 001/157] Fix PyTorch 0.4.1 errors --- pyro/contrib/gp/likelihoods/binary.py | 4 ++-- tests/distributions/conftest.py | 8 ++++---- tests/distributions/test_delta.py | 14 +++++++------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyro/contrib/gp/likelihoods/binary.py b/pyro/contrib/gp/likelihoods/binary.py index 897c15e371..aa72b1a536 100644 --- a/pyro/contrib/gp/likelihoods/binary.py +++ b/pyro/contrib/gp/likelihoods/binary.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, division, print_function -import torch.nn.functional as F +import torch import pyro import pyro.distributions as dist @@ -23,7 +23,7 @@ class Binary(Likelihood): def __init__(self, response_function=None, name="Binary"): super(Binary, self).__init__(name) self.response_function = (response_function if response_function is not None - else F.sigmoid) + else torch.sigmoid) def forward(self, f_loc, f_var, y=None): r""" diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 714061ccb6..38e8bf8c4f 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -207,14 +207,14 @@ Fixture(pyro_dist=dist.HalfCauchy, scipy_dist=sp.halfcauchy, examples=[ - {'loc': [0.5], 'scale': [1.2], + {'scale': [1.2], 'test_data': [1.0]}, - {'loc': [0.5, -1.5], 'scale': [1.2, 1.2], + {'scale': [1.2, 1.2], 'test_data': [[1.0, -1.0], [1.0, -1.0]]}, - {'loc': [[0.5], [0.3]], 'scale': [[1.2], [1.0]], + {'scale': [[1.2], [1.0]], 'test_data': [[0.54], [0.35]]} ], - scipy_arg_fn=lambda loc, scale: ((), {"loc": np.array(loc), "scale": np.array(scale)})), + scipy_arg_fn=lambda scale: ((), {"scale": np.array(scale)})), Fixture(pyro_dist=dist.VonMises, scipy_dist=sp.vonmises, examples=[ diff --git a/tests/distributions/test_delta.py b/tests/distributions/test_delta.py index 08a6f016bc..3d178ef788 100644 --- a/tests/distributions/test_delta.py +++ b/tests/distributions/test_delta.py @@ -16,13 +16,13 @@ def setUp(self): self.vs = torch.tensor([[0.0], [1.0], [2.0], [3.0]]) self.vs_expanded = self.vs.expand(4, 3) self.test_data = torch.tensor([[3.0], [3.0], [3.0]]) - self.batch_test_data_1 = torch.arange(0, 4).unsqueeze(1).expand(4, 3) - self.batch_test_data_2 = torch.arange(4, 8).unsqueeze(1).expand(4, 3) - self.batch_test_data_3 = torch.Tensor([[3], [3], [3], [3]]) - self.expected_support = [[[0], [1], [2], [3]]] - self.expected_support_non_vec = [[3]] - self.analytic_mean = 3 - self.analytic_var = 0 + self.batch_test_data_1 = torch.arange(0., 4.).unsqueeze(1).expand(4, 3) + self.batch_test_data_2 = torch.arange(4., 8.).unsqueeze(1).expand(4, 3) + self.batch_test_data_3 = torch.Tensor([[3.], [3.], [3.], [3.]]) + self.expected_support = [[[0.], [1.], [2.], [3.]]] + self.expected_support_non_vec = [[3.]] + self.analytic_mean = 3. + self.analytic_var = 0. self.n_samples = 10 def test_log_prob_sum(self): From cd945f3aa0d838e68e481732b5b8667fe8077dcc Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 26 Jul 2018 17:49:06 -0700 Subject: [PATCH 002/157] fix tests for pytorch 0.4.1 --- .travis.yml | 4 ++-- pyro/contrib/tracking/hashing.py | 7 ++++--- setup.py | 2 +- tests/infer/mcmc/test_hmc.py | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index a686aa24d2..9b367e2a5b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,9 +9,9 @@ env: install: - pip install -U pip - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then - pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl; + pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; else - pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl; + pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; fi - pip install .[test] - pip freeze diff --git a/pyro/contrib/tracking/hashing.py b/pyro/contrib/tracking/hashing.py index ad71549a7e..a9105eefc5 100644 --- a/pyro/contrib/tracking/hashing.py +++ b/pyro/contrib/tracking/hashing.py @@ -30,12 +30,13 @@ class LSH(object): >>> lsh.add('a', a) >>> lsh.add('b', b) >>> lsh.add('c', c) - >>> lsh.nearby('a') # even though c is within 2radius of a + >>> # even though c is within 2radius of a + >>> lsh.nearby('a') # doctest: +SKIP set(['b']) - >>> lsh.nearby('b') + >>> lsh.nearby('b') # doctest: +SKIP set(['a', 'c']) >>> lsh.remove('b') - >>> lsh.nearby('a') + >>> lsh.nearby('a') # doctest: +SKIP set([]) diff --git a/setup.py b/setup.py index a5615d49f1..1980142a37 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,7 @@ 'networkx>=2.0.0', 'numpy>=1.7', 'six>=1.10.0', - 'torch>=0.4.0', + 'torch>=0.4.1', ], extras_require={ 'extras': EXTRAS_REQUIRE, diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 27edbe2fe9..3888afbead 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -155,8 +155,8 @@ def test_hmc_conjugate_gaussian(fixture, def test_logistic_regression(): dim = 3 - true_coefs = torch.arange(1, dim+1) data = torch.randn(2000, dim) + true_coefs = torch.arange(1, dim+1).type(data.type()) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -220,8 +220,8 @@ def model(data): def test_logistic_regression_with_dual_averaging(): dim = 3 - true_coefs = torch.arange(1, dim+1) data = torch.randn(2000, dim) + true_coefs = torch.arange(1, dim+1).type(data.type()) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): From 4499b9cd7f539201bbded350a2436612dadf5bdf Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 26 Jul 2018 18:03:59 -0700 Subject: [PATCH 003/157] fix test_mask --- tests/distributions/test_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributions/test_mask.py b/tests/distributions/test_mask.py index f93b45f556..33e3944673 100644 --- a/tests/distributions/test_mask.py +++ b/tests/distributions/test_mask.py @@ -11,7 +11,7 @@ def checker_mask(shape): mask = tensor(0.) for size in shape: - mask = mask.unsqueeze(-1) + torch.arange(size) + mask = mask.unsqueeze(-1) + torch.arange(size).type(mask.type()) return mask.fmod(2) From f707101047f469ccb949d6b5867d4133909e398e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 26 Jul 2018 18:11:29 -0700 Subject: [PATCH 004/157] skip JIT tests --- tests/infer/test_jit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 468c70276f..633823b13c 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -13,6 +13,8 @@ from pyro.optim import Adam from tests.common import assert_equal, xfail_param +pytestmark = pytest.mark.skip(reason="Requires update - https://github.com/uber/pyro/issues/1063") + def test_simple(): y = torch.ones(2) @@ -82,6 +84,7 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) +@pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1063") @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, From 4737e1041073501d42e05a8694df35b09253098f Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 26 Jul 2018 18:13:56 -0700 Subject: [PATCH 005/157] fix NUTS tests --- tests/infer/mcmc/test_nuts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 3ed8975aaf..135c146377 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -69,8 +69,8 @@ def test_nuts_conjugate_gaussian(fixture, def test_logistic_regression(): dim = 3 - true_coefs = torch.arange(1, dim+1) data = torch.randn(2000, dim) + true_coefs = torch.arange(1, dim+1).type(data.type()) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -119,8 +119,8 @@ def model(data): def test_logistic_regression_with_dual_averaging(): dim = 3 - true_coefs = torch.arange(1, dim+1) data = torch.randn(2000, dim) + true_coefs = torch.arange(1, dim+1).type(data.type()) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): From 7dc2bc1ef8fa2ff0053a58258dff99096987e01e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 26 Jul 2018 18:20:01 -0700 Subject: [PATCH 006/157] fix gaussian_scale_mixture --- pyro/distributions/gaussian_scale_mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/distributions/gaussian_scale_mixture.py b/pyro/distributions/gaussian_scale_mixture.py index 7a588bcc4a..65c826f950 100644 --- a/pyro/distributions/gaussian_scale_mixture.py +++ b/pyro/distributions/gaussian_scale_mixture.py @@ -127,7 +127,7 @@ def backward(ctx, grad_output): q_tot = (pis * q_j).sum(-1, keepdim=True) # l Phi_j = torch.exp(-0.5 * r_sqr_j) # l j - exponents = - torch.arange(1, int(dim/2) + 1, 1) + exponents = - torch.arange(1, int(dim/2) + 1, 1).type(grad_output.type()) if z.dim() > 1: r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim/2)) # l j d/2 else: From b58ef93f2800b93ea0d37cf76cc2f4e0b1f63460 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 26 Jul 2018 18:27:16 -0700 Subject: [PATCH 007/157] Mark test_em.py failures as xfail --- tests/contrib/tracking/test_em.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index c004a4553f..944a19ebdd 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -13,6 +13,7 @@ from pyro.infer import SVI, TraceEnum_ELBO from pyro.optim import Adam from pyro.optim.multi import MixedMultiOptimizer, Newton +from tests.common import xfail_param def make_args(): @@ -120,7 +121,10 @@ def generate_data(args): return detections -@pytest.mark.parametrize('assignment_grad', [False, True]) +@pytest.mark.parametrize('assignment_grad', [ + False, + xfail_param(True, reason="pytorch 0.4.1 RuntimeError: dim() called on undefined Tensor"), +]) def test_em(assignment_grad): args = make_args() args.assignment_grad = assignment_grad @@ -142,7 +146,10 @@ def test_em(assignment_grad): print('step {}, loss = {}'.format(step, loss.item())) -@pytest.mark.parametrize('assignment_grad', [False, True]) +@pytest.mark.parametrize('assignment_grad', [ + False, + xfail_param(True, reason="pytorch 0.4.1 RuntimeError: dim() called on undefined Tensor"), +]) def test_em_nested_in_svi(assignment_grad): args = make_args() args.assignment_grad = assignment_grad @@ -173,6 +180,7 @@ def test_em_nested_in_svi(assignment_grad): svi_step, loss, pyro.param('noise_scale').item())) +@pytest.mark.xfail(reason="pytorch 0.4.1 RuntimeError: dim() called on undefined Tensor") def test_svi_multi(): args = make_args() args.assignment_grad = True From 448a760355c3e5590f009feee254959541bc1850 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 26 Jul 2018 18:36:30 -0700 Subject: [PATCH 008/157] Allow inf values in assert_tensors_equal --- tests/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/common.py b/tests/common.py index 80c5089ec6..c7043aedd8 100644 --- a/tests/common.py +++ b/tests/common.py @@ -117,6 +117,7 @@ def assert_tensors_equal(a, b, prec=1e-5, msg=''): nan_mask = a != a assert torch.equal(nan_mask, b != b), msg diff = a - b + diff[a == b] = 0 # handle inf diff[nan_mask] = 0 if diff.is_signed(): diff = diff.abs() From 62f21d57666481020717f8a1a6bb4b249e0f0fde Mon Sep 17 00:00:00 2001 From: jpchen Date: Thu, 26 Jul 2018 23:00:38 -0700 Subject: [PATCH 009/157] update examples --- examples/air/modules.py | 6 +++--- examples/dmm/dmm.py | 6 ++---- examples/vae/vae.py | 3 +-- examples/vae/vae_comparison.py | 3 +-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/air/modules.py b/examples/air/modules.py index f3edc703fa..843e425864 100644 --- a/examples/air/modules.py +++ b/examples/air/modules.py @@ -1,5 +1,5 @@ import torch.nn as nn -from torch.nn.functional import sigmoid, softplus +from torch.nn.functional import softplus # Takes pixel intensities of the attention window to parameters (mean, @@ -29,7 +29,7 @@ def forward(self, z): a = self.mlp(z) if self.bias is not None: a = a + self.bias - return sigmoid(a) if self.use_sigmoid else a + return torch.sigmoid(a) if self.use_sigmoid else a # A general purpose module to construct networks that look like: @@ -68,7 +68,7 @@ def __init__(self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_la def forward(self, h): out = self.mlp(h) - z_pres_p = sigmoid(out[:, 0:self.z_pres_size]) + z_pres_p = torch.sigmoid(out[:, 0:self.z_pres_size]) z_where_loc = out[:, self.z_pres_size:self.z_pres_size + self.z_where_size] z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size):]) return z_pres_p, z_where_loc, z_where_scale diff --git a/examples/dmm/dmm.py b/examples/dmm/dmm.py index 6c17a896b5..a2c2af1f43 100644 --- a/examples/dmm/dmm.py +++ b/examples/dmm/dmm.py @@ -43,7 +43,6 @@ def __init__(self, input_dim, z_dim, emission_dim): self.lin_hidden_to_input = nn.Linear(emission_dim, input_dim) # initialize the two non-linearities used in the neural network self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() def forward(self, z_t): """ @@ -52,7 +51,7 @@ def forward(self, z_t): """ h1 = self.relu(self.lin_z_to_hidden(z_t)) h2 = self.relu(self.lin_hidden_to_hidden(h1)) - ps = self.sigmoid(self.lin_hidden_to_input(h2)) + ps = torch.sigmoid(self.lin_hidden_to_input(h2)) return ps @@ -76,7 +75,6 @@ def __init__(self, z_dim, transition_dim): self.lin_z_to_loc.bias.data = torch.zeros(z_dim) # initialize the three non-linearities used in the neural network self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() def forward(self, z_t_1): @@ -87,7 +85,7 @@ def forward(self, z_t_1): """ # compute the gating function _gate = self.relu(self.lin_gate_z_to_hidden(z_t_1)) - gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate)) + gate = torch.sigmoid(self.lin_gate_hidden_to_z(_gate)) # compute the 'proposed mean' _proposed_mean = self.relu(self.lin_proposed_mean_z_to_hidden(z_t_1)) proposed_mean = self.lin_proposed_mean_hidden_to_z(_proposed_mean) diff --git a/examples/vae/vae.py b/examples/vae/vae.py index ddc735780f..f1189727f8 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -49,7 +49,6 @@ def __init__(self, z_dim, hidden_dim): self.fc21 = nn.Linear(hidden_dim, 784) # setup the non-linearities self.softplus = nn.Softplus() - self.sigmoid = nn.Sigmoid() def forward(self, z): # define the forward computation on the latent z @@ -57,7 +56,7 @@ def forward(self, z): hidden = self.softplus(self.fc1(z)) # return the parameter for the output Bernoulli # each is of size batch_size x 784 - loc_img = self.sigmoid(self.fc21(hidden)) + loc_img = torch.sigmoid(self.fc21(hidden)) return loc_img diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index 31696bcbc2..a61742137c 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -50,12 +50,11 @@ def __init__(self): super(Decoder, self).__init__() self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) - self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU() def forward(self, z): h3 = self.relu(self.fc3(z)) - return self.sigmoid(self.fc4(h3)) + return torch.sigmoid(self.fc4(h3)) @add_metaclass(ABCMeta) From 814e025cab7037015145bc66c9f7383ba26780b8 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 27 Jul 2018 17:54:04 -0700 Subject: [PATCH 010/157] remove redundant xfail --- tests/infer/test_jit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 633823b13c..c3c1c2f147 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -84,7 +84,6 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) -@pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1063") @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, From e1cd9b4acea37d93389e80b12163f2036ae9069e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 30 Jul 2018 11:46:33 -0700 Subject: [PATCH 011/157] Update JIT usage to PyTorch 0.4.1 (#1276) * Update JIT usage to PyTorch 0.4.1 * Add _broadcast_all() patch for jit exploration * Uncomment debug code --- docs/source/primitives.rst | 2 +- examples/air/modules.py | 1 + pyro/distributions/torch_patch.py | 49 ++++++++++++++++++++++++++++ pyro/infer/trace_elbo.py | 4 +-- pyro/infer/traceenum_elbo.py | 2 +- pyro/infer/tracegraph_elbo.py | 6 ++-- pyro/ops/jit.py | 54 +++++++++++++++---------------- tests/infer/test_gradient.py | 6 ++-- tests/infer/test_jit.py | 43 +++++++++++------------- 9 files changed, 104 insertions(+), 63 deletions(-) diff --git a/docs/source/primitives.rst b/docs/source/primitives.rst index fd47624400..bc7100304f 100644 --- a/docs/source/primitives.rst +++ b/docs/source/primitives.rst @@ -20,4 +20,4 @@ Primitives .. autofunction:: pyro.validation_enabled .. autofunction:: pyro.enable_validation -.. autofunction:: pyro.ops.jit.compile +.. autofunction:: pyro.ops.jit.trace diff --git a/examples/air/modules.py b/examples/air/modules.py index 843e425864..4ddc1f9ebd 100644 --- a/examples/air/modules.py +++ b/examples/air/modules.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from torch.nn.functional import softplus diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 119d866159..2fba7357e2 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +from numbers import Number + import torch @@ -45,4 +47,51 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) +# This version of broadcast_all() is compatible with early versions of the PyTorch jit, +# since it avoids torch._C._infer_size(). However it is more expensive since it infers +# size by summing the tensors. It is mainly useful for working around one jit limitation +# to discovering additional jit limitations. +# +# To temporarily apply this patch, uncomment one or more of the decorators: +# +# @_patch('torch.distributions.beta.broadcast_all') +# @_patch('torch.distributions.dirichlet.broadcast_all') +# @_patch('torch.distributions.normal.broadcast_all') +# @_patch('torch.distributions.utils.broadcast_all') +def _broadcast_all(*values): + r""" + Given a list of values (possibly containing numbers), returns a list where each + value is broadcasted based on the following rules: + - `torch.*Tensor` instances are broadcasted as per the `broadcasting rules + `_ + - numbers.Number instances (scalars) are upcast to tensors having + the same size and type as the first tensor passed to `values`. If all the + values are scalars, then they are upcasted to Tensors having size + `(1,)`. + + Args: + values (list of `numbers.Number` or `torch.*Tensor`) + + Raises: + ValueError: if any of the values is not a `numbers.Number` or + `torch.*Tensor` instance + """ + values = list(values) + scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)] + tensor_idxs = [i for i in range(len(values)) if values[i].__class__.__name__ == 'Tensor'] + if len(scalar_idxs) + len(tensor_idxs) != len(values): + raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.') + if tensor_idxs: + broadcast_shape = sum(values).size() # expensive alternative to torch._C._infer_size() + for idx in tensor_idxs: + values[idx] = values[idx].expand(broadcast_shape) + template = values[tensor_idxs[0]] + for idx in scalar_idxs: + values[idx] = template.new(template.size()).fill_(values[idx]) + else: + for idx in scalar_idxs: + values[idx] = torch.tensor(float(values[idx])) + return values + + __all__ = [] diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index f26537acbd..b1e4147029 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -160,7 +160,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 @@ -200,7 +200,7 @@ def loss_and_surrogate_loss(*args): # invoke _loss_and_surrogate_loss loss, surrogate_loss = self._loss_and_surrogate_loss(*args) - surrogate_loss.backward() # this line triggers jit compilation + surrogate_loss.backward() loss = loss.item() warn_if_nan(loss, "loss") diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 6efff0ff42..feb621b9ce 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -191,7 +191,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): weakself = weakref.ref(self) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace def differentiable_loss(*args): self = weakself() elbo = 0.0 diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index f3b437b93b..867247e7ee 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -256,7 +256,7 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace): class JitTraceGraph_ELBO(TraceGraph_ELBO): """ - Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.compile` to + Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.trace` to compile :meth:`loss_and_grads`. This works only for a limited set of models: @@ -276,7 +276,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 @@ -304,7 +304,7 @@ def loss_and_surrogate_loss(*args): self._loss_and_surrogate_loss = loss_and_surrogate_loss loss, surrogate_loss = self._loss_and_surrogate_loss(*args) - surrogate_loss.backward() # this line triggers jit compilation + surrogate_loss.backward() loss = loss.item() warn_if_nan(loss, "loss") diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index 8143b3c563..0fb2127301 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -7,22 +7,20 @@ class CompiledFunction(object): """ - Output type of :func:`pyro.ops.jit.compile`. + Output type of :func:`pyro.ops.jit.trace`. - Wrapper around the output of :func:`torch.jit.compile` + Wrapper around the output of :func:`torch.jit.trace` that handles parameter plumbing. The actual PyTorch compilation artifact is stored in :attr:`compiled`. Call diagnostic methods on this attribute. """ - def __init__(self, fn, **jit_options): + def __init__(self, fn): self.fn = fn - self._jit_options = jit_options self.compiled = None self._param_names = None def __call__(self, *args, **kwargs): - # if first time if self.compiled is None: # param capture @@ -31,48 +29,48 @@ def __call__(self, *args, **kwargs): self.fn(*args, **kwargs) self._param_names = list(set(first_param_capture.trace.nodes.keys())) - + unconstrained_params = [pyro.param(name).unconstrained() + for name in self._param_names] + params_and_args = unconstrained_params + list(args) weakself = weakref.ref(self) - @torch.jit.compile(**self._jit_options) - def compiled(unconstrained_params, *args): + @torch.jit.trace(*params_and_args) + def compiled(*params_and_args): self = weakself() + unconstrained_params = params_and_args[:len(self._param_names)] + args = params_and_args[len(self._param_names):] constrained_params = {} for name, unconstrained_param in zip(self._param_names, unconstrained_params): constrained_param = pyro.param(name) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param constrained_params[name] = constrained_param - - return poutine.replay( - self.fn, params=constrained_params)(*args, **kwargs) + return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) self.compiled = compiled - - param_list = [pyro.param(name).unconstrained() - for name in self._param_names] + else: + unconstrained_params = [pyro.param(name).unconstrained() + for name in self._param_names] + params_and_args = unconstrained_params + list(args) with poutine.block(hide=self._param_names): with poutine.trace(param_only=True) as param_capture: - ret = self.compiled(param_list, *args, **kwargs) - - new_params = filter(lambda name: name not in self._param_names, - param_capture.trace.nodes.keys()) + ret = self.compiled(*params_and_args) - for name in new_params: - # enforce uniqueness + for name in param_capture.trace.nodes.keys(): if name not in self._param_names: - self._param_names.append(name) + raise NotImplementedError('pyro.ops.jit.trace assumes all params are created on ' + 'first invocation, but found new param: {}'.format(name)) return ret -def compile(fn=None, **jit_options): +def trace(fn=None): """ - Drop-in replacement for :func:`torch.jit.compile` that works with + Lazy replacement for :func:`torch.jit.trace` that works with Pyro functions that call :func:`pyro.param`. - The actual compilation artifact is stored in the ``compiled`` attribute of the output. - Call diagnostic methods on this attribute. + The actual compilation artifact is stored in the ``compiled`` attribute of + the output. Call diagnostic methods on this attribute. Example:: @@ -80,12 +78,12 @@ def model(x): scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) return pyro.sample("y", dist.Normal(x, scale)) - @pyro.ops.jit.compile(nderivs=1) + @pyro.ops.jit.trace def model_log_prob_fn(x, y): cond_model = pyro.condition(model, data={"y": y}) tr = pyro.poutine.trace(cond_model).get_trace(x) return tr.log_prob_sum() """ if fn is None: - return lambda fn: compile(fn, **jit_options) - return CompiledFunction(fn, **jit_options) + return lambda fn: trace(fn) + return CompiledFunction(fn) diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index c1b7ce77cb..d5f0e76533 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -175,11 +175,11 @@ def guide(): TraceGraph_ELBO, TraceEnum_ELBO, xfail_param(JitTrace_ELBO, - reason="jit RuntimeError: Unsupported op descriptor: index-2"), + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), xfail_param(JitTraceGraph_ELBO, - reason="jit RuntimeError: Unsupported op descriptor: index-2"), + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), xfail_param(JitTraceEnum_ELBO, - reason="jit RuntimeError: Unsupported op descriptor: index-2"), + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), ]) def test_subsample_gradient_sequential(Elbo, reparameterized, subsample): pyro.clear_param_store() diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index c3c1c2f147..9bed5e17b5 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -13,32 +13,30 @@ from pyro.optim import Adam from tests.common import assert_equal, xfail_param -pytestmark = pytest.mark.skip(reason="Requires update - https://github.com/uber/pyro/issues/1063") - def test_simple(): y = torch.ones(2) - @torch.jit.compile(nderivs=0) + @torch.jit.trace(y) def f(x): print('Inside f') assert x is y return y + 1.0 print('Calling f(y)') - assert_equal(f(y), y.new_tensor([2, 2])) + assert_equal(f(y), y.new_tensor([2., 2.])) print('Calling f(y)') - assert_equal(f(y), y.new_tensor([2, 2])) + assert_equal(f(y), y.new_tensor([2., 2.])) print('Calling f(torch.zeros(2))') - assert_equal(f(torch.zeros(2)), y.new_tensor([1, 1])) - with pytest.raises(AssertionError): - assert_equal(f(torch.ones(5)), y.new_tensor([2, 2, 2, 2, 2])) + assert_equal(f(torch.zeros(2)), y.new_tensor([1., 1.])) + print('Calling f(torch.zeros(5))') + assert_equal(f(torch.ones(5)), y.new_tensor([2., 2., 2., 2., 2.])) def test_backward(): y = torch.ones(2, requires_grad=True) - @torch.jit.compile(nderivs=1) + @torch.jit.trace(y) def f(x): print('Inside f') assert x is y @@ -50,13 +48,13 @@ def f(x): f(y) print('Calling f(torch.zeros(2))') f(torch.zeros(2, requires_grad=True)) - with pytest.raises(AssertionError): - f(torch.ones(5, requires_grad=True)) + print('Calling f(torch.zeros(5))') + f(torch.ones(5, requires_grad=True)) def test_grad(): - @torch.jit.compile(nderivs=0) + @torch.jit.trace(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) def f(x, y): print('Inside f') loss = (x - y).pow(2).sum() @@ -68,11 +66,9 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(2, requires_grad=True)) -@pytest.mark.xfail(reason='RuntimeError: ' - 'saved_variables() needed but not implemented in ExpandBackward') def test_grad_expand(): - @torch.jit.compile(nderivs=0) + @torch.jit.trace(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) def f(x, y): print('Inside f') loss = (x - y).pow(2).sum() @@ -95,7 +91,7 @@ def f(x, y): ]) def test_svi(Elbo, num_particles): pyro.clear_param_store() - data = torch.arange(10) + data = torch.arange(10.) def model(data): loc = pyro.param("loc", torch.tensor(0.0)) @@ -115,10 +111,6 @@ def guide(data): @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("irange_dim", [1, 2]) @pytest.mark.parametrize('Elbo', [ - Trace_ELBO, - JitTrace_ELBO, - TraceGraph_ELBO, - JitTraceGraph_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO, ]) @@ -145,9 +137,9 @@ def guide(): inner_particles = 2 outer_particles = num_particles // inner_particles - elbo = TraceEnum_ELBO(max_iarange_nesting=0, - strict_enumeration_warning=any([enumerate1, enumerate2]), - num_particles=inner_particles) + elbo = Elbo(max_iarange_nesting=0, + strict_enumeration_warning=any([enumerate1, enumerate2]), + num_particles=inner_particles) actual_loss = sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles actual_grad = q.unconstrained().grad / outer_particles @@ -166,7 +158,7 @@ def guide(): @pytest.mark.parametrize('Elbo', [ TraceEnum_ELBO, xfail_param(JitTraceEnum_ELBO, - reason="jit RuntimeError: Unsupported op descriptor: stack-2-dim_i"), + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), ]) def test_beta_bernoulli(Elbo, vectorized): pyro.clear_param_store() @@ -205,7 +197,8 @@ def guide(data): @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [ TraceEnum_ELBO, - xfail_param(JitTraceEnum_ELBO, reason="jit RuntimeError in Dirichlet.rsample"), + xfail_param(JitTraceEnum_ELBO, + reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), ]) def test_dirichlet_bernoulli(Elbo, vectorized): pyro.clear_param_store() From 8cb89a30702e86d6c9f671761c6275d41b52f606 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 30 Jul 2018 18:35:57 -0700 Subject: [PATCH 012/157] use float in arange --- pyro/distributions/gaussian_scale_mixture.py | 2 +- tests/infer/mcmc/test_hmc.py | 4 ++-- tests/infer/mcmc/test_nuts.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyro/distributions/gaussian_scale_mixture.py b/pyro/distributions/gaussian_scale_mixture.py index 65c826f950..635a3b31a9 100644 --- a/pyro/distributions/gaussian_scale_mixture.py +++ b/pyro/distributions/gaussian_scale_mixture.py @@ -127,7 +127,7 @@ def backward(ctx, grad_output): q_tot = (pis * q_j).sum(-1, keepdim=True) # l Phi_j = torch.exp(-0.5 * r_sqr_j) # l j - exponents = - torch.arange(1, int(dim/2) + 1, 1).type(grad_output.type()) + exponents = - torch.arange(1, float(int(dim/2)) + 1, 1) if z.dim() > 1: r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim/2)) # l j d/2 else: diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 3888afbead..e04701f121 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -156,7 +156,7 @@ def test_hmc_conjugate_gaussian(fixture, def test_logistic_regression(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, dim+1).type(data.type()) + true_coefs = torch.arange(1, float(dim+1)) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -221,7 +221,7 @@ def model(data): def test_logistic_regression_with_dual_averaging(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, dim+1).type(data.type()) + true_coefs = torch.arange(1, float(dim+1)) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 135c146377..ecb49e62f9 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -70,7 +70,7 @@ def test_nuts_conjugate_gaussian(fixture, def test_logistic_regression(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, dim+1).type(data.type()) + true_coefs = torch.arange(1, float(dim+1)) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -120,7 +120,7 @@ def model(data): def test_logistic_regression_with_dual_averaging(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, dim+1).type(data.type()) + true_coefs = torch.arange(1, float(dim+1)) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): From 660b8d1979a64607280ddf567b2761cfc3a7ea16 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 11:30:05 -0700 Subject: [PATCH 013/157] Fix Categorical.enumerate_support to make JitTraceEnum_ELBO work --- pyro/distributions/torch.py | 8 +++++ tests/infer/test_jit.py | 66 +++++++++++++++++++++++++++++++------ 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index a10c41ec6b..3d17dfc26b 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -39,6 +39,14 @@ def expand(self, batch_shape): class Categorical(torch.distributions.Categorical, TorchDistributionMixin): + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = value.long().unsqueeze(-1) + value, log_pmf = torch.broadcast_tensors(value, self.logits) + value = value[..., :1] + return log_pmf.gather(-1, value).squeeze(-1) + def expand(self, batch_shape): try: return super(Categorical, self).expand(batch_shape) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 9bed5e17b5..9e77c33bf4 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -33,6 +33,25 @@ def f(x): assert_equal(f(torch.ones(5)), y.new_tensor([2., 2., 2., 2., 2.])) +def test_multi_output(): + y = torch.ones(2) + + @torch.jit.trace(y) + def f(x): + print('Inside f') + assert x is y + return y - 1.0, y + 1.0 + + print('Calling f(y)') + assert_equal(f(y)[1], y.new_tensor([2., 2.])) + print('Calling f(y)') + assert_equal(f(y)[1], y.new_tensor([2., 2.])) + print('Calling f(torch.zeros(2))') + assert_equal(f(torch.zeros(2))[1], y.new_tensor([1., 1.])) + print('Calling f(torch.zeros(5))') + assert_equal(f(torch.ones(5))[1], y.new_tensor([2., 2., 2., 2., 2.])) + + def test_backward(): y = torch.ones(2, requires_grad=True) @@ -52,6 +71,7 @@ def f(x): f(torch.ones(5, requires_grad=True)) +@pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad(): @torch.jit.trace(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) @@ -66,6 +86,7 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(2, requires_grad=True)) +@pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad_expand(): @torch.jit.trace(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) @@ -80,6 +101,39 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) +@pytest.mark.parametrize('expand', [False, True]) +@pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) +def test_bernoulli_enumerate(shape, expand): + shape = torch.Size(shape) + probs = torch.empty(shape).fill_(0.25) + + @torch.jit.trace(probs) + def f(probs): + d = dist.Bernoulli(probs) + support = d.enumerate_support(expand=expand) + return d.log_prob(support) + + log_prob = f(probs) + assert log_prob.shape == (2,) + shape + + +@pytest.mark.parametrize('expand', [False, True]) +@pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) +def test_categorical_enumerate(shape, expand): + shape = torch.Size(shape) + probs = torch.ones(shape) + + @torch.jit.trace(probs) + def f(probs): + d = dist.Categorical(probs) + support = d.enumerate_support(expand=expand) + return d.log_prob(support) + + log_prob = f(probs) + batch_shape = shape[:-1] + assert log_prob.shape == shape[-1:] + batch_shape + + @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, @@ -155,11 +209,7 @@ def guide(): @pytest.mark.parametrize('vectorized', [False, True]) -@pytest.mark.parametrize('Elbo', [ - TraceEnum_ELBO, - xfail_param(JitTraceEnum_ELBO, - reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), -]) +@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_beta_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) @@ -195,11 +245,7 @@ def guide(data): @pytest.mark.parametrize('vectorized', [False, True]) -@pytest.mark.parametrize('Elbo', [ - TraceEnum_ELBO, - xfail_param(JitTraceEnum_ELBO, - reason="in broadcast_all: RuntimeError: expected int at position 0, but got: Tensor"), -]) +@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_dirichlet_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) From d583dd183df9aa4ccccba855a96827924b9bbd8e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 12:31:07 -0700 Subject: [PATCH 014/157] Refactor test_examples.py to allow xfailing examples --- tests/test_examples.py | 105 ++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 54 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 3e97cddd97..cad18290bd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -14,57 +14,50 @@ CPU_EXAMPLES = [ - ['air/main.py', '--num-steps=1'], - ['baseball.py', '--num-samples=200', '--warmup-steps=100'], - ['bayesian_regression.py', '--num-epochs=1'], - ['contrib/autoname/scoping_mixture.py', '--num-epochs=1'], - ['contrib/autoname/mixture.py', '--num-epochs=1'], - ['contrib/autoname/tree_data.py', '--num-epochs=1'], - ['contrib/gp/sv-dkl.py', '--epochs=1', '--num-inducing=4'], - ['contrib/oed/ab_test.py', '--num-vi-steps=1000', '--num-acquisitions=2'], - ['dmm/dmm.py', '--num-epochs=1'], - ['dmm/dmm.py', '--num-epochs=1', '--num-iafs=1'], - ['eight_schools/mcmc.py', '--num-samples=500', '--warmup-steps=100'], - ['eight_schools/svi.py', '--num-epochs=1'], - ['inclined_plane.py', '--num-samples=1'], - ['rsa/generics.py', '--num-samples=10'], - ['rsa/hyperbole.py', '--price=10000'], - ['rsa/schelling.py', '--num-samples=10'], - ['rsa/schelling_false.py', '--num-samples=10'], - ['rsa/semantic_parsing.py', '--num-samples=10'], - ['sparse_gamma_def.py', '--num-epochs=1'], - ['vae/ss_vae_M2.py', '--num-epochs=1'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--aux-loss'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--enum-discrete=parallel'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--enum-discrete=sequential'], - ['vae/vae.py', '--num-epochs=1'], - ['vae/vae_comparison.py', '--num-epochs=1'], + 'air/main.py --num-steps=1', + 'baseball.py --num-samples=200 --warmup-steps=100', + 'bayesian_regression.py --num-epochs=1', + 'contrib/autoname/scoping_mixture.py --num-epochs=1', + 'contrib/autoname/mixture.py --num-epochs=1', + 'contrib/autoname/tree_data.py --num-epochs=1', + 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4', + 'contrib/oed/ab_test.py --num-vi-steps=1000 --num-acquisitions=2', + 'dmm/dmm.py --num-epochs=1', + 'dmm/dmm.py --num-epochs=1 --num-iafs=1', + 'eight_schools/mcmc.py --num-samples=500 --warmup-steps=100', + 'eight_schools/svi.py --num-epochs=1', + 'inclined_plane.py --num-samples=1', + 'rsa/generics.py --num-samples=10', + 'rsa/hyperbole.py --price=10000', + 'rsa/schelling.py --num-samples=10', + 'rsa/schelling_false.py --num-samples=10', + 'rsa/semantic_parsing.py --num-samples=10', + 'sparse_gamma_def.py --num-epochs=1', + 'vae/ss_vae_M2.py --num-epochs=1', + 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss', + 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel', + 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential', + 'vae/vae.py --num-epochs=1', + 'vae/vae_comparison.py --num-epochs=1', ] CUDA_EXAMPLES = [ - ['air/main.py', '--num-steps=1', '--cuda'], - ['bayesian_regression.py', '--num-epochs=1', '--cuda'], - ['contrib/gp/sv-dkl.py', '--epochs=1', '--num-inducing=4', '--cuda'], - ['dmm/dmm.py', '--num-epochs=1', '--cuda'], - ['dmm/dmm.py', '--num-epochs=1', '--num-iafs=1', '--cuda'], - ['vae/vae.py', '--num-epochs=1', '--cuda'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--cuda'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--aux-loss', '--cuda'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--enum-discrete=parallel', '--cuda'], - ['vae/ss_vae_M2.py', '--num-epochs=1', '--enum-discrete=sequential', '--cuda'], + 'air/main.py --num-steps=1 --cuda', + 'bayesian_regression.py --num-epochs=1 --cuda', + 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', + 'dmm/dmm.py --num-epochs=1 --cuda', + 'dmm/dmm.py --num-epochs=1 --num-iafs=1 --cuda', + 'vae/vae.py --num-epochs=1 --cuda', + 'vae/ss_vae_M2.py --num-epochs=1 --cuda', + 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda', + 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --cuda', + 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda', ] -CPU_EXAMPLES = [(example[0], example[1:]) for example in sorted(CPU_EXAMPLES)] -CUDA_EXAMPLES = [(example[0], example[1:]) for example in sorted(CUDA_EXAMPLES)] - - -def make_ids(examples): - return ['{} {}'.format(example, ' '.join(args)) for example, args in examples] - def test_coverage(): - cpu_tests = set([name for name, _ in CPU_EXAMPLES]) - cuda_tests = set([name for name, _ in CUDA_EXAMPLES]) + cpu_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in CPU_EXAMPLES) + cuda_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in CUDA_EXAMPLES) for root, dirs, files in os.walk(EXAMPLES_DIR): for basename in files: if not basename.endswith('.py'): @@ -80,16 +73,20 @@ def test_coverage(): pytest.fail('Example: {} not covered by CUDA_TESTS.'.format(example)) -@pytest.mark.parametrize('example,args', CPU_EXAMPLES, ids=make_ids(CPU_EXAMPLES)) -def test_cpu(example, args): - logger.info('Running:\npython examples/{} {}'.format(example, ' '.join(args))) - example = os.path.join(EXAMPLES_DIR, example) - check_call([sys.executable, example] + args) +@pytest.mark.parametrize('example', CPU_EXAMPLES) +def test_cpu(example): + logger.info('Running:\npython examples/{}'.format(example)) + example = example.split() + filename, args = example[0], example[1:] + filename = os.path.join(EXAMPLES_DIR, filename) + check_call([sys.executable, filename] + args) @requires_cuda -@pytest.mark.parametrize('example,args', CUDA_EXAMPLES, ids=make_ids(CUDA_EXAMPLES)) -def test_cuda(example, args): - logger.info('Running:\npython examples/{} {}'.format(example, ' '.join(args))) - example = os.path.join(EXAMPLES_DIR, example) - check_call([sys.executable, example] + args) +@pytest.mark.parametrize('example', CUDA_EXAMPLES) +def test_cuda(example): + logger.info('Running:\npython examples/{}'.format(example)) + example = example.split() + filename, args = example[0], example[1:] + filename = os.path.join(EXAMPLES_DIR, filename) + check_call([sys.executable, filename] + args) From fc9a4a7fceb01dd923373ef2b01bb4d368661111 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 12:31:35 -0700 Subject: [PATCH 015/157] Add xfailing examples that use --jit --- examples/dmm/dmm.py | 18 ++++++++++-------- examples/vae/vae.py | 6 ++++-- tests/infer/test_jit.py | 1 - tests/test_examples.py | 6 +++++- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/dmm/dmm.py b/examples/dmm/dmm.py index 1eb676b2a9..c345947ae9 100644 --- a/examples/dmm/dmm.py +++ b/examples/dmm/dmm.py @@ -27,7 +27,7 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions import InverseAutoregressiveFlow, TransformedDistribution -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import ClippedAdam from util import get_logger @@ -322,7 +322,8 @@ def rep(x): adam = ClippedAdam(adam_params) # setup inference algorithm - elbo = SVI(dmm.model, dmm.guide, adam, Trace_ELBO()) + elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() + svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # now we're going to define some functions we need to form the main training loop @@ -365,8 +366,8 @@ def process_minibatch(epoch, which_mini_batch, shuffled_indices): = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # do an actual gradient step - loss = elbo.step(mini_batch, mini_batch_reversed, mini_batch_mask, - mini_batch_seq_lengths, annealing_factor) + loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask, + mini_batch_seq_lengths, annealing_factor) # keep track of the training loss return loss @@ -376,10 +377,10 @@ def do_evaluation(): dmm.rnn.eval() # compute the validation and test loss n_samples many times - val_nll = elbo.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, - val_seq_lengths) / np.sum(val_seq_lengths) - test_nll = elbo.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask, - test_seq_lengths) / np.sum(test_seq_lengths) + val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, + val_seq_lengths) / np.sum(val_seq_lengths) + test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask, + test_seq_lengths) / np.sum(test_seq_lengths) # put the RNN back into training mode (i.e. turn on drop-out if applicable) dmm.rnn.train() @@ -443,6 +444,7 @@ def do_evaluation(): parser.add_argument('-sopt', '--save-opt', type=str, default='') parser.add_argument('-smod', '--save-model', type=str, default='') parser.add_argument('--cuda', action='store_true') + parser.add_argument('--jit', action='store_true') parser.add_argument('-l', '--log', type=str, default='dmm.log') args = parser.parse_args() diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 616e5cbb94..0bfff4621d 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -7,7 +7,7 @@ import pyro import pyro.distributions as dist -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam from utils.mnist_cached import MNISTCached as MNIST from utils.mnist_cached import setup_data_loaders @@ -131,7 +131,8 @@ def main(args): optimizer = Adam(adam_args) # setup the inference algorithm - svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO()) + elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() + svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) # setup visdom for visualization if args.visdom_flag: @@ -203,6 +204,7 @@ def main(args): parser.add_argument('-tf', '--test-frequency', default=5, type=int, help='how often we evaluate the test set') parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate') parser.add_argument('--cuda', action='store_true', default=False, help='whether to use cuda') + parser.add_argument('--jit', action='store_true', default=False, help='whether to use PyTorch jit') parser.add_argument('-visdom', '--visdom_flag', action="store_true", help='Whether plotting in visdom is desired') parser.add_argument('-i-tsne', '--tsne_iter', default=100, type=int, help='epoch when tsne visualization runs') args = parser.parse_args() diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 9e77c33bf4..f94fb0c792 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -11,7 +11,6 @@ Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) from pyro.optim import Adam -from tests.common import assert_equal, xfail_param def test_simple(): diff --git a/tests/test_examples.py b/tests/test_examples.py index cad18290bd..4a8a42e42f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -7,7 +7,7 @@ import pytest -from tests.common import EXAMPLES_DIR, requires_cuda +from tests.common import EXAMPLES_DIR, requires_cuda, xfail_param logger = logging.getLogger(__name__) pytestmark = pytest.mark.stage('test_examples') @@ -39,6 +39,8 @@ 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential', 'vae/vae.py --num-epochs=1', 'vae/vae_comparison.py --num-epochs=1', + xfail_param('dmm/dmm.py --num-epochs=1 --jit', reason='not jittable'), + xfail_param('vae/vae.py --num-epochs=1 --jit', reason='not jittable'), ] CUDA_EXAMPLES = [ @@ -52,6 +54,8 @@ 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda', + xfail_param('dmm/dmm.py --num-epochs=1 --cuda --jit', reason='not jittable'), + xfail_param('vae/vae.py --num-epochs=1 --cuda --jit', reason='not jittable'), ] From 15fba47ee2ebe987e781aa0cea05ae7d7a0fb5ab Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 14:47:50 -0700 Subject: [PATCH 016/157] Fix missing import in test_jit.py --- tests/infer/test_jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index f94fb0c792..e02f42c8ac 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -7,10 +7,10 @@ import pyro import pyro.distributions as dist -from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, - Trace_ELBO, TraceEnum_ELBO, +from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) from pyro.optim import Adam +from tests.common import assert_equal def test_simple(): From 91bf7482d2060782e01d4f616ed2910e8cce7159 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 15:34:00 -0700 Subject: [PATCH 017/157] Enable jit in most SVI examples --- examples/air/air.py | 2 +- examples/air/main.py | 12 +++++---- examples/bayesian_regression.py | 13 ++++----- examples/contrib/autoname/mixture.py | 8 +++--- examples/contrib/gp/sv-dkl.py | 5 +++- examples/eight_schools/svi.py | 6 +++-- examples/vae/ss_vae_M2.py | 10 ++++--- examples/vae/vae_comparison.py | 6 +++-- tests/infer/test_jit.py | 5 +--- tests/test_examples.py | 40 +++++++++++++++++++++++----- 10 files changed, 73 insertions(+), 34 deletions(-) diff --git a/examples/air/air.py b/examples/air/air.py index f4524d85db..0c3d320c90 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -161,7 +161,7 @@ def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p): return ModelState(x=x, z_pres=z_pres, z_where=z_where) - def model(self, data, _, **kwargs): + def model(self, data, batch_size, **kwargs): pyro.module("decode", self.decode) with pyro.iarange('data', data.size(0), use_cuda=self.use_cuda) as ix: batch = data[ix] diff --git a/examples/air/main.py b/examples/air/main.py index df89df210c..6912463503 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -22,7 +22,7 @@ import pyro.optim as optim import pyro.poutine as poutine from air import AIR, latents_to_tensor -from pyro.infer import SVI, TraceGraph_ELBO +from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO from viz import draw_many, tensor_to_objs @@ -198,9 +198,9 @@ def per_param_optim_args(module_name, param_name): lr = args.baseline_learning_rate if 'bl_' in param_name else args.learning_rate return {'lr': lr} - svi = SVI(air.model, air.guide, - optim.Adam(per_param_optim_args), - loss=TraceGraph_ELBO()) + adam = optim.Adam(per_param_optim_args) + elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO() + svi = SVI(air.model, air.guide, adam, loss=elbo) # Do inference. t0 = time.time() @@ -208,7 +208,7 @@ def per_param_optim_args(module_name, param_name): for i in range(1, args.num_steps + 1): - loss = svi.step(X, args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)) + loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)) if args.progress_every > 0 and i % args.progress_every == 0: print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( @@ -284,6 +284,8 @@ def per_param_optim_args(module_name, param_name): help='number of steps between parameter saves') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') + parser.add_argument('--jit', action='store_true', default=False, + help='use PyTorch jit') parser.add_argument('-t', '--model-steps', type=int, default=3, help='number of time steps') parser.add_argument('--rnn-hidden-size', type=int, default=256, diff --git a/examples/bayesian_regression.py b/examples/bayesian_regression.py index c3efd0f46b..2ca58da1df 100644 --- a/examples/bayesian_regression.py +++ b/examples/bayesian_regression.py @@ -7,7 +7,7 @@ import pyro from pyro.distributions import Bernoulli, Normal # noqa: F401 -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam @@ -94,11 +94,6 @@ def guide(data): return lifted_module() -# instantiate optim and inference objects -optim = Adam({"lr": 0.05}) -svi = SVI(model, guide, optim, loss=Trace_ELBO()) - - # get array of batch indices def get_batch_indices(N, batch_size): all_batches = np.arange(0, N, batch_size) @@ -115,6 +110,11 @@ def main(args): data = data.cuda() softplus.cuda() regression_model.cuda() + + # perform inference + optim = Adam({"lr": 0.05}) + elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() + svi = SVI(model, guide, optim, loss=elbo) for j in range(args.num_epochs): if args.batch_size == N: # use the entire data set @@ -140,5 +140,6 @@ def main(args): parser.add_argument('-n', '--num-epochs', default=1000, type=int) parser.add_argument('-b', '--batch-size', default=N, type=int) parser.add_argument('--cuda', action='store_true') + parser.add_argument('--jit', action='store_true') args = parser.parse_args() main(args) diff --git a/examples/contrib/autoname/mixture.py b/examples/contrib/autoname/mixture.py index d436577262..eda73d62f5 100644 --- a/examples/contrib/autoname/mixture.py +++ b/examples/contrib/autoname/mixture.py @@ -8,7 +8,7 @@ import pyro import pyro.distributions as dist from pyro.contrib.autoname import named -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam # This is a simple gaussian mixture model. @@ -55,7 +55,8 @@ def main(args): pyro.enable_validation() optim = Adam({"lr": 0.1}) - inference = SVI(model, guide, optim, loss=Trace_ELBO()) + elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() + inference = SVI(model, guide, optim, loss=elbo) data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0]) k = 2 @@ -65,7 +66,7 @@ def main(args): if step and step % 10 == 0: print('{}\t{:0.5g}'.format(step, loss)) loss = 0.0 - loss += inference.step(data, k) + loss += inference.step(data, k=k) print('Parameters:') for name in sorted(pyro.get_param_store().get_all_param_names()): @@ -75,5 +76,6 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=200, type=int) + parser.add_argument('--jit', action='store_true') args = parser.parse_args() main(args) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index ed0092b48a..1553d755ee 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -117,7 +117,8 @@ def cnn_fn(x): optimizer = optim.Adam({"lr": args.lr}) - svi = infer.SVI(gpmodel.model, gpmodel.guide, optimizer, infer.Trace_ELBO()) + elbo = infer.JitTrace_ELBO() if args.jit else infer.Trace_ELBO() + svi = infer.SVI(gpmodel.model, gpmodel.guide, optimizer, elbo) for epoch in range(1, args.epochs + 1): start_time = time.time() @@ -143,6 +144,8 @@ def cnn_fn(x): help='learning rate (default: 0.01)') parser.add_argument('--cuda', action='store_true', default=False, help='enables CUDA training') + parser.add_argument('--jit', action='store_true', default=False, + help='enables PyTorch jit') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', diff --git a/examples/eight_schools/svi.py b/examples/eight_schools/svi.py index 649948dbe7..d788c4e752 100644 --- a/examples/eight_schools/svi.py +++ b/examples/eight_schools/svi.py @@ -9,7 +9,7 @@ import pyro import pyro.distributions as dist from data import J, sigma, y -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam logging.basicConfig(format='%(message)s', level=logging.INFO) @@ -59,7 +59,8 @@ def guide(data): def main(args): optim = Adam({'lr': args.lr}) - svi = SVI(model, guide, optim, loss=Trace_ELBO()) + elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() + svi = SVI(model, guide, optim, loss=elbo) pyro.clear_param_store() for j in range(args.num_epochs): @@ -78,6 +79,7 @@ def main(args): help='learning rate (default: 0.01)') parser.add_argument('--num-epochs', type=int, default=1000, help='number of epochs (default: 1000)') + parser.add_argument('--jit', action='store_true', default=False) args = parser.parse_args() main(args) diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index e30d80e629..b7456a237f 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -7,7 +7,7 @@ import pyro import pyro.distributions as dist from pyro.contrib.examples.util import print_and_log, set_seed -from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate +from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate from pyro.optim import Adam from utils.custom_mlp import MLP, Exp from utils.mnist_cached import MNISTCached, mkdir_p, setup_data_loaders @@ -300,14 +300,16 @@ def main(args): # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum # by enumerating each class label for the sampled discrete categorical distribution in the model guide = config_enumerate(ss_vae.guide, args.enum_discrete) - loss_basic = SVI(ss_vae.model, guide, optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1)) + elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_iarange_nesting=1) + loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo) # build a list of all losses considered losses = [loss_basic] # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al) if args.aux_loss: - loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=Trace_ELBO()) + elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() + loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo) losses.append(loss_aux) try: @@ -383,6 +385,8 @@ def main(args): parser.add_argument('--cuda', action='store_true', help="use GPU(s) to speed up training") + parser.add_argument('--jit', action='store_true', + help="use PyTorch jit to speed up training") parser.add_argument('-n', '--num-epochs', default=50, type=int, help="number of epochs to run") parser.add_argument('--aux-loss', action="store_true", diff --git a/examples/vae/vae_comparison.py b/examples/vae/vae_comparison.py index a61742137c..40ca720c14 100644 --- a/examples/vae/vae_comparison.py +++ b/examples/vae/vae_comparison.py @@ -12,7 +12,7 @@ import pyro from pyro.contrib.examples import util from pyro.distributions import Bernoulli, Normal -from pyro.infer import SVI, Trace_ELBO +from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO from pyro.optim import Adam from utils.mnist_cached import DATA_DIR, RESULTS_DIR @@ -204,7 +204,8 @@ def compute_loss_and_gradient(self, x): def initialize_optimizer(self, lr): optimizer = Adam({'lr': lr}) - return SVI(self.model, self.guide, optimizer, loss=Trace_ELBO()) + elbo = JitTrace_ELBO() if self.args.jit else Trace_ELBO() + return SVI(self.model, self.guide, optimizer, loss=elbo) def setup(args): @@ -250,6 +251,7 @@ def main(args): parser.add_argument('--rng_seed', nargs='?', default=0, type=int) parser.add_argument('--impl', nargs='?', default='pyro', type=str) parser.add_argument('--skip_eval', action='store_true') + parser.add_argument('--jit', action='store_true') parser.set_defaults(skip_eval=False) args = parser.parse_args() main(args) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index e02f42c8ac..52cbc7dfae 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -163,10 +163,7 @@ def guide(data): @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("irange_dim", [1, 2]) -@pytest.mark.parametrize('Elbo', [ - TraceEnum_ELBO, - JitTraceEnum_ELBO, -]) +@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 diff --git a/tests/test_examples.py b/tests/test_examples.py index 4a8a42e42f..af7fe55374 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -7,7 +7,7 @@ import pytest -from tests.common import EXAMPLES_DIR, requires_cuda, xfail_param +from tests.common import EXAMPLES_DIR, requires_cuda logger = logging.getLogger(__name__) pytestmark = pytest.mark.stage('test_examples') @@ -39,8 +39,6 @@ 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential', 'vae/vae.py --num-epochs=1', 'vae/vae_comparison.py --num-epochs=1', - xfail_param('dmm/dmm.py --num-epochs=1 --jit', reason='not jittable'), - xfail_param('vae/vae.py --num-epochs=1 --jit', reason='not jittable'), ] CUDA_EXAMPLES = [ @@ -54,14 +52,30 @@ 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda', - xfail_param('dmm/dmm.py --num-epochs=1 --cuda --jit', reason='not jittable'), - xfail_param('vae/vae.py --num-epochs=1 --cuda --jit', reason='not jittable'), +] + +JIT_EXAMPLES = [ + 'air/main.py --num-steps=1 --jit', + 'bayesian_regression.py --num-epochs=1 --jit', + 'contrib/autoname/mixture.py --num-epochs=1 --jit', + 'dmm/dmm.py --num-epochs=1 --jit', + 'dmm/dmm.py --num-epochs=1 --num-iafs=1 --jit', + 'eight_schools/svi.py --num-epochs=1 --jit', + 'examples/contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit', + 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit', + 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit', + 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit', + 'vae/ss_vae_M2.py --num-epochs=1 --jit', + 'vae/vae.py --num-epochs=1 --jit', + 'vae/vae_comparison.py --num-epochs=1 --jit', + 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit', ] def test_coverage(): cpu_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in CPU_EXAMPLES) cuda_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in CUDA_EXAMPLES) + jit_tests = set((e if isinstance(e, str) else e.values[0]).split()[0] for e in JIT_EXAMPLES) for root, dirs, files in os.walk(EXAMPLES_DIR): for basename in files: if not basename.endswith('.py'): @@ -72,9 +86,11 @@ def test_coverage(): example = os.path.relpath(path, EXAMPLES_DIR) if '__main__' in text: if example not in cpu_tests: - pytest.fail('Example: {} not covered in CPU_TESTS.'.format(example)) + pytest.fail('Example: {} not covered in CPU_EXAMPLES.'.format(example)) if '--cuda' in text and example not in cuda_tests: - pytest.fail('Example: {} not covered by CUDA_TESTS.'.format(example)) + pytest.fail('Example: {} not covered by CUDA_EXAMPLES.'.format(example)) + if '--jit' in text and example not in jit_tests: + pytest.fail('Example: {} not covered by JIT_EXAMPLES.'.format(example)) @pytest.mark.parametrize('example', CPU_EXAMPLES) @@ -94,3 +110,13 @@ def test_cuda(example): filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) check_call([sys.executable, filename] + args) + + +@pytest.mark.xfail(reason="not jittable") +@pytest.mark.parametrize('example', JIT_EXAMPLES) +def test_jit(example): + logger.info('Running:\npython examples/{}'.format(example)) + example = example.split() + filename, args = example[0], example[1:] + filename = os.path.join(EXAMPLES_DIR, filename) + check_call([sys.executable, filename] + args) From 71779ecfe02b6344edbff7bd378fbfeedff82b59 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 18:14:59 -0700 Subject: [PATCH 018/157] Revert changes to torch_patch.py --- pyro/distributions/torch_patch.py | 49 ------------------------------- 1 file changed, 49 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 2fba7357e2..119d866159 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -1,7 +1,5 @@ from __future__ import absolute_import, division, print_function -from numbers import Number - import torch @@ -47,51 +45,4 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) -# This version of broadcast_all() is compatible with early versions of the PyTorch jit, -# since it avoids torch._C._infer_size(). However it is more expensive since it infers -# size by summing the tensors. It is mainly useful for working around one jit limitation -# to discovering additional jit limitations. -# -# To temporarily apply this patch, uncomment one or more of the decorators: -# -# @_patch('torch.distributions.beta.broadcast_all') -# @_patch('torch.distributions.dirichlet.broadcast_all') -# @_patch('torch.distributions.normal.broadcast_all') -# @_patch('torch.distributions.utils.broadcast_all') -def _broadcast_all(*values): - r""" - Given a list of values (possibly containing numbers), returns a list where each - value is broadcasted based on the following rules: - - `torch.*Tensor` instances are broadcasted as per the `broadcasting rules - `_ - - numbers.Number instances (scalars) are upcast to tensors having - the same size and type as the first tensor passed to `values`. If all the - values are scalars, then they are upcasted to Tensors having size - `(1,)`. - - Args: - values (list of `numbers.Number` or `torch.*Tensor`) - - Raises: - ValueError: if any of the values is not a `numbers.Number` or - `torch.*Tensor` instance - """ - values = list(values) - scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)] - tensor_idxs = [i for i in range(len(values)) if values[i].__class__.__name__ == 'Tensor'] - if len(scalar_idxs) + len(tensor_idxs) != len(values): - raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.') - if tensor_idxs: - broadcast_shape = sum(values).size() # expensive alternative to torch._C._infer_size() - for idx in tensor_idxs: - values[idx] = values[idx].expand(broadcast_shape) - template = values[tensor_idxs[0]] - for idx in scalar_idxs: - values[idx] = template.new(template.size()).fill_(values[idx]) - else: - for idx in scalar_idxs: - values[idx] = torch.tensor(float(values[idx])) - return values - - __all__ = [] From fcdddcc785cacb6fa4ec25f05a1e2f804ae44b9e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 7 Aug 2018 20:54:33 -0700 Subject: [PATCH 019/157] Work around jit issues; bayesian_regressian example now jits --- pyro/distributions/torch.py | 2 ++ pyro/distributions/util.py | 16 ++++++++++++-- pyro/poutine/broadcast_messenger.py | 3 ++- tests/test_examples.py | 33 ++++++++++++++++------------- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 3d17dfc26b..5db79e3306 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -39,6 +39,8 @@ def expand(self, batch_shape): class Categorical(torch.distributions.Categorical, TorchDistributionMixin): + + # log_prob can be deleted after https://github.com/pytorch/pytorch/pull/10321 def log_prob(self, value): if self._validate_args: self._validate_sample(value) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index bdab60c161..f843c2febb 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -52,7 +52,11 @@ def is_identically_zero(x): Check if argument is exactly the number zero. True for the number zero; false for other numbers; false for :class:`~torch.Tensor`s. """ - return isinstance(x, numbers.Number) and x == 0 + if isinstance(x, numbers.Number): + return x == 0 + elif isinstance(x, torch.Tensor) and x.dtype == torch.int64 and not x.shape: + return x.item() == 0 + return False def is_identically_one(x): @@ -60,7 +64,11 @@ def is_identically_one(x): Check if argument is exactly the number one. True for the number one; false for other numbers; false for :class:`~torch.Tensor`s. """ - return isinstance(x, numbers.Number) and x == 1 + if isinstance(x, numbers.Number): + return x == 1 + elif isinstance(x, torch.Tensor) and x.dtype == torch.int64 and not x.shape: + return x.item() == 1 + return False def broadcast_shape(*shapes, **kwargs): @@ -172,8 +180,12 @@ def scale_tensor(tensor, scale): # avoid NANs if not isinstance(scale, numbers.Number): + if scale.dtype == torch.int64: + scale = scale.float() result[(scale == 0).expand_as(result)] = 0 if not isinstance(tensor, numbers.Number): + if tensor.dtype == torch.int64: + tensor = tensor.float() result[(tensor == 0).expand_as(result)] = 0 return result diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index e4e04d28b6..03f2748d5e 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -21,7 +21,8 @@ def _pyro_sample(self, msg): dist = msg["fn"] actual_batch_shape = getattr(dist, "batch_shape", None) if actual_batch_shape is not None: - target_batch_shape = [None if size == 1 else size for size in actual_batch_shape] + target_batch_shape = [None if size == 1 else int(size) # int() is required by jit + for size in actual_batch_shape] for f in msg["cond_indep_stack"]: if f.dim is None or f.size == -1: continue diff --git a/tests/test_examples.py b/tests/test_examples.py index f5b0c5ae69..38d284864a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -54,21 +54,26 @@ 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --cuda', ] + +def xfail_jit(*args): + return pytest.param(*args, marks=[pytest.mark.xfail(reason="not jittable"), + pytest.mark.skipif('CI' in os.environ, reason='slow test')]) + + JIT_EXAMPLES = [ - 'air/main.py --num-steps=1 --jit', 'bayesian_regression.py --num-epochs=1 --jit', - 'contrib/autoname/mixture.py --num-epochs=1 --jit', - 'dmm/dmm.py --num-epochs=1 --jit', - 'dmm/dmm.py --num-epochs=1 --num-iafs=1 --jit', - 'eight_schools/svi.py --num-epochs=1 --jit', - 'examples/contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit', - 'vae/ss_vae_M2.py --num-epochs=1 --jit', - 'vae/vae.py --num-epochs=1 --jit', - 'vae/vae_comparison.py --num-epochs=1 --jit', - 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit', + xfail_jit('air/main.py --num-steps=1 --jit'), + xfail_jit('contrib/autoname/mixture.py --num-epochs=1 --jit'), + xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), + xfail_jit('dmm/dmm.py --num-epochs=1 --jit'), + xfail_jit('dmm/dmm.py --num-epochs=1 --num-iafs=1 --jit'), + xfail_jit('eight_schools/svi.py --num-epochs=1 --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit'), + xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --jit'), + xfail_jit('vae/vae.py --num-epochs=1 --jit'), + xfail_jit('vae/vae_comparison.py --num-epochs=1 --jit'), ] @@ -112,8 +117,6 @@ def test_cuda(example): check_call([sys.executable, filename] + args) -@pytest.mark.skipif('CI' in os.environ, reason='slow test') -@pytest.mark.xfail(reason='not jittable') @pytest.mark.parametrize('example', JIT_EXAMPLES) def test_jit(example): logger.info('Running:\npython examples/{}'.format(example)) From d260fe02e2c7611ded654db8ff03ad85a056cda0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 13:51:50 -0700 Subject: [PATCH 020/157] Fix doctests to pass on Python 2.7 --- pyro/contrib/tracking/hashing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/tracking/hashing.py b/pyro/contrib/tracking/hashing.py index a73fdd99f6..37c64a7d48 100644 --- a/pyro/contrib/tracking/hashing.py +++ b/pyro/contrib/tracking/hashing.py @@ -36,7 +36,7 @@ class LSH(object): >>> lsh.nearby('b') # doctest: +SKIP {'a', 'c'} >>> lsh.remove('b') - >>> lsh.nearby('a') + >>> lsh.nearby('a') # doctest: +SKIP set() From 9f7fd5426bd02bc7158cbd78ce75c9d64e2b401f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 14:21:51 -0700 Subject: [PATCH 021/157] Fix arange usage --- tests/infer/mcmc/test_valid_models.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/infer/mcmc/test_valid_models.py b/tests/infer/mcmc/test_valid_models.py index 434338dddd..f5b0266b0d 100644 --- a/tests/infer/mcmc/test_valid_models.py +++ b/tests/infer/mcmc/test_valid_models.py @@ -1,15 +1,13 @@ import logging -import torch - import pytest -from torch.nn.functional import sigmoid +import torch import pyro import pyro.distributions as dist -from pyro.infer import config_enumerate -from pyro.infer.mcmc import MCMC, HMC, NUTS import pyro.poutine as poutine +from pyro.infer import config_enumerate +from pyro.infer.mcmc import HMC, MCMC, NUTS from pyro.infer.mcmc.util import EnumTraceProbEvaluator from pyro.primitives import _Subsample from tests.common import assert_equal @@ -49,7 +47,7 @@ def test_model_error_stray_batch_dims(kernel, kwargs): def gmm(): data = torch.tensor([0., 0., 3., 3., 3., 5., 5.]) mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(3))) - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3), 1.)) + cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3.), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) @@ -74,7 +72,7 @@ def gmm(): data = torch.tensor([0., 0., 3., 3., 3., 5., 5.]) with pyro.iarange("num_clusters", 3): mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3), 1.)) + cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3.), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) @@ -260,7 +258,7 @@ def model(data): mean = 2 * y - 1 n = pyro.sample("n", dist.Normal(mean, 1.)) with pyro.iarange("data", len(data)): - pyro.sample("obs", dist.Bernoulli(sigmoid(n)), obs=data) + pyro.sample("obs", dist.Bernoulli(torch.sigmoid(n)), obs=data) model_trace = poutine.trace(model).get_trace(data) print_debug_info(model_trace) From d3cafb12f3d8e0a98b0458264d14a877eadbb719 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 14:22:36 -0700 Subject: [PATCH 022/157] Only patch Categorical if broadcast_tensors is defined --- pyro/distributions/torch.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 5db79e3306..818db52d1c 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -40,14 +40,17 @@ def expand(self, batch_shape): class Categorical(torch.distributions.Categorical, TorchDistributionMixin): - # log_prob can be deleted after https://github.com/pytorch/pytorch/pull/10321 - def log_prob(self, value): - if self._validate_args: - self._validate_sample(value) - value = value.long().unsqueeze(-1) - value, log_pmf = torch.broadcast_tensors(value, self.logits) - value = value[..., :1] - return log_pmf.gather(-1, value).squeeze(-1) + # workaround for jit errors in broadcast_all. + # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 + if hasattr(torch, 'broadcast_tensors'): + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = value.long().unsqueeze(-1) + value, log_pmf = torch.broadcast_tensors(value, self.logits) + value = value[..., :1] + return log_pmf.gather(-1, value).squeeze(-1) def expand(self, batch_shape): try: From 17cb6368841522c17c199dc83a13cd76ace9b750 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 14:22:54 -0700 Subject: [PATCH 023/157] Add patches to work around bugs in 0.4.1 --- pyro/distributions/torch_patch.py | 38 +++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 119d866159..ba7b26559d 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -45,4 +45,42 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) +if torch.__version__.startswith('0.4.1'): + + # work around https://github.com/pytorch/pytorch/issues/9917 + @_patch('torch.bernoulli') + def _torch_bernoulli(input, out=None): + unpatched_fn = _torch_bernoulli._pyro_unpatched + input = input.contiguous() + return unpatched_fn(input) if out is None else unpatched_fn(input, out) + + # work around https://github.com/pytorch/pytorch/issues/9521 + @_patch('torch._standard_gamma') + def _torch_standard_gamma(concentration): + concentration = concentration.contiguous() + unpatched_fn = _torch_standard_gamma._pyro_unpatched + if concentration.is_cuda: + return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) + return unpatched_fn(concentration) + + # work around https://github.com/pytorch/pytorch/issues/9521 + @_patch('torch.distributions.gamma._standard_gamma') + def _standard_gamma(concentration): + concentration = concentration.contiguous() + if concentration.is_cuda: + return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) + return concentration._standard_gamma() + + # work around https://github.com/pytorch/pytorch/issues/9521 + @_patch('torch._dirichlet_grad') + def _torch_dirichlet_grad(x, concentration, total): + unpatched_fn = _torch_dirichlet_grad._pyro_unpatched + x = x.contiguous() + concentration = concentration.contiguous() + total = total.contiguous() + if x.is_cuda: + return unpatched_fn(x.cpu(), concentration.cpu(), total.cpu()).cuda(x.get_device()) + return unpatched_fn(x, concentration, total) + + __all__ = [] From a251d3de47e7303b69825e75c612875a764cc826 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 15:21:16 -0700 Subject: [PATCH 024/157] Fix test failures --- pyro/distributions/torch_patch.py | 7 +++++++ tests/infer/mcmc/test_hmc.py | 6 +++--- tests/infer/test_jit.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index ba7b26559d..b23ad5c170 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -54,6 +54,13 @@ def _torch_bernoulli(input, out=None): input = input.contiguous() return unpatched_fn(input) if out is None else unpatched_fn(input, out) + # work around https://github.com/pytorch/pytorch/issues/9917 + @_patch('torch.poisson') + def _torch_poisson(input): + unpatched_fn = _torch_poisson._pyro_unpatched + input = input.contiguous() + return unpatched_fn(input) + # work around https://github.com/pytorch/pytorch/issues/9521 @_patch('torch._standard_gamma') def _torch_standard_gamma(concentration): diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 70af202cbb..231f637d62 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -157,7 +157,7 @@ def test_hmc_conjugate_gaussian(fixture, def test_logistic_regression(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, float(dim+1)) + true_coefs = torch.arange(1., dim + 1.) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -222,7 +222,7 @@ def model(data): def test_logistic_regression_with_dual_averaging(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, float(dim+1)) + true_coefs = torch.arange(1., dim + 1.) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -277,7 +277,7 @@ def test_gaussian_mixture_model(): def gmm(data): with pyro.iarange("num_clusters", K): mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(K), 1.)) + cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 52cbc7dfae..90a1689fb3 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -100,6 +100,8 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) +@pytest.mark.skipif(torch.version <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) def test_bernoulli_enumerate(shape, expand): @@ -116,6 +118,8 @@ def f(probs): assert log_prob.shape == (2,) + shape +@pytest.mark.skipif(torch.version <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) def test_categorical_enumerate(shape, expand): @@ -133,6 +137,8 @@ def f(probs): assert log_prob.shape == shape[-1:] + batch_shape +@pytest.mark.skipif(torch.version <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, @@ -160,6 +166,8 @@ def guide(data): inference.step(data) +@pytest.mark.skipif(torch.version <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("irange_dim", [1, 2]) @@ -204,6 +212,8 @@ def guide(): ])) +@pytest.mark.skipif(torch.version <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_beta_bernoulli(Elbo, vectorized): @@ -240,6 +250,8 @@ def guide(data): svi.step(data) +@pytest.mark.skipif(torch.version <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_dirichlet_bernoulli(Elbo, vectorized): From 1e6fb98a077194615e7346c642ddd18044b07a0d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 15:37:59 -0700 Subject: [PATCH 025/157] flake8 --- pyro/distributions/torch_patch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index b23ad5c170..717e2e1f40 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -62,7 +62,7 @@ def _torch_poisson(input): return unpatched_fn(input) # work around https://github.com/pytorch/pytorch/issues/9521 - @_patch('torch._standard_gamma') + @_patch('torch._standard_gamma') # noqa: F811 def _torch_standard_gamma(concentration): concentration = concentration.contiguous() unpatched_fn = _torch_standard_gamma._pyro_unpatched @@ -71,7 +71,7 @@ def _torch_standard_gamma(concentration): return unpatched_fn(concentration) # work around https://github.com/pytorch/pytorch/issues/9521 - @_patch('torch.distributions.gamma._standard_gamma') + @_patch('torch.distributions.gamma._standard_gamma') # noqa: F811 def _standard_gamma(concentration): concentration = concentration.contiguous() if concentration.is_cuda: @@ -79,7 +79,7 @@ def _standard_gamma(concentration): return concentration._standard_gamma() # work around https://github.com/pytorch/pytorch/issues/9521 - @_patch('torch._dirichlet_grad') + @_patch('torch._dirichlet_grad') # noqa: F811 def _torch_dirichlet_grad(x, concentration, total): unpatched_fn = _torch_dirichlet_grad._pyro_unpatched x = x.contiguous() From 80cf76f9b88e598b4e0680e592f074fedbca7dee Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 17:35:47 -0700 Subject: [PATCH 026/157] Fix typo in skipif markers --- tests/infer/test_jit.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 90a1689fb3..df656d0256 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -100,7 +100,7 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) -@pytest.mark.skipif(torch.version <= '0.4.1', +@pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) @@ -118,7 +118,7 @@ def f(probs): assert log_prob.shape == (2,) + shape -@pytest.mark.skipif(torch.version <= '0.4.1', +@pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) @@ -137,7 +137,7 @@ def f(probs): assert log_prob.shape == shape[-1:] + batch_shape -@pytest.mark.skipif(torch.version <= '0.4.1', +@pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ @@ -166,7 +166,7 @@ def guide(data): inference.step(data) -@pytest.mark.skipif(torch.version <= '0.4.1', +@pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @@ -212,7 +212,7 @@ def guide(): ])) -@pytest.mark.skipif(torch.version <= '0.4.1', +@pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) @@ -250,7 +250,7 @@ def guide(data): svi.step(data) -@pytest.mark.skipif(torch.version <= '0.4.1', +@pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) From 00087255bedc49f6fa01a13c8d44a7c54c6abcab Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 17:48:18 -0700 Subject: [PATCH 027/157] Work around bugs in torch unwind backward --- tests/contrib/tracking/test_em.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/contrib/tracking/test_em.py b/tests/contrib/tracking/test_em.py index 944a19ebdd..d89d20f4c4 100644 --- a/tests/contrib/tracking/test_em.py +++ b/tests/contrib/tracking/test_em.py @@ -13,7 +13,6 @@ from pyro.infer import SVI, TraceEnum_ELBO from pyro.optim import Adam from pyro.optim.multi import MixedMultiOptimizer, Newton -from tests.common import xfail_param def make_args(): @@ -64,23 +63,23 @@ def model(detections, args): # This should match detection_model's existence part. -def exists_log_likelihood(objects, args): +def compute_exists_logits(objects, args): p_exists = args.expected_num_objects / args.max_num_objects real_part = dist.Normal(0., 1.).log_prob(objects) real_part = real_part + math.log(p_exists) spurious_part = torch.empty(real_part.shape).fill_(math.log(1 - p_exists)) - return torch.stack([spurious_part, real_part], -1) + return real_part - spurious_part # This should match detection_model's assignment part. -def assign_log_likelihood(objects, detections, noise_scale, args): +def compute_assign_logits(objects, detections, noise_scale, args): num_detections = len(detections) p_fake = args.num_fake_detections / num_detections real_part = dist.Normal(objects, noise_scale).log_prob(detections) real_part = real_part + math.log((1 - p_fake) / args.max_num_objects) fake_part = dist.Normal(0., 1.).log_prob(detections) fake_part = fake_part + math.log(p_fake) - return torch.cat([real_part, fake_part], -1) + return real_part - fake_part def guide(detections, args): @@ -91,14 +90,12 @@ def guide(detections, args): with torch.set_grad_enabled(args.assignment_grad): # Evaluate log likelihoods. TODO make this more pyronic. - exists_loglike = exists_log_likelihood(objects, args) - assign_loglike = assign_log_likelihood(objects, detections.unsqueeze(-1), noise_scale, args) - assert exists_loglike.shape == (max_num_objects, 2) - assert assign_loglike.shape == (num_detections, max_num_objects + 1) + exists_logits = compute_exists_logits(objects, args) + assign_logits = compute_assign_logits(objects, detections.unsqueeze(-1), noise_scale, args) + assert exists_logits.shape == (max_num_objects,) + assert assign_logits.shape == (num_detections, max_num_objects) # Compute soft assignments. - exists_logits = exists_loglike[:, 1] - exists_loglike[:, 0] - assign_logits = assign_loglike[:, :-1] - assign_loglike[:, -1:] assignment = MarginalAssignment(exists_logits, assign_logits, bp_iters=10) with pyro.iarange('objects_iarange', max_num_objects): @@ -121,10 +118,7 @@ def generate_data(args): return detections -@pytest.mark.parametrize('assignment_grad', [ - False, - xfail_param(True, reason="pytorch 0.4.1 RuntimeError: dim() called on undefined Tensor"), -]) +@pytest.mark.parametrize('assignment_grad', [False, True]) def test_em(assignment_grad): args = make_args() args.assignment_grad = assignment_grad @@ -146,10 +140,7 @@ def test_em(assignment_grad): print('step {}, loss = {}'.format(step, loss.item())) -@pytest.mark.parametrize('assignment_grad', [ - False, - xfail_param(True, reason="pytorch 0.4.1 RuntimeError: dim() called on undefined Tensor"), -]) +@pytest.mark.parametrize('assignment_grad', [False, True]) def test_em_nested_in_svi(assignment_grad): args = make_args() args.assignment_grad = assignment_grad @@ -180,7 +171,6 @@ def test_em_nested_in_svi(assignment_grad): svi_step, loss, pyro.param('noise_scale').item())) -@pytest.mark.xfail(reason="pytorch 0.4.1 RuntimeError: dim() called on undefined Tensor") def test_svi_multi(): args = make_args() args.assignment_grad = True From 55003ec2cdee12fb829f94dc1a94be299e0e99f7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 17:54:54 -0700 Subject: [PATCH 028/157] Mark xfailing jit test --- tests/test_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 38d284864a..fddb59b639 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -61,8 +61,8 @@ def xfail_jit(*args): JIT_EXAMPLES = [ - 'bayesian_regression.py --num-epochs=1 --jit', xfail_jit('air/main.py --num-steps=1 --jit'), + xfail_jit('bayesian_regression.py --num-epochs=1 --jit'), # this works on PyTorch master xfail_jit('contrib/autoname/mixture.py --num-epochs=1 --jit'), xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), xfail_jit('dmm/dmm.py --num-epochs=1 --jit'), From 8ea0461ccaa819c45a98c10d31bdd3f4c1f725c1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 8 Aug 2018 18:47:20 -0700 Subject: [PATCH 029/157] Update all uses of torch.arange --- pyro/distributions/gaussian_scale_mixture.py | 2 +- pyro/distributions/torch.py | 2 +- tests/contrib/gp/test_models.py | 16 ++++++++-------- tests/distributions/test_von_mises.py | 2 +- tests/infer/mcmc/test_nuts.py | 6 +++--- tests/infer/test_enum.py | 4 ++-- tutorial/source/gp.ipynb | 4 ++-- tutorial/source/tensor_shapes.ipynb | 8 ++++---- tutorial/source/tracking_1d.ipynb | 4 ++-- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pyro/distributions/gaussian_scale_mixture.py b/pyro/distributions/gaussian_scale_mixture.py index 635a3b31a9..4a56883ec5 100644 --- a/pyro/distributions/gaussian_scale_mixture.py +++ b/pyro/distributions/gaussian_scale_mixture.py @@ -127,7 +127,7 @@ def backward(ctx, grad_output): q_tot = (pis * q_j).sum(-1, keepdim=True) # l Phi_j = torch.exp(-0.5 * r_sqr_j) # l j - exponents = - torch.arange(1, float(int(dim/2)) + 1, 1) + exponents = - torch.arange(1., int(dim/2) + 1., 1.) if z.dim() > 1: r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim/2)) # l j d/2 else: diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 818db52d1c..61d1507ff7 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -67,7 +67,7 @@ def expand(self, batch_shape): def enumerate_support(self, expand=True): num_events = self._num_events - values = torch.arange(num_events).long() + values = torch.arange(num_events, dtype=torch.long) values = values.view((-1,) + (1,) * len(self._batch_shape)) if expand: values = values.expand((-1,) + self._batch_shape) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index e2063bb1e4..6d2fddba52 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -187,12 +187,12 @@ def test_inference_sgpr(): X = dist.Uniform(torch.zeros(N), torch.ones(N)*5).sample() y = 0.5 * torch.sin(3*X) + dist.Normal(torch.zeros(N), torch.ones(N)*0.5).sample() kernel = RBF(input_dim=1) - Xu = torch.arange(0, 5.5, 0.5) + Xu = torch.arange(0., 5.5, 0.5) sgpr = SparseGPRegression(X, y, kernel, Xu) sgpr.optimize(optim.Adam({"lr": 0.01}), num_steps=1000) - Xnew = torch.arange(0, 5.05, 0.05) + Xnew = torch.arange(0., 5.05, 0.05) loc, var = sgpr(Xnew, full_cov=False) target = 0.5 * torch.sin(3*Xnew) @@ -205,12 +205,12 @@ def test_inference_vsgp(): X = dist.Uniform(torch.zeros(N), torch.ones(N)*5).sample() y = 0.5 * torch.sin(3*X) + dist.Normal(torch.zeros(N), torch.ones(N)*0.5).sample() kernel = RBF(input_dim=1) - Xu = torch.arange(0, 5.5, 0.5) + Xu = torch.arange(0., 5.5, 0.5) vsgp = VariationalSparseGP(X, y, kernel, Xu, Gaussian()) vsgp.optimize(optim.Adam({"lr": 0.03}), num_steps=1000) - Xnew = torch.arange(0, 5.05, 0.05) + Xnew = torch.arange(0., 5.05, 0.05) loc, var = vsgp(Xnew, full_cov=False) target = 0.5 * torch.sin(3*Xnew) @@ -223,12 +223,12 @@ def test_inference_whiten_vsgp(): X = dist.Uniform(torch.zeros(N), torch.ones(N)*5).sample() y = 0.5 * torch.sin(3*X) + dist.Normal(torch.zeros(N), torch.ones(N)*0.5).sample() kernel = RBF(input_dim=1) - Xu = torch.arange(0, 5.5, 0.5) + Xu = torch.arange(0., 5.5, 0.5) vsgp = VariationalSparseGP(X, y, kernel, Xu, Gaussian(), whiten=True) vsgp.optimize(optim.Adam({"lr": 0.01}), num_steps=1000) - Xnew = torch.arange(0, 5.05, 0.05) + Xnew = torch.arange(0., 5.05, 0.05) loc, var = vsgp(Xnew, full_cov=False) target = 0.5 * torch.sin(3*Xnew) @@ -331,9 +331,9 @@ def f(x): return 2 * x + 3 + 5 * torch.sin(7 * x) tensor_holder = torch.tensor([]) - X = tensor_holder.new_tensor(torch.arange(100)) + X = tensor_holder.new_tensor(torch.arange(100.)) y = f(X) - Xnew = tensor_holder.new_tensor(torch.arange(100, 150)) + Xnew = tensor_holder.new_tensor(torch.arange(100., 150.)) ynew = f(Xnew) kernel = Cosine(input_dim=1) diff --git a/tests/distributions/test_von_mises.py b/tests/distributions/test_von_mises.py index db497980fa..94c3f74be4 100644 --- a/tests/distributions/test_von_mises.py +++ b/tests/distributions/test_von_mises.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize('concentration', [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]) def test_log_prob_normalized(concentration): - grid = torch.arange(0, 2 * math.pi, 1e-4) + grid = torch.arange(0., 2 * math.pi, 1e-4) prob = VonMises(0.0, concentration).log_prob(grid).exp() norm = prob.mean().item() * 2 * math.pi assert abs(norm - 1) < 1e-3, norm diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index a69c75834d..469d12495f 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -71,7 +71,7 @@ def test_nuts_conjugate_gaussian(fixture, def test_logistic_regression(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, float(dim+1)) + true_coefs = torch.arange(1., dim + 1.) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -121,7 +121,7 @@ def model(data): def test_logistic_regression_with_dual_averaging(): dim = 3 data = torch.randn(2000, dim) - true_coefs = torch.arange(1, float(dim+1)) + true_coefs = torch.arange(1., dim + 1.) labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() def model(data): @@ -189,7 +189,7 @@ def test_gaussian_mixture_model(): def gmm(data): with pyro.iarange("num_clusters", K): mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) - cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(K), 1.)) + cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 1e6c952ab7..f88d20b185 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -154,7 +154,7 @@ def gmm_guide(data, verbose=False): @pytest.mark.parametrize("model", [gmm_model, gmm_guide]) def test_gmm_iter_discrete_traces(data_size, graph_type, model): pyro.clear_param_store() - data = torch.arange(0, data_size) + data = torch.arange(0., float(data_size)) model = config_enumerate(model) traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True)) # This non-vectorized version is exponential in data_size: @@ -189,7 +189,7 @@ def gmm_batch_guide(data): @pytest.mark.parametrize("model", [gmm_batch_model, gmm_batch_guide]) def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type): pyro.clear_param_store() - data = torch.arange(0, data_size) + data = torch.arange(0., float(data_size)) model = config_enumerate(model) traces = list(iter_discrete_traces(graph_type, model, data=data)) # This vectorized version is independent of data_size: diff --git a/tutorial/source/gp.ipynb b/tutorial/source/gp.ipynb index 7dd4d0c256..be8a013d94 100644 --- a/tutorial/source/gp.ipynb +++ b/tutorial/source/gp.ipynb @@ -535,7 +535,7 @@ ], "source": [ "# initialize the inducing inputs\n", - "Xu = torch.arange(20) / 4.0\n", + "Xu = torch.arange(20.) / 4.0\n", "\n", "# initialize the kernel and model\n", "kernel = gp.kernels.RBF(input_dim=1)\n", @@ -634,7 +634,7 @@ ], "source": [ "# initialize the inducing inputs\n", - "Xu = torch.arange(10) / 2.0\n", + "Xu = torch.arange(10.) / 2.0\n", "\n", "# initialize the kernel, likelihood, and model\n", "kernel = gp.kernels.RBF(input_dim=1)\n", diff --git a/tutorial/source/tensor_shapes.ipynb b/tutorial/source/tensor_shapes.ipynb index 608a4838c2..42f9b87930 100644 --- a/tutorial/source/tensor_shapes.ipynb +++ b/tutorial/source/tensor_shapes.ipynb @@ -335,7 +335,7 @@ "metadata": {}, "outputs": [], "source": [ - "data = torch.arange(100)\n", + "data = torch.arange(100.)\n", "\n", "def model2():\n", " mean = pyro.param(\"mean\", torch.zeros(len(data)))\n", @@ -404,7 +404,7 @@ "source": [ "@config_enumerate(default=\"parallel\")\n", "def model3():\n", - " p = pyro.param(\"p\", torch.arange(6) / 6)\n", + " p = pyro.param(\"p\", torch.arange(6.) / 6)\n", " locs = pyro.param(\"locs\", torch.tensor([-1., 1.]))\n", "\n", " a = pyro.sample(\"a\", Categorical(torch.ones(6) / 6))\n", @@ -414,7 +414,7 @@ " with pyro.iarange(\"d_iarange\", 5):\n", " d = pyro.sample(\"d\", Bernoulli(0.4).expand_by([5,4]))\n", " e_loc = locs[d.long()].unsqueeze(-1)\n", - " e_scale = torch.arange(1, 8)\n", + " e_scale = torch.arange(1., 8.)\n", " e = pyro.sample(\"e\", Normal(e_loc, e_scale)\n", " .independent(1)) # Note this depends on d.\n", "\n", @@ -448,7 +448,7 @@ " | | with pyro.iarange(\"d_iarange\", 5):\n", " 2 1 1 1|5 4| d = pyro.sample(\"d\", Bernoulli(0.4).expand_by([5,4]))\n", " 2 1 1 1|5 4|1 e_loc = locs[d.long()].unsqueeze(-1)\n", - " | |7 e_scale = torch.arange(1, 8)\n", + " | |7 e_scale = torch.arange(1., 8.)\n", " 2 1 1 1|5 4|7 e = pyro.sample(\"e\", Normal(e_loc, e_scale)\n", " | | .independent(1))\n", "```\n", diff --git a/tutorial/source/tracking_1d.ipynb b/tutorial/source/tracking_1d.ipynb index 4c44152f12..57570f2e35 100644 --- a/tutorial/source/tracking_1d.ipynb +++ b/tutorial/source/tracking_1d.ipynb @@ -48,7 +48,7 @@ "outputs": [], "source": [ "def get_dynamics(num_frames):\n", - " time = torch.arange(num_frames) / 4\n", + " time = torch.arange(float(num_frames)) / 4\n", " return torch.stack([time.cos(), time.sin()], -1)" ] }, @@ -263,7 +263,7 @@ " pyplot.plot(true_positions.numpy(), 'k--')\n", " is_observed = (observations[..., -1] > 0)\n", " pos = observations[..., 0]\n", - " time = torch.arange(args.num_frames).unsqueeze(-1).expand_as(pos)\n", + " time = torch.arange(float(args.num_frames)).unsqueeze(-1).expand_as(pos)\n", " pyplot.scatter(time[is_observed].view(-1).numpy(),\n", " pos[is_observed].view(-1).numpy(), color='k', marker='+',\n", " label='observation')\n", From 09cadfbaaf4226f7f2542c0e9752923f0b530ee5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 12:07:31 -0700 Subject: [PATCH 030/157] Remove obsolete logsumexp implementation --- pyro/distributions/util.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index f843c2febb..b14ecdf1f3 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -5,9 +5,12 @@ import torch import torch.distributions as torch_dist +from torch import logsumexp _VALIDATION_ENABLED = False +log_sum_exp = logsumexp # DEPRECATED + def copy_docs_from(source_class, full_text=False): """ @@ -200,27 +203,6 @@ def torch_sign(value): return torch.sign(value) -try: - from torch import logsumexp # for pytorch 0.4.1 and later -except ImportError: - def logsumexp(tensor, dim=-1, keepdim=False): - """ - Numerically stable implementation for the `LogSumExp` operation. The - summing is done along the dimension specified by ``dim``. - - :param torch.Tensor tensor: Input tensor. - :param dim: Dimension to be summed out. - :param keepdim: Whether to retain the dimension - that is summed out. - """ - max_val = tensor.max(dim, keepdim=True)[0] - log_sum_exp = max_val + (tensor - max_val).exp().sum(dim=dim, keepdim=True).log() - return log_sum_exp if keepdim else log_sum_exp.squeeze(dim) - - -log_sum_exp = logsumexp # DEPRECATED - - def enable_validation(is_validate): global _VALIDATION_ENABLED _VALIDATION_ENABLED = is_validate From 2a93a9a4e049181ebdbc48bc5b153e40d33d2cde Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 12:31:04 -0700 Subject: [PATCH 031/157] Patch torch.distributions.Categorical.log_prob --- pyro/distributions/torch.py | 12 ------------ pyro/distributions/torch_patch.py | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 61d1507ff7..f0c1fa9cfa 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -40,18 +40,6 @@ def expand(self, batch_shape): class Categorical(torch.distributions.Categorical, TorchDistributionMixin): - # workaround for jit errors in broadcast_all. - # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 - if hasattr(torch, 'broadcast_tensors'): - - def log_prob(self, value): - if self._validate_args: - self._validate_sample(value) - value = value.long().unsqueeze(-1) - value, log_pmf = torch.broadcast_tensors(value, self.logits) - value = value[..., :1] - return log_pmf.gather(-1, value).squeeze(-1) - def expand(self, batch_shape): try: return super(Categorical, self).expand(batch_shape) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 717e2e1f40..79793972d0 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -90,4 +90,19 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) +# these patches work after https://github.com/pytorch/pytorch/pull/10075 +if hasattr(torch, 'broadcast_tensors'): + + # workaround lack of jit support for Categorical.log_prob() + # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 + @_patch('torch.distributions.categorical.Categorical.log_prob') + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = value.long().unsqueeze(-1) + value, log_pmf = torch.broadcast_tensors(value, self.logits) + value = value[..., :1] + return log_pmf.gather(-1, value).squeeze(-1) + + __all__ = [] From e91aa861361a668224114b4f7c44642d43cf0ca6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 12:35:34 -0700 Subject: [PATCH 032/157] Work around lack of jit support for torch.eye(_, out=_) --- pyro/contrib/gp/models/gplvm.py | 8 ++++---- pyro/contrib/gp/models/vgp.py | 5 +++-- pyro/contrib/gp/models/vsgp.py | 5 +++-- pyro/distributions/lowrank_mvn.py | 5 +++-- pyro/distributions/omt_mvn.py | 4 ++-- pyro/distributions/torch.py | 4 ++-- pyro/distributions/torch_patch.py | 2 +- pyro/distributions/util.py | 9 +++++++++ tests/distributions/test_util.py | 12 +++++++++++- 9 files changed, 38 insertions(+), 16 deletions(-) diff --git a/pyro/contrib/gp/models/gplvm.py b/pyro/contrib/gp/models/gplvm.py index 51260d99e8..ac68cb162f 100644 --- a/pyro/contrib/gp/models/gplvm.py +++ b/pyro/contrib/gp/models/gplvm.py @@ -1,14 +1,14 @@ from __future__ import absolute_import, division, print_function -import torch from torch.distributions import constraints from torch.nn import Parameter import pyro -from pyro.contrib.gp.util import Parameterized import pyro.distributions as dist import pyro.infer as infer import pyro.optim as optim +from pyro.contrib.gp.util import Parameterized +from pyro.distributions.util import eye_like from pyro.params import param_with_module_name @@ -74,7 +74,7 @@ def __init__(self, base_model, name="GPLVM"): C = self.X_loc.shape[1] X_scale_tril_shape = self.X_loc.shape + (C,) - Id = torch.eye(C, out=self.X_loc.new_empty(C, C)) + Id = eye_like(self.X_loc, C) X_scale_tril = Id.expand(X_scale_tril_shape) self.X_scale_tril = Parameter(X_scale_tril) self.set_constraint("X_scale_tril", constraints.lower_cholesky) @@ -87,7 +87,7 @@ def model(self): # sample X from unit multivariate normal distribution zero_loc = self.X_loc.new_zeros(self.X_loc.shape) C = self.X_loc.shape[1] - Id = torch.eye(C, out=self.X_loc.new_empty(C, C)) + Id = eye_like(self.X_loc, C) X_name = param_with_module_name(self.name, "X") X = pyro.sample(X_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim()-1)) diff --git a/pyro/contrib/gp/models/vgp.py b/pyro/contrib/gp/models/vgp.py index 2acfed9ef9..4a8f8b0a23 100644 --- a/pyro/contrib/gp/models/vgp.py +++ b/pyro/contrib/gp/models/vgp.py @@ -8,6 +8,7 @@ import pyro.distributions as dist from pyro.contrib.gp.models.model import GPModel from pyro.contrib.gp.util import conditional +from pyro.distributions.util import eye_like from pyro.params import param_with_module_name @@ -74,7 +75,7 @@ def __init__(self, X, y, kernel, likelihood, mean_function=None, self.f_loc = Parameter(f_loc) f_scale_tril_shape = self.latent_shape + (N, N) - Id = torch.eye(N, out=self.X.new_empty(N, N)) + Id = eye_like(self.X, N) f_scale_tril = Id.expand(f_scale_tril_shape) self.f_scale_tril = Parameter(f_scale_tril) self.set_constraint("f_scale_tril", constraints.lower_cholesky) @@ -96,7 +97,7 @@ def model(self): f_name = param_with_module_name(self.name, "f") if self.whiten: - Id = torch.eye(N, out=self.X.new_empty(N, N)) + Id = eye_like(self.X, N) pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) diff --git a/pyro/contrib/gp/models/vsgp.py b/pyro/contrib/gp/models/vsgp.py index f0c67e084c..2d6586bad6 100644 --- a/pyro/contrib/gp/models/vsgp.py +++ b/pyro/contrib/gp/models/vsgp.py @@ -9,6 +9,7 @@ import pyro.poutine as poutine from pyro.contrib.gp.models.model import GPModel from pyro.contrib.gp.util import conditional +from pyro.distributions.util import eye_like from pyro.params import param_with_module_name @@ -98,7 +99,7 @@ def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None, self.u_loc = Parameter(u_loc) u_scale_tril_shape = self.latent_shape + (M, M) - Id = torch.eye(M, out=self.Xu.new_empty(M, M)) + Id = eye_like(self.Xu, M) u_scale_tril = Id.expand(u_scale_tril_shape) self.u_scale_tril = Parameter(u_scale_tril) self.set_constraint("u_scale_tril", constraints.lower_cholesky) @@ -120,7 +121,7 @@ def model(self): zero_loc = Xu.new_zeros(u_loc.shape) u_name = param_with_module_name(self.name, "u") if self.whiten: - Id = torch.eye(M, out=Xu.new_empty(M, M)) + Id = eye_like(Xu, M) pyro.sample(u_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) diff --git a/pyro/distributions/lowrank_mvn.py b/pyro/distributions/lowrank_mvn.py index e88b2d84f6..5beef012af 100644 --- a/pyro/distributions/lowrank_mvn.py +++ b/pyro/distributions/lowrank_mvn.py @@ -7,6 +7,7 @@ from torch.distributions.utils import lazy_property from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistribution +from pyro.distributions.util import eye_like def _matrix_triangular_solve_compat(b, A, upper=True): @@ -84,7 +85,7 @@ def scale_tril(self): A = self.covariance_matrix_W_term / Dsqrt At_A = A.t().matmul(A) N = A.shape[1] - Id = torch.eye(N, N, out=A.new_empty(N, N)) + Id = eye_like(A, N) K = Id + At_A L = K.potrf(upper=False) return Dsqrt.unsqueeze(1) * L @@ -111,7 +112,7 @@ def _compute_logdet_and_mahalanobis(self, D, W, y, trace_term=0): """ W_Dinv = W / D M = W.shape[0] - Id = torch.eye(M, M, out=W.new_empty(M, M)) + Id = eye_like(W, M) K = Id + W_Dinv.matmul(W.t()) L = K.potrf(upper=False) if y.dim() == 1: diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index 8fed96f85f..c3a393bc85 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -6,7 +6,7 @@ from torch.distributions import constraints from pyro.distributions.torch import MultivariateNormal -from pyro.distributions.util import sum_leftmost +from pyro.distributions.util import eye_like, sum_leftmost class OMTMultivariateNormal(MultivariateNormal): @@ -51,7 +51,7 @@ def backward(ctx, grad_output): g = grad_output loc_grad = sum_leftmost(grad_output, -1) - identity = torch.eye(dim, out=torch.tensor(g.new_empty(dim, dim))) + identity = eye_like(g, dim) R_inv = torch.trtrs(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index f0c1fa9cfa..633c7acf47 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -4,6 +4,7 @@ from torch.distributions import constraints from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistributionMixin +from pyro.distributions.util import eye_like class Bernoulli(torch.distributions.Bernoulli, TorchDistributionMixin): @@ -253,8 +254,7 @@ def expand(self, batch_shape): def enumerate_support(self, expand=True): n = self.event_shape[0] - values = self._new((n, n)) - torch.eye(n, out=values) + values = eye_like(self._categorical._param, n) values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) if expand: values = values.expand((n,) + self.batch_shape + (n,)) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 79793972d0..0eb1908105 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -96,7 +96,7 @@ def _torch_dirichlet_grad(x, concentration, total): # workaround lack of jit support for Categorical.log_prob() # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 @_patch('torch.distributions.categorical.Categorical.log_prob') - def log_prob(self, value): + def _log_prob(self, value): if self._validate_args: self._validate_sample(value) value = value.long().unsqueeze(-1) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index b14ecdf1f3..fedc27ecba 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -203,6 +203,15 @@ def torch_sign(value): return torch.sign(value) +# work around lack of jit support for torch.eye(..., out=value) +def eye_like(value, m, n=None): + if n is None: + n = m + eye = value.new_zeros(m, n) + eye.view(-1)[::n + 1] = 1 + return eye + + def enable_validation(is_validate): global _VALIDATION_ENABLED _VALIDATION_ENABLED = is_validate diff --git a/tests/distributions/test_util.py b/tests/distributions/test_util.py index f94fb97f6e..31cf1d7abe 100644 --- a/tests/distributions/test_util.py +++ b/tests/distributions/test_util.py @@ -4,7 +4,8 @@ import pytest import torch -from pyro.distributions.util import broadcast_shape, scale_tensor, sum_leftmost, sum_rightmost +from pyro.distributions.util import broadcast_shape, eye_like, scale_tensor, sum_leftmost, sum_rightmost +from tests.common import assert_equal INF = float('inf') @@ -136,3 +137,12 @@ def test_scale_tensor(tensor, scale, expected): assert (actual == expected).all() else: assert actual == expected + + +@pytest.mark.parametrize("m", [1, 2, 3, 4]) +@pytest.mark.parametrize("n", [1, 2, 3, 4, None]) +def test_eye_like(m, n): + x = torch.tensor(0.) + expected = torch.eye(m) if n is None else torch.eye(m, n) + actual = eye_like(x, m, n) + assert_equal(expected, actual, '{} vs {}'.format(expected.cpu().numpy(), actual.cpu().numpy())) From 5d71162d75e3b7158a10b0a4fa33b28b5a48de46 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 12:51:39 -0700 Subject: [PATCH 033/157] Add test-jit target to Makefile --- Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Makefile b/Makefile index b8b90fde41..33f0c72447 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,10 @@ test-cuda: lint FORCE CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx -n 4 --stage unit CUDA_TEST=1 pytest -vx -n 4 tests/test_examples.py::test_cuda +test-jit: FORCE + @echo See jit.log + pytest -v -n auto --tb=short --runxfail tests/infer/test_jit.py tests/test_examples.py::test_jit | tee jit.log + clean: FORCE git clean -dfx -e pyro-egg.info From 0fb6870748d9ac1366cc954f6f0f2bb1caa653b8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 13:47:51 -0700 Subject: [PATCH 034/157] Fix bug in eye_like when m!=n --- pyro/distributions/util.py | 2 +- tests/distributions/test_util.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index fedc27ecba..da34d062d2 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -208,7 +208,7 @@ def eye_like(value, m, n=None): if n is None: n = m eye = value.new_zeros(m, n) - eye.view(-1)[::n + 1] = 1 + eye.view(-1)[:min(m, n) * n:n + 1] = 1 return eye diff --git a/tests/distributions/test_util.py b/tests/distributions/test_util.py index 31cf1d7abe..32b832d1c1 100644 --- a/tests/distributions/test_util.py +++ b/tests/distributions/test_util.py @@ -139,10 +139,11 @@ def test_scale_tensor(tensor, scale, expected): assert actual == expected -@pytest.mark.parametrize("m", [1, 2, 3, 4]) -@pytest.mark.parametrize("n", [1, 2, 3, 4, None]) +@pytest.mark.parametrize("m", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("n", [1, 2, 3, 4, 5, None]) def test_eye_like(m, n): x = torch.tensor(0.) expected = torch.eye(m) if n is None else torch.eye(m, n) actual = eye_like(x, m, n) - assert_equal(expected, actual, '{} vs {}'.format(expected.cpu().numpy(), actual.cpu().numpy())) + assert_equal(expected, actual, + msg='Expected:\n{}\nActual:\n{}'.format(expected.cpu().numpy(), actual.cpu().numpy())) From 6f88bbc133bdd6297ee24c4482f18eac74128df1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 14:18:58 -0700 Subject: [PATCH 035/157] Fix jit errors: torch_scale and variable len(args) --- pyro/distributions/util.py | 2 ++ pyro/ops/jit.py | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index da34d062d2..bf07027c07 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -200,6 +200,8 @@ def torch_sign(value): """ if isinstance(value, numbers.Number): return (value > 0) - (value < 0) + if value.dtype == torch.int64: + value = value.float() return torch.sign(value) diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index 0fb2127301..b1ce7dcf20 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -17,12 +17,14 @@ class CompiledFunction(object): """ def __init__(self, fn): self.fn = fn - self.compiled = None + self.compiled = {} # len(args) -> callable self._param_names = None def __call__(self, *args, **kwargs): + argc = len(args) + # if first time - if self.compiled is None: + if argc not in self.compiled: # param capture with poutine.block(): with poutine.trace(param_only=True) as first_param_capture: @@ -46,7 +48,7 @@ def compiled(*params_and_args): constrained_params[name] = constrained_param return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) - self.compiled = compiled + self.compiled[argc] = compiled else: unconstrained_params = [pyro.param(name).unconstrained() for name in self._param_names] @@ -54,7 +56,7 @@ def compiled(*params_and_args): with poutine.block(hide=self._param_names): with poutine.trace(param_only=True) as param_capture: - ret = self.compiled(*params_and_args) + ret = self.compiled[argc](*params_and_args) for name in param_capture.trace.nodes.keys(): if name not in self._param_names: From 2f190e5c20aa9a77b13f69d91f8a3f6559c2cea1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 15:26:30 -0700 Subject: [PATCH 036/157] Patch multivariate normal __init__ methods to be jittable --- pyro/distributions/torch_patch.py | 75 ++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 0eb1908105..0630204aa8 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -45,7 +45,7 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) -if torch.__version__.startswith('0.4.1'): +if torch.__version__ >= '0.4.1': # work around https://github.com/pytorch/pytorch/issues/9917 @_patch('torch.bernoulli') @@ -93,7 +93,7 @@ def _torch_dirichlet_grad(x, concentration, total): # these patches work after https://github.com/pytorch/pytorch/pull/10075 if hasattr(torch, 'broadcast_tensors'): - # workaround lack of jit support for Categorical.log_prob() + # work around lack of jit support for torch._C._infer_size() # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 @_patch('torch.distributions.categorical.Categorical.log_prob') def _log_prob(self, value): @@ -104,5 +104,76 @@ def _log_prob(self, value): value = value[..., :1] return log_pmf.gather(-1, value).squeeze(-1) + # work around lack of jit support for torch._C._infer_size() + # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 + @_patch('torch.distributions.multivariate_normal.MultivariateNormal.__init__') + def _MultivariateNormal_init(self, loc, covariance_matrix=None, precision_matrix=None, + scale_tril=None, validate_args=None): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: + raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") + + loc_ = loc.unsqueeze(-1) # temporarily add dim on right + if scale_tril is not None: + if scale_tril.dim() < 2: + raise ValueError("scale_tril matrix must be at least two-dimensional, " + "with optional leading batch dimensions") + self._unbroadcasted_scale_tril = scale_tril + self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) + elif covariance_matrix is not None: + if covariance_matrix.dim() < 2: + raise ValueError("covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions") + self._unbroadcasted_scale_tril = torch.distributions.multivariate_normal._batch_potrf_lower( + covariance_matrix) + self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_) + else: + if precision_matrix.dim() < 2: + raise ValueError("precision_matrix must be at least two-dimensional, " + "with optional leading batch dimensions") + covariance_matrix = torch.distributions.multivariate_normal._batch_inverse(precision_matrix) + self._unbroadcasted_scale_tril = torch.distributions.multivariate_normal._batch_potrf_lower( + covariance_matrix) + self.covariance_matrix, self.precision_matrix, loc_ = torch.broadcast_tensors( + covariance_matrix, precision_matrix, loc_) + self.loc = loc_[..., 0] # drop rightmost dim + + batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] + super(torch.distributions.multivariate_normal.MultivariateNormal, self).__init__( + batch_shape, event_shape, validate_args=validate_args) + + # work around lack of jit support for torch._C._infer_size() + # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 + @_patch('torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal.__init__') + def __init__(self, loc, cov_factor, cov_diag, validate_args=None): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + event_shape = loc.shape[-1:] + if cov_factor.dim() < 2: + raise ValueError("cov_factor must be at least two-dimensional, " + "with optional leading batch dimensions") + if cov_factor.shape[-2:-1] != event_shape: + raise ValueError("cov_factor must be a batch of matrices with shape {} x m" + .format(event_shape[0])) + if cov_diag.shape[-1:] != event_shape: + raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape)) + + loc_ = loc.unsqueeze(-1) + cov_diag_ = cov_diag.unsqueeze(-1) + try: + loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_) + except RuntimeError: + raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}" + .format(loc.shape, cov_factor.shape, cov_diag.shape)) + self.loc = loc_[..., 0] + self.cov_diag = cov_diag_[..., 0] + batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] + + self._capacitance_tril = torch.distributions.lowrank_multivariate_normal._batch_capacitance_tril( + self.cov_factor, self.cov_diag) + super(torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal, self).__init__( + batch_shape, event_shape, validate_args=validate_args) + __all__ = [] From f7ef56e54e613875f2dcf31dce3c7412b106c936 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 15:40:47 -0700 Subject: [PATCH 037/157] Patch torch.log --- pyro/distributions/torch_patch.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 0630204aa8..0e441728a1 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -47,6 +47,14 @@ def _torch_dirichlet_grad(x, concentration, total): if torch.__version__ >= '0.4.1': + # work around https://github.com/pytorch/pytorch/issues/10241 + # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 + @_patch('torch.log') + def _torch_log(input, out=None): + unpatched_fn = _torch_log._pyro_unpatched + input = input.contiguous() + return unpatched_fn(input) if out is None else unpatched_fn(input, out) + # work around https://github.com/pytorch/pytorch/issues/9917 @_patch('torch.bernoulli') def _torch_bernoulli(input, out=None): From b5ba5f1901436d5e159801b8f33e43ac776425f9 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 15:55:40 -0700 Subject: [PATCH 038/157] Patch torch.Tensor.log --- pyro/distributions/torch_patch.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 0e441728a1..53552bba46 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -55,6 +55,14 @@ def _torch_log(input, out=None): input = input.contiguous() return unpatched_fn(input) if out is None else unpatched_fn(input, out) + # work around https://github.com/pytorch/pytorch/issues/10241 + # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 + @_patch('torch.Tensor.log') + def _Tensor_log(self): + unpatched_fn = _Tensor_log._pyro_unpatched + self = self.contiguous() + return unpatched_fn(self) + # work around https://github.com/pytorch/pytorch/issues/9917 @_patch('torch.bernoulli') def _torch_bernoulli(input, out=None): From 3f64101b77b465b6e8e9907bb88a075ae1b394a5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 9 Aug 2018 16:05:44 -0700 Subject: [PATCH 039/157] Patch torch.exp and torch.Tensor.exp --- pyro/distributions/torch_patch.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 53552bba46..906096539e 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -63,6 +63,22 @@ def _Tensor_log(self): self = self.contiguous() return unpatched_fn(self) + # work around https://github.com/pytorch/pytorch/issues/10241 + # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 + @_patch('torch.exp') + def _torch_exp(input, out=None): + unpatched_fn = _torch_exp._pyro_unpatched + input = input.contiguous() + return unpatched_fn(input) if out is None else unpatched_fn(input, out) + + # work around https://github.com/pytorch/pytorch/issues/10241 + # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 + @_patch('torch.Tensor.exp') + def _Tensor_exp(self): + unpatched_fn = _Tensor_exp._pyro_unpatched + self = self.contiguous() + return unpatched_fn(self) + # work around https://github.com/pytorch/pytorch/issues/9917 @_patch('torch.bernoulli') def _torch_bernoulli(input, out=None): From 894fa5837400834498d1fe2c7ca2b7cf17ed49ef Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 13 Aug 2018 23:15:02 -0700 Subject: [PATCH 040/157] Use JIT traced potential energy computation in HMC (#1299) --- Makefile | 2 + examples/baseball.py | 4 +- examples/eight_schools/mcmc.py | 3 +- pyro/infer/mcmc/hmc.py | 36 ++++++++++++++++- pyro/infer/mcmc/nuts.py | 11 ++++-- tests/infer/mcmc/test_hmc.py | 66 +++++++++++++++++++++++--------- tests/infer/mcmc/test_nuts.py | 70 ++++++++++++++++++++++++---------- tests/test_examples.py | 2 + 8 files changed, 150 insertions(+), 44 deletions(-) diff --git a/Makefile b/Makefile index 33f0c72447..ec3cda44c5 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,8 @@ test-cuda: lint FORCE test-jit: FORCE @echo See jit.log pytest -v -n auto --tb=short --runxfail tests/infer/test_jit.py tests/test_examples.py::test_jit | tee jit.log + pytest -v -n auto --tb=short --runxfail tests/infer/mcmc/test_hmc.py tests/infer/mcmc/test_nuts.py \ + -k JIT=True | tee -a jit.log clean: FORCE git clean -dfx -e pyro-egg.info diff --git a/examples/baseball.py b/examples/baseball.py index b46ca4fd62..27b283c62a 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -218,7 +218,7 @@ def main(args): baseball_dataset = pd.read_csv(DATA_URL, "\t") train, _, player_names = train_test_split(baseball_dataset) at_bats, hits = train[:, 0], train[:, 1] - nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) + nuts_kernel = NUTS(conditioned_model, adapt_step_size=True, jit_compile=args.jit) logging.info("Original Dataset:") logging.info(baseball_dataset) @@ -270,5 +270,7 @@ def main(args): parser.add_argument("-n", "--num-samples", nargs="?", default=1200, type=int) parser.add_argument("--warmup-steps", nargs='?', default=300, type=int) parser.add_argument("--rng_seed", nargs='?', default=0, type=int) + parser.add_argument('--jit', action='store_true', default=False, + help='use PyTorch jit') args = parser.parse_args() main(args) diff --git a/examples/eight_schools/mcmc.py b/examples/eight_schools/mcmc.py index b8a577e14b..fe29f96e6d 100644 --- a/examples/eight_schools/mcmc.py +++ b/examples/eight_schools/mcmc.py @@ -34,7 +34,7 @@ def conditioned_model(model, sigma, y): def main(args): - nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) + nuts_kernel = NUTS(conditioned_model, adapt_step_size=True, jit_compile=args.jit) posterior = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps)\ .run(model, data.sigma, data.y) marginal_mu_tau = EmpiricalMarginal(posterior, sites=["mu", "tau"])\ @@ -54,6 +54,7 @@ def main(args): help='number of MCMC samples (default: 1000)') parser.add_argument('--warmup-steps', type=int, default=1000, help='number of MCMC samples for warmup (default: 1000)') + parser.add_argument('--jit', action='store_true', default=False) args = parser.parse_args() main(args) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 9cd304a279..299137d63b 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -50,6 +50,9 @@ class HMC(TraceKernel): :param int max_iarange_nesting: Optional bound on max number of nested :func:`pyro.iarange` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. + :param bool jit_compile: Optional parameter denoting whether to use + the PyTorch JIT to trace the log density computation, and use this + optimized executable trace in the integrator. Example: @@ -78,7 +81,8 @@ def __init__(self, num_steps=None, adapt_step_size=False, transforms=None, - max_iarange_nesting=float("inf")): + max_iarange_nesting=float("inf"), + jit_compile=False): # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. self.model = poutine.enum(config_enumerate(model, default="parallel", expand=False), @@ -94,6 +98,7 @@ def __init__(self, self.trajectory_length = 2 * math.pi # from Stan self.num_steps = max(1, int(self.trajectory_length / self.step_size)) self.adapt_step_size = adapt_step_size + self._jit_compile = jit_compile self._target_accept_prob = 0.8 # from Stan self.transforms = {} if transforms is None else transforms @@ -123,6 +128,8 @@ def _kinetic_energy(self, r): return 0.5 * sum(x.pow(2).sum() for x in r.values()) def _potential_energy(self, z): + if self._jit_compile: + return self._potential_energy_jit(z) # Since the model is specified in the constrained space, transform the # unconstrained R.V.s `z` to the constrained space. z_constrained = z.copy() @@ -135,6 +142,32 @@ def _potential_energy(self, z): potential_energy += transform.log_abs_det_jacobian(z_constrained[name], z[name]).sum() return potential_energy + def _potential_energy_jit(self, z): + names, vals = zip(*sorted(z.items())) + if self._compiled_potential_fn: + return self._compiled_potential_fn(*vals) + + @torch.jit.trace(*vals, optimize=True) + def wrapped(*zi): + z_constrained = list(zi) + # transform to constrained space. + for i, name in enumerate(names): + if name in self.transforms: + transform = self.transforms[name] + z_constrained[i] = transform.inv(z_constrained[i]) + z_constrained = dict(zip(names, z_constrained)) + trace = self._get_trace(z_constrained) + potential_energy = -self._compute_trace_log_prob(trace) + # adjust by the jacobian for this transformation. + for i, name in enumerate(names): + if name in self.transforms: + transform = self.transforms[name] + potential_energy += transform.log_abs_det_jacobian(z_constrained[name], zi[i]).sum() + return potential_energy + + self._compiled_potential_fn = wrapped + return self._compiled_potential_fn(*vals) + def _energy(self, z, r): return self._kinetic_energy(r) + self._potential_energy(z) @@ -143,6 +176,7 @@ def _reset(self): self._accept_cnt = 0 self._r_dist = OrderedDict() self._args = None + self._compiled_potential_fn = None self._kwargs = None self._prototype_trace = None self._adapt_phase = False diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index fd8bbb3c51..a290ab537f 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -56,6 +56,9 @@ class NUTS(HMC): :param int max_iarange_nesting: Optional bound on max number of nested :func:`pyro.iarange` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. + :param bool jit_compile: Optional parameter denoting whether to use + the PyTorch JIT to trace the log density computation, and use this + optimized executable trace in the integrator. Example: @@ -82,12 +85,14 @@ def __init__(self, step_size=None, adapt_step_size=False, transforms=None, - max_iarange_nesting=float("inf")): + max_iarange_nesting=float("inf"), + jit_compile=False): super(NUTS, self).__init__(model, step_size, adapt_step_size=adapt_step_size, transforms=transforms, - max_iarange_nesting=max_iarange_nesting) + max_iarange_nesting=max_iarange_nesting, + jit_compile=jit_compile) self._max_tree_depth = 10 # from Stan # There are three conditions to stop doubling process: @@ -133,7 +138,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): else: diverging = (sliced_energy >= self._max_sliced_energy) delta_energy = energy_new - energy_current - accept_prob = (-delta_energy).exp().clamp(max=1) + accept_prob = (-delta_energy).exp().clamp(max=1.0) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, tree_size, False, diverging, accept_prob, 1) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 231f637d62..f5b1681401 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -18,6 +18,22 @@ logger = logging.getLogger(__name__) +def mark_jit(*args, **kwargs): + jit_markers = kwargs.pop("marks", []) + jit_markers += [ + pytest.mark.skipif(torch.__version__ <= "0.4.1", + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228"), + pytest.mark.skipif('CI' in os.environ, + reason='slow test') + ] + kwargs["marks"] = jit_markers + return pytest.param(*args, **kwargs) + + +def jit_idfn(param): + return "JIT={}".format(param) + + class GaussianChain(object): def __init__(self, dim, chain_len, num_obs): @@ -154,7 +170,8 @@ def test_hmc_conjugate_gaussian(fixture, assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) -def test_logistic_regression(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -166,13 +183,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, step_size=0.0855, num_steps=4) + hmc_kernel = HMC(model, step_size=0.0855, num_steps=4, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) beta_posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, beta_posterior.mean).item(), 0.0, prec=0.1) -def test_beta_bernoulli(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -182,13 +200,14 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, step_size=0.02, num_steps=3) + hmc_kernel = HMC(model, step_size=0.02, num_steps=3, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.05) -def test_gamma_normal(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_normal(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) @@ -198,13 +217,14 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3) + hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) -def test_dirichlet_categorical(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) @@ -213,13 +233,14 @@ def model(data): true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3) + hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) -def test_logistic_regression_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression_with_dual_averaging(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -231,13 +252,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(posterior.mean, true_coefs).item(), 0.0, prec=0.1) -def test_beta_bernoulli_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli_with_dual_averaging(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -248,13 +270,15 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=2) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=2, + jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.05) -def test_gamma_normal_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_normal_with_dual_averaging(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) @@ -264,13 +288,16 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) -def test_gaussian_mixture_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True, + marks=[pytest.mark.skip("FIXME: Slow on JIT.")])], + ids=jit_idfn) +def test_gaussian_mixture_model(jit): K, N = 3, 1000 @poutine.broadcast @@ -287,14 +314,16 @@ def gmm(data): true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() - hmc_kernel = HMC(gmm, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=1) + hmc_kernel = HMC(gmm, trajectory_length=1, adapt_step_size=True, + max_iarange_nesting=1, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) assert_equal(posterior[1], true_cluster_means, prec=0.2) -def test_bernoulli_latent_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_bernoulli_latent_model(jit): @poutine.broadcast def model(data): y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0)) @@ -309,7 +338,8 @@ def model(data): y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2. * z, 1.0).sample() - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=1) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, + max_iarange_nesting=1, jit_compile=jit) mcmc_run = MCMC(hmc_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="y_prob").mean assert_equal(posterior, y_prob, prec=0.05) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 7ff51e8be2..04f9da6c08 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -18,6 +18,23 @@ logger = logging.getLogger(__name__) + +def mark_jit(*args, **kwargs): + jit_markers = kwargs.pop("marks", []) + jit_markers += [ + pytest.mark.skipif(torch.__version__ <= "0.4.1", + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228"), + pytest.mark.skipif('CI' in os.environ, + reason='slow test') + ] + kwargs["marks"] = jit_markers + return pytest.param(*args, **kwargs) + + +def jit_idfn(param): + return "JIT={}".format(param) + + T2 = T(*TEST_CASES[2].values)._replace(num_samples=800, warmup_steps=200) TEST_CASES[2] = pytest.param(*T2, marks=pytest.mark.skipif( 'CI' in os.environ and os.environ['CI'] == 'true', reason='Slow test - skip on CI')) @@ -68,7 +85,8 @@ def test_nuts_conjugate_gaussian(fixture, assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) -def test_logistic_regression(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -80,13 +98,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, step_size=0.0855) + nuts_kernel = NUTS(model, step_size=0.0855, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, posterior.mean).item(), 0.0, prec=0.1) -def test_beta_bernoulli(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -96,13 +115,14 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, step_size=0.02) + nuts_kernel = NUTS(model, step_size=0.02, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) -def test_gamma_normal(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_normal(jit): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) @@ -112,13 +132,14 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, step_size=0.01) + nuts_kernel = NUTS(model, step_size=0.01, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) -def test_logistic_regression_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_logistic_regression_with_dual_averaging(jit): dim = 3 data = torch.randn(2000, dim) true_coefs = torch.arange(1., dim + 1.) @@ -130,13 +151,14 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, posterior.mean).item(), 0.0, prec=0.1) -def test_beta_bernoulli_with_dual_averaging(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_beta_bernoulli_with_dual_averaging(jit): def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) @@ -146,13 +168,14 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="p_latent") assert_equal(posterior.mean, true_probs, prec=0.03) -def test_dirichlet_categorical(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) @@ -161,13 +184,14 @@ def model(data): true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) -def test_gamma_beta(): +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +def test_gamma_beta(jit): def model(data): alpha_prior = pyro.sample('alpha', dist.Gamma(concentration=1., rate=1.)) beta_prior = pyro.sample('beta', dist.Gamma(concentration=1., rate=1.)) @@ -176,13 +200,16 @@ def model(data): true_alpha = torch.tensor(5.) true_beta = torch.tensor(1.) data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000,))) - nuts_kernel = NUTS(model, adapt_step_size=True) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=['alpha', 'beta']) assert_equal(posterior.mean, torch.stack([true_alpha, true_beta]), prec=0.05) -def test_gaussian_mixture_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True, + marks=[pytest.mark.skip("FIXME: Slow on JIT.")])], + ids=jit_idfn) +def test_gaussian_mixture_model(jit): K, N = 3, 1000 @poutine.broadcast @@ -199,14 +226,16 @@ def gmm(data): true_mix_proportions = torch.tensor([0.1, 0.3, 0.6]) cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() - nuts_kernel = NUTS(gmm, adapt_step_size=True, max_iarange_nesting=1) + nuts_kernel = NUTS(gmm, adapt_step_size=True, max_iarange_nesting=1, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) assert_equal(posterior[1], true_cluster_means, prec=0.2) -def test_bernoulli_latent_model(): +@pytest.mark.parametrize("jit", [False, mark_jit(True, marks=[ + pytest.mark.xfail(reason="FIXME: log not implemented for 'CPULongType'")])], ids=jit_idfn) +def test_bernoulli_latent_model(jit): @poutine.broadcast def model(data): y_prob = pyro.sample("y_prob", dist.Beta(1., 1.)) @@ -220,14 +249,15 @@ def model(data): y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2. * z, 1.0).sample() - nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=1) + nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=1, jit_compile=jit) mcmc_run = MCMC(nuts_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="y_prob").mean assert_equal(posterior, y_prob, prec=0.05) +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) @pytest.mark.parametrize("num_steps", [2, 3, 5]) -def test_gaussian_hmm_enum_shape(num_steps): +def test_gaussian_hmm_enum_shape(jit, num_steps): dim = 4 def model(data): @@ -244,5 +274,5 @@ def model(data): assert effective_dim == 1 data = torch.ones(num_steps) - nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=0) + nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=0, jit_compile=jit) MCMC(nuts_kernel, num_samples=5, warmup_steps=5).run(data) diff --git a/tests/test_examples.py b/tests/test_examples.py index fddb59b639..e53f649048 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -62,11 +62,13 @@ def xfail_jit(*args): JIT_EXAMPLES = [ xfail_jit('air/main.py --num-steps=1 --jit'), + xfail_jit('baseball.py --num-samples=200 --warmup-steps=100 --jit'), xfail_jit('bayesian_regression.py --num-epochs=1 --jit'), # this works on PyTorch master xfail_jit('contrib/autoname/mixture.py --num-epochs=1 --jit'), xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), xfail_jit('dmm/dmm.py --num-epochs=1 --jit'), xfail_jit('dmm/dmm.py --num-epochs=1 --num-iafs=1 --jit'), + xfail_jit('eight_schools/mcmc.py --num-samples=500 --warmup-steps=100 --jit'), xfail_jit('eight_schools/svi.py --num-epochs=1 --jit'), xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit'), From 9313b0f7fdf7561d7e5879be65fa9941dc3bdbd6 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 3 Sep 2018 19:56:36 -0700 Subject: [PATCH 041/157] add xfailing test --- tests/infer/test_jit.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index df656d0256..66ac8f6802 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -10,7 +10,7 @@ from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) from pyro.optim import Adam -from tests.common import assert_equal +from tests.common import assert_equal, xfail_param def test_simple(): @@ -250,6 +250,37 @@ def guide(data): svi.step(data) +@pytest.mark.skipif(torch.__version__ <= '0.4.1', + reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") +@pytest.mark.parametrize('Elbo', [ + Trace_ELBO, + xfail_param(JitTrace_ELBO, reason="https://github.com/uber/pyro/issues/1358"), + TraceGraph_ELBO, + xfail_param(JitTraceGraph_ELBO, reason="https://github.com/uber/pyro/issues/1358"), + TraceEnum_ELBO, + xfail_param(JitTraceEnum_ELBO, reason="https://github.com/uber/pyro/issues/1358"), +]) +def test_svi_irregular_batch_size(Elbo): + pyro.clear_param_store() + + def model(data): + loc = pyro.param("loc", torch.tensor(0.0)) + scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) + with pyro.iarange("data", data.shape[0]): + pyro.sample("x", + dist.Normal(loc, scale).expand([data.shape[0]]), + obs=data) + + def guide(data): + pass + + pyro.clear_param_store() + elbo = Elbo(strict_enumeration_warning=False, max_iarange_nesting=1) + inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) + inference.step(torch.ones(10)) + inference.step(torch.ones(3)) + + @pytest.mark.skipif(torch.__version__ <= '0.4.1', reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) From 3cfba13ada741b2340aa782c1cbe972a3c061430 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 10 Sep 2018 17:29:48 -0700 Subject: [PATCH 042/157] Remove obsolete PyTorch patches --- pyro/distributions/torch_patch.py | 172 ------------------------------ 1 file changed, 172 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 2ebf2e4aeb..7e6e3a967f 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -45,169 +45,6 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) -if torch.__version__ >= '0.4.1': - - # work around https://github.com/pytorch/pytorch/issues/10241 - # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 - @_patch('torch.log') - def _torch_log(input, out=None): - unpatched_fn = _torch_log._pyro_unpatched - input = input.contiguous() - return unpatched_fn(input) if out is None else unpatched_fn(input, out) - - # work around https://github.com/pytorch/pytorch/issues/10241 - # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 - @_patch('torch.Tensor.log') - def _Tensor_log(self): - unpatched_fn = _Tensor_log._pyro_unpatched - self = self.contiguous() - return unpatched_fn(self) - - # work around https://github.com/pytorch/pytorch/issues/10241 - # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 - @_patch('torch.exp') - def _torch_exp(input, out=None): - unpatched_fn = _torch_exp._pyro_unpatched - input = input.contiguous() - return unpatched_fn(input) if out is None else unpatched_fn(input, out) - - # work around https://github.com/pytorch/pytorch/issues/10241 - # this can be deleted after https://github.com/pytorch/pytorch/pull/10269 - @_patch('torch.Tensor.exp') - def _Tensor_exp(self): - unpatched_fn = _Tensor_exp._pyro_unpatched - self = self.contiguous() - return unpatched_fn(self) - - # work around https://github.com/pytorch/pytorch/issues/9917 - @_patch('torch.bernoulli') - def _torch_bernoulli(input, out=None): - unpatched_fn = _torch_bernoulli._pyro_unpatched - input = input.contiguous() - return unpatched_fn(input) if out is None else unpatched_fn(input, out) - - # work around https://github.com/pytorch/pytorch/issues/9917 - @_patch('torch.poisson') - def _torch_poisson(input): - unpatched_fn = _torch_poisson._pyro_unpatched - input = input.contiguous() - return unpatched_fn(input) - - # work around https://github.com/pytorch/pytorch/issues/9521 - @_patch('torch._standard_gamma') # noqa: F811 - def _torch_standard_gamma(concentration): - concentration = concentration.contiguous() - unpatched_fn = _torch_standard_gamma._pyro_unpatched - if concentration.is_cuda: - return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) - return unpatched_fn(concentration) - - # work around https://github.com/pytorch/pytorch/issues/9521 - @_patch('torch.distributions.gamma._standard_gamma') # noqa: F811 - def _standard_gamma(concentration): - concentration = concentration.contiguous() - if concentration.is_cuda: - return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) - return concentration._standard_gamma() - - # work around https://github.com/pytorch/pytorch/issues/9521 - @_patch('torch._dirichlet_grad') # noqa: F811 - def _torch_dirichlet_grad(x, concentration, total): - unpatched_fn = _torch_dirichlet_grad._pyro_unpatched - x = x.contiguous() - concentration = concentration.contiguous() - total = total.contiguous() - if x.is_cuda: - return unpatched_fn(x.cpu(), concentration.cpu(), total.cpu()).cuda(x.get_device()) - return unpatched_fn(x, concentration, total) - - -# these patches work after https://github.com/pytorch/pytorch/pull/10075 -if hasattr(torch, 'broadcast_tensors'): - - # work around lack of jit support for torch._C._infer_size() - # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 - @_patch('torch.distributions.categorical.Categorical.log_prob') - def _log_prob(self, value): - if self._validate_args: - self._validate_sample(value) - value = value.long().unsqueeze(-1) - value, log_pmf = torch.broadcast_tensors(value, self.logits) - value = value[..., :1] - return log_pmf.gather(-1, value).squeeze(-1) - - # work around lack of jit support for torch._C._infer_size() - # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 - @_patch('torch.distributions.multivariate_normal.MultivariateNormal.__init__') - def _MultivariateNormal_init(self, loc, covariance_matrix=None, precision_matrix=None, - scale_tril=None, validate_args=None): - if loc.dim() < 1: - raise ValueError("loc must be at least one-dimensional.") - if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: - raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") - - loc_ = loc.unsqueeze(-1) # temporarily add dim on right - if scale_tril is not None: - if scale_tril.dim() < 2: - raise ValueError("scale_tril matrix must be at least two-dimensional, " - "with optional leading batch dimensions") - self._unbroadcasted_scale_tril = scale_tril - self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) - elif covariance_matrix is not None: - if covariance_matrix.dim() < 2: - raise ValueError("covariance_matrix must be at least two-dimensional, " - "with optional leading batch dimensions") - self._unbroadcasted_scale_tril = torch.distributions.multivariate_normal._batch_potrf_lower( - covariance_matrix) - self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_) - else: - if precision_matrix.dim() < 2: - raise ValueError("precision_matrix must be at least two-dimensional, " - "with optional leading batch dimensions") - covariance_matrix = torch.distributions.multivariate_normal._batch_inverse(precision_matrix) - self._unbroadcasted_scale_tril = torch.distributions.multivariate_normal._batch_potrf_lower( - covariance_matrix) - self.covariance_matrix, self.precision_matrix, loc_ = torch.broadcast_tensors( - covariance_matrix, precision_matrix, loc_) - self.loc = loc_[..., 0] # drop rightmost dim - - batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] - super(torch.distributions.multivariate_normal.MultivariateNormal, self).__init__( - batch_shape, event_shape, validate_args=validate_args) - - # work around lack of jit support for torch._C._infer_size() - # this can be deleted after https://github.com/pytorch/pytorch/pull/10321 - @_patch('torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal.__init__') - def __init__(self, loc, cov_factor, cov_diag, validate_args=None): - if loc.dim() < 1: - raise ValueError("loc must be at least one-dimensional.") - event_shape = loc.shape[-1:] - if cov_factor.dim() < 2: - raise ValueError("cov_factor must be at least two-dimensional, " - "with optional leading batch dimensions") - if cov_factor.shape[-2:-1] != event_shape: - raise ValueError("cov_factor must be a batch of matrices with shape {} x m" - .format(event_shape[0])) - if cov_diag.shape[-1:] != event_shape: - raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape)) - - loc_ = loc.unsqueeze(-1) - cov_diag_ = cov_diag.unsqueeze(-1) - try: - loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_) - except RuntimeError: - raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}" - .format(loc.shape, cov_factor.shape, cov_diag.shape)) - self.loc = loc_[..., 0] - self.cov_diag = cov_diag_[..., 0] - batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] - - self._capacitance_tril = torch.distributions.lowrank_multivariate_normal._batch_capacitance_tril( - self.cov_factor, self.cov_diag) - super(torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal, self).__init__( - batch_shape, event_shape, validate_args=validate_args) - - def _einsum(equation, operands): # work around torch.einsum performance issues # see https://github.com/pytorch/pytorch/issues/10661 @@ -218,15 +55,6 @@ def _einsum(equation, operands): y, x = operands return (x.unsqueeze(1) * y).sum(0).transpose(0, 1) - # work around torch.einsum's limitation to 26 letters - symbols = sorted(set(equation) - set(',->')) - rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) - equation = ''.join(rename.get(s, s) for s in equation) - - # this workaround can be deleted after this issue is fixed in release: - # https://github.com/pytorch/pytorch/issues/7763 - operands = [t.clone() for t in operands] - return _einsum._pyro_unpatched(equation, operands) From a677512b350da67581c6a6dcb329bb30cd49eb6c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 10 Sep 2018 22:32:57 -0700 Subject: [PATCH 043/157] Remove patch for Tensor._standard_gamma --- pyro/distributions/torch_patch.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 7e6e3a967f..144f124124 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -30,13 +30,6 @@ def _torch_standard_gamma(concentration): return unpatched_fn(concentration) -@_patch('torch.distributions.gamma._standard_gamma') -def _standard_gamma(concentration): - if concentration.is_cuda: - return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) - return concentration._standard_gamma() - - @_patch('torch._dirichlet_grad') def _torch_dirichlet_grad(x, concentration, total): unpatched_fn = _torch_dirichlet_grad._pyro_unpatched From 8f665baacd1e0bf94fde2d0ed3ff3b62d7203c3f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 10 Sep 2018 23:35:24 -0700 Subject: [PATCH 044/157] Fix some jit errors --- examples/hmm.py | 2 +- pyro/ops/jit.py | 10 +++--- pyro/primitives.py | 9 +++-- tests/infer/test_jit.py | 74 +++++++++++++++++++++-------------------- 4 files changed, 51 insertions(+), 44 deletions(-) diff --git a/examples/hmm.py b/examples/hmm.py index c89b257718..8922d94d19 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -258,7 +258,7 @@ def main(args): # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): - loss = svi.step(sequences, lengths, args, batch_size=args.batch_size) + loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) # We evaluate on the entire training dataset, diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index b1ce7dcf20..6cf5adeafb 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -31,12 +31,11 @@ def __call__(self, *args, **kwargs): self.fn(*args, **kwargs) self._param_names = list(set(first_param_capture.trace.nodes.keys())) - unconstrained_params = [pyro.param(name).unconstrained() - for name in self._param_names] - params_and_args = unconstrained_params + list(args) + unconstrained_params = tuple(pyro.param(name).unconstrained() + for name in self._param_names) + params_and_args = unconstrained_params + args weakself = weakref.ref(self) - @torch.jit.trace(*params_and_args) def compiled(*params_and_args): self = weakself() unconstrained_params = params_and_args[:len(self._param_names)] @@ -48,7 +47,8 @@ def compiled(*params_and_args): constrained_params[name] = constrained_param return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) - self.compiled[argc] = compiled + with pyro.validation_enabled(False): + self.compiled[argc] = torch.jit.trace(compiled, params_and_args) else: unconstrained_params = [pyro.param(name).unconstrained() for name in self._param_names] diff --git a/pyro/primitives.py b/pyro/primitives.py index d441743e77..071ac82e56 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -100,7 +100,12 @@ def __init__(self, size, subsample_size, use_cuda=None): """ self.size = size self.subsample_size = subsample_size - self.use_cuda = torch.Tensor().is_cuda if use_cuda is None else use_cuda + if use_cuda is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + self.use_cuda = torch.Tensor().is_cuda + else: + self.use_cuda = use_cuda def sample(self, sample_shape=torch.Size()): """ @@ -139,7 +144,7 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No subsample = sample(name, _Subsample(size, subsample_size, use_cuda)) if subsample_size is None: - subsample_size = len(subsample) + subsample_size = subsample.shape[0] elif subsample is not None and subsample_size != len(subsample): raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample)) + diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 66ac8f6802..383dfc2e94 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import warnings + import pytest import torch from torch.autograd import grad @@ -7,21 +9,31 @@ import pyro import pyro.distributions as dist +import pyro.ops.jit from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) from pyro.optim import Adam from tests.common import assert_equal, xfail_param +def constant(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + return torch.tensor(*args, **kwargs) + + def test_simple(): y = torch.ones(2) - @torch.jit.trace(y) def f(x): print('Inside f') - assert x is y + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + assert x is y return y + 1.0 + print('Compiling f') + f = torch.jit.trace(f, (y,)) print('Calling f(y)') assert_equal(f(y), y.new_tensor([2., 2.])) print('Calling f(y)') @@ -35,12 +47,13 @@ def f(x): def test_multi_output(): y = torch.ones(2) - @torch.jit.trace(y) def f(x): print('Inside f') assert x is y return y - 1.0, y + 1.0 + print('Compiling f') + f = torch.jit.trace(f, (y,)) print('Calling f(y)') assert_equal(f(y)[1], y.new_tensor([2., 2.])) print('Calling f(y)') @@ -54,12 +67,13 @@ def f(x): def test_backward(): y = torch.ones(2, requires_grad=True) - @torch.jit.trace(y) def f(x): print('Inside f') assert x is y return (y + 1.0).sum() + print('Compiling f') + f = torch.jit.trace(f, (y,)) print('Calling f(y)') f(y).backward() print('Calling f(y)') @@ -73,12 +87,13 @@ def f(x): @pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad(): - @torch.jit.trace(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) def f(x, y): print('Inside f') loss = (x - y).pow(2).sum() return torch.autograd.grad(loss, [x, y], allow_unused=True) + print('Compiling f') + f = torch.jit.trace(f, (torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True))) print('Invoking f') f(torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)) print('Invoking f') @@ -88,27 +103,26 @@ def f(x, y): @pytest.mark.xfail(reason="grad cannot appear in jitted code") def test_grad_expand(): - @torch.jit.trace(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) def f(x, y): print('Inside f') loss = (x - y).pow(2).sum() return torch.autograd.grad(loss, [x, y], allow_unused=True) + print('Compiling f') + f = torch.jit.trace(f, (torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True))) print('Invoking f') f(torch.zeros(2, requires_grad=True), torch.ones(1, requires_grad=True)) print('Invoking f') f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) def test_bernoulli_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.empty(shape).fill_(0.25) - @torch.jit.trace(probs) + @pyro.ops.jit.trace def f(probs): d = dist.Bernoulli(probs) support = d.enumerate_support(expand=expand) @@ -118,15 +132,13 @@ def f(probs): assert log_prob.shape == (2,) + shape -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) def test_categorical_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.ones(shape) - @torch.jit.trace(probs) + @pyro.ops.jit.trace def f(probs): d = dist.Categorical(probs) support = d.enumerate_support(expand=expand) @@ -137,8 +149,6 @@ def f(probs): assert log_prob.shape == shape[-1:] + batch_shape -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, @@ -153,8 +163,8 @@ def test_svi(Elbo, num_particles): data = torch.arange(10.) def model(data): - loc = pyro.param("loc", torch.tensor(0.0)) - scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) + loc = pyro.param("loc", constant(0.0)) + scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) pyro.sample("x", dist.Normal(loc, scale).expand_by(data.shape).independent(1), obs=data) def guide(data): @@ -166,8 +176,6 @@ def guide(data): inference.step(data) -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("irange_dim", [1, 2]) @@ -175,7 +183,7 @@ def guide(data): def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 - q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval) + q = pyro.param("q", constant(0.75), constraint=constraints.unit_interval) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): @@ -212,8 +220,6 @@ def guide(): ])) -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_beta_bernoulli(Elbo, vectorized): @@ -221,15 +227,15 @@ def test_beta_bernoulli(Elbo, vectorized): data = torch.tensor([1.0] * 6 + [0.0] * 4) def model1(data): - alpha0 = torch.tensor(10.0) - beta0 = torch.tensor(10.0) + alpha0 = constant(10.0) + beta0 = constant(10.0) f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) for i in pyro.irange("irange", len(data)): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) def model2(data): - alpha0 = torch.tensor(10.0) - beta0 = torch.tensor(10.0) + alpha0 = constant(10.0) + beta0 = constant(10.0) f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), obs=data) @@ -237,9 +243,9 @@ def model2(data): model = model2 if vectorized else model1 def guide(data): - alpha_q = pyro.param("alpha_q", torch.tensor(15.0), + alpha_q = pyro.param("alpha_q", constant(15.0), constraint=constraints.positive) - beta_q = pyro.param("beta_q", torch.tensor(15.0), + beta_q = pyro.param("beta_q", constant(15.0), constraint=constraints.positive) pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) @@ -250,8 +256,6 @@ def guide(data): svi.step(data) -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('Elbo', [ Trace_ELBO, xfail_param(JitTrace_ELBO, reason="https://github.com/uber/pyro/issues/1358"), @@ -264,8 +268,8 @@ def test_svi_irregular_batch_size(Elbo): pyro.clear_param_store() def model(data): - loc = pyro.param("loc", torch.tensor(0.0)) - scale = pyro.param("scale", torch.tensor(1.0), constraint=constraints.positive) + loc = pyro.param("loc", constant(0.0)) + scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) with pyro.iarange("data", data.shape[0]): pyro.sample("x", dist.Normal(loc, scale).expand([data.shape[0]]), @@ -281,8 +285,6 @@ def guide(data): inference.step(torch.ones(3)) -@pytest.mark.skipif(torch.__version__ <= '0.4.1', - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228") @pytest.mark.parametrize('vectorized', [False, True]) @pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) def test_dirichlet_bernoulli(Elbo, vectorized): @@ -290,13 +292,13 @@ def test_dirichlet_bernoulli(Elbo, vectorized): data = torch.tensor([1.0] * 6 + [0.0] * 4) def model1(data): - concentration0 = torch.tensor([10.0, 10.0]) + concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] for i in pyro.irange("irange", len(data)): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) def model2(data): - concentration0 = torch.tensor([10.0, 10.0]) + concentration0 = constant([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), obs=data) @@ -304,7 +306,7 @@ def model2(data): model = model2 if vectorized else model1 def guide(data): - concentration_q = pyro.param("concentration_q", torch.tensor([15.0, 15.0]), + concentration_q = pyro.param("concentration_q", constant([15.0, 15.0]), constraint=constraints.positive) pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) From 06b0e63ee4e4d746831a2f5f2395db4c2950550f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 10 Sep 2018 23:52:38 -0700 Subject: [PATCH 045/157] Convert to valid einsum chars in torch_log backend --- pyro/ops/einsum/torch_log.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyro/ops/einsum/torch_log.py b/pyro/ops/einsum/torch_log.py index 06886248ab..afb7f5e991 100644 --- a/pyro/ops/einsum/torch_log.py +++ b/pyro/ops/einsum/torch_log.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function import torch - +from opt_einsum.parser import convert_to_valid_einsum_chars EINSUM_SYMBOLS_BASE = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' @@ -14,6 +14,10 @@ def einsum(equation, *operands): """ Log-sum-exp implementation of einsum. """ + # rename symbols to support PyTorch 0.4.1 and earlier, + # which allow only symbols a-z. + equation = convert_to_valid_einsum_chars(equation) + inputs, output = equation.split('->') inputs = inputs.split(',') From 1785b6c0af616f8fa8bc4fd63cb5e18bed3003b3 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 11 Sep 2018 12:55:15 -0400 Subject: [PATCH 046/157] Updating distributions module with PyTorch master (#1377) * remove lowrank mvn * remove custom enumerate_support * remove unused import * remove reshaped dist, custom wrappers * restore patch file --- pyro/distributions/__init__.py | 2 - pyro/distributions/lowrank_mvn.py | 135 ---------- pyro/distributions/torch.py | 289 ---------------------- pyro/distributions/torch_distribution.py | 165 +----------- pyro/distributions/torch_patch.py | 2 +- pyro/infer/traceenum_elbo.py | 3 - tests/distributions/conftest.py | 4 +- tests/distributions/test_distributions.py | 44 ++-- 8 files changed, 23 insertions(+), 621 deletions(-) delete mode 100644 pyro/distributions/lowrank_mvn.py diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 280f434de7..3f39b98fc6 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -11,7 +11,6 @@ from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture from pyro.distributions.half_cauchy import HalfCauchy from pyro.distributions.iaf import InverseAutoregressiveFlow -from pyro.distributions.lowrank_mvn import LowRankMultivariateNormal from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal from pyro.distributions.rejector import Rejector @@ -37,7 +36,6 @@ "GaussianScaleMixture", "HalfCauchy", "InverseAutoregressiveFlow", - "LowRankMultivariateNormal", "MaskedMixture", "MixtureOfDiagNormalsSharedCovariance", "MixtureOfDiagNormals", diff --git a/pyro/distributions/lowrank_mvn.py b/pyro/distributions/lowrank_mvn.py deleted file mode 100644 index 5beef012af..0000000000 --- a/pyro/distributions/lowrank_mvn.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import math - -import torch -from torch.distributions import constraints -from torch.distributions.utils import lazy_property - -from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistribution -from pyro.distributions.util import eye_like - - -def _matrix_triangular_solve_compat(b, A, upper=True): - """ - Computes the solution to the linear equation AX = b, - where A is a triangular matrix. - - :param b: A 1D or 2D tensor of size N or N x C. - :param A: A 2D tensor of size N X N. - :param upper: A flag if A is a upper triangular matrix or not. - """ - return b.view(b.shape[0], -1).trtrs(A, upper=upper)[0].view(b.shape) - - -class LowRankMultivariateNormal(TorchDistribution): - """ - Low Rank Multivariate Normal distribution. - - Implements fast computation for log probability of Multivariate Normal distribution - when the covariance matrix has the form:: - - covariance_matrix = W @ W.T + D. - - Here D is a diagonal vector and ``W`` is a matrix of size ``N x M``. The - computation will be beneficial when ``M << N``. - - :param torch.Tensor loc: Mean. - Must be a 1D or 2D tensor with the last dimension of size N. - :param torch.Tensor W_term: W term of covariance matrix. - Must be in 2 dimensional of size N x M. - :param torch.Tensor D_term: D term of covariance matrix. - Must be in 1 dimensional of size N. - :param float trace_term: A optional term to be added into Mahalabonis term - according to p(y) = N(y|loc, cov).exp(-1/2 * trace_term). - """ - arg_constraints = {"loc": constraints.real, - "covariance_matrix_D_term": constraints.positive, - "scale_tril": constraints.lower_triangular} - support = IndependentConstraint(constraints.real, 1) - has_rsample = True - - def __init__(self, loc, W_term, D_term, trace_term=None): - W_term = W_term.t() - if loc.shape[-1] != D_term.shape[0]: - raise ValueError("Expected loc.shape == D_term.shape, but got {} vs {}".format( - loc.shape, D_term.shape)) - if D_term.shape[0] != W_term.shape[1]: - raise ValueError("The dimension of D_term must match the first dimension of W_term.") - if D_term.dim() != 1 or W_term.dim() != 2 or loc.dim() > 2: - raise ValueError("D_term, W_term must be 1D, 2D tensors respectively and " - "loc must be a 1D or 2D tensor.") - - self.loc = loc - self.covariance_matrix_D_term = D_term - self.covariance_matrix_W_term = W_term - self.trace_term = trace_term if trace_term is not None else 0 - - batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:] - super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape) - - @property - def mean(self): - return self.loc - - @property - def variance(self): - return self.covariance_matrix_D_term + (self.covariance_matrix_W_term ** 2).sum(0) - - @lazy_property - def scale_tril(self): - # We use the following formula to increase the numerically computation stability - # when using Cholesky decomposition (see GPML section 3.4.3): - # D + W.T @ W = D1/2 @ (I + D-1/2 @ W.T @ W @ D-1/2) @ D1/2 - Dsqrt = self.covariance_matrix_D_term.sqrt() - A = self.covariance_matrix_W_term / Dsqrt - At_A = A.t().matmul(A) - N = A.shape[1] - Id = eye_like(A, N) - K = Id + At_A - L = K.potrf(upper=False) - return Dsqrt.unsqueeze(1) * L - - def rsample(self, sample_shape=torch.Size()): - white = self.loc.new_empty(sample_shape + self.loc.shape).normal_() - return self.loc + torch.matmul(white, self.scale_tril.t()) - - def log_prob(self, value): - delta = value - self.loc - logdet, mahalanobis_squared = self._compute_logdet_and_mahalanobis( - self.covariance_matrix_D_term, self.covariance_matrix_W_term, delta, self.trace_term) - normalization_const = 0.5 * (self.event_shape[-1] * math.log(2 * math.pi) + logdet) - return -(normalization_const + 0.5 * mahalanobis_squared) - - def _compute_logdet_and_mahalanobis(self, D, W, y, trace_term=0): - """ - Calculates log determinant and (squared) Mahalanobis term of covariance - matrix ``(D + Wt.W)``, where ``D`` is a diagonal matrix, based on the - "Woodbury matrix identity" and "matrix determinant lemma":: - - inv(D + Wt.W) = inv(D) - inv(D).Wt.inv(I + W.inv(D).Wt).W.inv(D) - log|D + Wt.W| = log|Id + Wt.inv(D).W| + log|D| - """ - W_Dinv = W / D - M = W.shape[0] - Id = eye_like(W, M) - K = Id + W_Dinv.matmul(W.t()) - L = K.potrf(upper=False) - if y.dim() == 1: - W_Dinv_y = W_Dinv.matmul(y) - elif y.dim() == 2: - W_Dinv_y = W_Dinv.matmul(y.t()) - else: - raise NotImplementedError("SparseMultivariateNormal distribution does not support " - "computing log_prob for a tensor with more than 2 dimensionals.") - Linv_W_Dinv_y = _matrix_triangular_solve_compat(W_Dinv_y, L, upper=False) - if y.dim() == 2: - Linv_W_Dinv_y = Linv_W_Dinv_y.t() - - logdet = 2 * L.diag().log().sum() + D.log().sum() - - mahalanobis1 = (y * y / D).sum(-1) - mahalanobis2 = (Linv_W_Dinv_y * Linv_W_Dinv_y).sum(-1) - mahalanobis_squared = mahalanobis1 - mahalanobis2 + trace_term - - return logdet, mahalanobis_squared diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 633c7acf47..d996cf010d 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -4,300 +4,11 @@ from torch.distributions import constraints from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistributionMixin -from pyro.distributions.util import eye_like - - -class Bernoulli(torch.distributions.Bernoulli, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Bernoulli, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape) - return type(self)(logits=logits, validate_args=validate_args) - - def enumerate_support(self, expand=True): - values = self._param.new_tensor([0., 1.]) - values = values.reshape((2,) + (1,) * len(self.batch_shape)) - if expand: - values = values.expand((2,) + self.batch_shape) - return values - - -class Beta(torch.distributions.Beta, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Beta, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - concentration1 = self.concentration1.expand(batch_shape) - concentration0 = self.concentration0.expand(batch_shape) - return type(self)(concentration1, concentration0, validate_args=validate_args) - - -class Categorical(torch.distributions.Categorical, TorchDistributionMixin): - - def expand(self, batch_shape): - try: - return super(Categorical, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape + self.logits.shape[-1:]) - return type(self)(logits=logits, validate_args=validate_args) - - def enumerate_support(self, expand=True): - num_events = self._num_events - values = torch.arange(num_events, dtype=torch.long) - values = values.view((-1,) + (1,) * len(self._batch_shape)) - if expand: - values = values.expand((-1,) + self._batch_shape) - if self._param.is_cuda: - values = values.cuda(self._param.get_device()) - return values - - -class Cauchy(torch.distributions.Cauchy, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Cauchy, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class Chi2(torch.distributions.Chi2, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Chi2, self).expand_by(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - df = self.df.expand(batch_shape) - return type(self)(df, validate_args=validate_args) - - -class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Dirichlet, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - concentration = self.concentration.expand(batch_shape + self.event_shape) - return type(self)(concentration, validate_args=validate_args) - - -class Exponential(torch.distributions.Exponential, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Exponential, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - rate = self.rate.expand(batch_shape) - return type(self)(rate, validate_args=validate_args) - - -class Gamma(torch.distributions.Gamma, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Gamma, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - concentration = self.concentration.expand(batch_shape) - rate = self.rate.expand(batch_shape) - return type(self)(concentration, rate, validate_args=validate_args) - - -class Geometric(torch.distributions.Geometric, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Geometric, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape) - return type(self)(logits=logits, validate_args=validate_args) - - -class Gumbel(torch.distributions.Gumbel, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Gumbel, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class Independent(torch.distributions.Independent, TorchDistributionMixin): - @constraints.dependent_property - def support(self): - return IndependentConstraint(self.base_dist.support, self.reinterpreted_batch_ndims) - - @property - def _validate_args(self): - return self.base_dist._validate_args - - @_validate_args.setter - def _validate_args(self, value): - self.base_dist._validate_args = value - - def expand(self, batch_shape): - batch_shape = torch.Size(batch_shape) - base_shape = self.base_dist.batch_shape - reinterpreted_shape = base_shape[len(base_shape) - self.reinterpreted_batch_ndims:] - base_dist = self.base_dist.expand(batch_shape + reinterpreted_shape) - return type(self)(base_dist, self.reinterpreted_batch_ndims) - - def enumerate_support(self, expand=expand): - if self.reinterpreted_batch_ndims: - raise NotImplementedError("Pyro does not enumerate over cartesian products") - return self.base_dist.enumerate_support(expand=expand) - - -class Laplace(torch.distributions.Laplace, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Laplace, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(LogNormal, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Multinomial, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape + self.event_shape) - return type(self)(self.total_count, probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape + self.event_shape) - return type(self)(self.total_count, logits=logits, validate_args=validate_args) class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): support = IndependentConstraint(constraints.real, 1) # TODO move upstream - def expand(self, batch_shape): - try: - return super(MultivariateNormal, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape + self.event_shape) - if 'scale_tril' in self.__dict__: - scale_tril = self.scale_tril.expand(batch_shape + self.event_shape + self.event_shape) - return type(self)(loc, scale_tril=scale_tril, validate_args=validate_args) - elif 'covariance_matrix' in self.__dict__: - covariance_matrix = self.covariance_matrix.expand(batch_shape + self.event_shape + self.event_shape) - return type(self)(loc, covariance_matrix=covariance_matrix, validate_args=validate_args) - else: - precision_matrix = self.precision_matrix.expand(batch_shape + self.event_shape + self.event_shape) - return type(self)(loc, precision_matrix=precision_matrix, validate_args=validate_args) - - -class Normal(torch.distributions.Normal, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Normal, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale, validate_args=validate_args) - - -class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(OneHotCategorical, self).expand(batch_shape) - except NotImplementedError: - batch_shape = torch.Size(batch_shape) - validate_args = self.__dict__.get('_validate_args') - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape + self.event_shape) - return type(self)(probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape + self.event_shape) - return type(self)(logits=logits, validate_args=validate_args) - - def enumerate_support(self, expand=True): - n = self.event_shape[0] - values = eye_like(self._categorical._param, n) - values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) - if expand: - values = values.expand((n,) + self.batch_shape + (n,)) - return values - - -class Poisson(torch.distributions.Poisson, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Poisson, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - rate = self.rate.expand(batch_shape) - return type(self)(rate, validate_args=validate_args) - - -class StudentT(torch.distributions.StudentT, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(StudentT, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - df = self.df.expand(batch_shape) - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(df, loc, scale, validate_args=validate_args) - - -class TransformedDistribution(torch.distributions.TransformedDistribution, TorchDistributionMixin): - def expand(self, batch_shape): - return super(TransformedDistribution, self).expand(batch_shape) - - -class Uniform(torch.distributions.Uniform, TorchDistributionMixin): - def expand(self, batch_shape): - try: - return super(Uniform, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - low = self.low.expand(batch_shape) - high = self.high.expand(batch_shape) - return type(self)(low, high, validate_args=validate_args) - # Programmatically load all distributions from PyTorch. __all__ = [] diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 7276478b14..05d3702800 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -1,14 +1,11 @@ from __future__ import absolute_import, division, print_function -import numbers - import torch from torch.distributions import biject_to, constraints, transform_to import pyro.distributions.torch from pyro.distributions.distribution import Distribution -from pyro.distributions.score_parts import ScoreParts -from pyro.distributions.util import broadcast_shape, scale_and_mask, sum_rightmost +from pyro.distributions.util import broadcast_shape, scale_and_mask class TorchDistributionMixin(Distribution): @@ -65,35 +62,6 @@ def shape(self, sample_shape=torch.Size()): """ return sample_shape + self.batch_shape + self.event_shape - def expand(self, batch_shape): - """ - Expands a distribution to a desired - :attr:`~torch.distributions.distribution.Distribution.batch_shape`. - - Note that this is more general than :meth:`expand_by` because - ``d.expand_by(sample_shape)`` can be reduced to - ``d.expand(sample_shape + d.batch_shape)``. - - :param torch.Size batch_shape: The target ``batch_shape``. This must - compatible with ``self.batch_shape`` similar to the requirements - of :func:`torch.Tensor.expand`: the target ``batch_shape`` must - be at least as long as ``self.batch_shape``, and for each - non-singleton dim of ``self.batch_shape``, ``batch_shape`` must - either agree or be set to ``-1``. - :return: An expanded version of this distribution. - :rtype: :class:`ReshapedDistribution` - """ - batch_shape = torch.Size(batch_shape) - cut = len(batch_shape) - len(self.batch_shape) - left, right = batch_shape[:cut], batch_shape[cut:] - if right == self.batch_shape: - return self.expand_by(left) - else: - raise NotImplementedError("`TorchDistributionMixin.expand()` cannot expand " - "distribution's existing batch shape. Consider " - "overriding the default implementation for the " - "distribution class.") - def expand_by(self, sample_shape): """ Expands a distribution by adding ``sample_shape`` to the left side of @@ -107,9 +75,7 @@ def expand_by(self, sample_shape): :return: An expanded version of this distribution. :rtype: :class:`ReshapedDistribution` """ - if not sample_shape: - return self - return ReshapedDistribution(self, sample_shape=sample_shape) + return self.expand(torch.Size(sample_shape) + self.batch_shape) def reshape(self, sample_shape=None, extra_event_dims=None): raise Exception(''' @@ -254,133 +220,6 @@ def check(self, value): transform_to.register(IndependentConstraint, lambda c: transform_to(c.base_constraint)) -class ReshapedDistribution(TorchDistribution): - """ - Reshapes a distribution by adding ``sample_shape`` to its total shape - and adding ``reinterpreted_batch_ndims`` to its - :attr:`~torch.distributions.distribution.Distribution.event_shape`. - - :param torch.Size sample_shape: The size of the iid batch to be drawn from - the distribution. - :param int reinterpreted_batch_ndims: The number of extra event dimensions that will - be considered dependent. - """ - arg_constraints = {} - - def __init__(self, base_dist, sample_shape=torch.Size(), reinterpreted_batch_ndims=0): - sample_shape = torch.Size(sample_shape) - if reinterpreted_batch_ndims > len(sample_shape + base_dist.batch_shape): - raise ValueError('Expected reinterpreted_batch_ndims <= len(sample_shape + base_dist.batch_shape), ' - 'actual {} vs {}'.format(reinterpreted_batch_ndims, - len(sample_shape + base_dist.batch_shape))) - self.base_dist = base_dist - self.sample_shape = sample_shape - self.reinterpreted_batch_ndims = reinterpreted_batch_ndims - shape = sample_shape + base_dist.batch_shape + base_dist.event_shape - batch_dim = len(shape) - reinterpreted_batch_ndims - len(base_dist.event_shape) - batch_shape, event_shape = shape[:batch_dim], shape[batch_dim:] - super(ReshapedDistribution, self).__init__(batch_shape, event_shape) - - def expand(self, batch_shape): - batch_shape = torch.Size(batch_shape) - # Raise error if existing batch shape is being shrunk. - # e.g. (2, 4) -> (2, 1) - proposed_shape = broadcast_shape(self.batch_shape, batch_shape) - if tuple(reversed(proposed_shape)) > tuple(reversed(batch_shape)): - raise ValueError("Existing batch shape {} cannot be expanded " - "to the new batch shape {}." - .format(self.batch_shape, batch_shape)) - # Adjust existing sample shape if possible. - base_dist = self.base_dist - base_batch_shape = batch_shape + self.event_shape[:self.reinterpreted_batch_ndims] - cut = len(base_batch_shape) - len(base_dist.batch_shape) - left, right = base_batch_shape[:cut], base_batch_shape[cut:] - if right == base_dist.batch_shape: - sample_shape = left - # Modify the base distribution's batch shape, - # if existing sample shape cannot be adjusted. - else: - base_dist = self.base_dist.expand(base_batch_shape) - assert not isinstance(base_dist, ReshapedDistribution) - sample_shape = torch.Size(()) - return ReshapedDistribution(base_dist, sample_shape, self.reinterpreted_batch_ndims) - - def expand_by(self, sample_shape): - base_dist = self.base_dist - sample_shape = torch.Size(sample_shape) + self.sample_shape - reinterpreted_batch_ndims = self.reinterpreted_batch_ndims - return ReshapedDistribution(base_dist, sample_shape, reinterpreted_batch_ndims) - - def independent(self, reinterpreted_batch_ndims=None): - if reinterpreted_batch_ndims is None: - reinterpreted_batch_ndims = len(self.batch_shape) - base_dist = self.base_dist - sample_shape = self.sample_shape - reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + reinterpreted_batch_ndims - return ReshapedDistribution(base_dist, sample_shape, reinterpreted_batch_ndims) - - @property - def has_rsample(self): - return self.base_dist.has_rsample - - @property - def has_enumerate_support(self): - return self.base_dist.has_enumerate_support - - @constraints.dependent_property - def support(self): - return IndependentConstraint(self.base_dist.support, self.reinterpreted_batch_ndims) - - @property - def _validate_args(self): - return self.base_dist._validate_args - - @_validate_args.setter - def _validate_args(self, value): - self.base_dist._validate_args = value - - def sample(self, sample_shape=torch.Size()): - return self.base_dist.sample(sample_shape + self.sample_shape) - - def rsample(self, sample_shape=torch.Size()): - return self.base_dist.rsample(sample_shape + self.sample_shape) - - def log_prob(self, value): - shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) - return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape) - - def score_parts(self, value): - shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) - log_prob, score_function, entropy_term = self.base_dist.score_parts(value) - log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape) - if not isinstance(score_function, numbers.Number): - score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape) - if not isinstance(entropy_term, numbers.Number): - entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape) - return ScoreParts(log_prob, score_function, entropy_term) - - def enumerate_support(self, expand=True): - if self.reinterpreted_batch_ndims: - raise NotImplementedError("Pyro does not enumerate over cartesian products") - - samples = self.base_dist.enumerate_support(expand=False) - samples = samples.reshape(samples.shape[:1] + (1,) * len(self.batch_shape) + self.event_shape) - if expand: - samples = samples.expand(samples.shape[:1] + self.batch_shape + self.event_shape) - return samples - - @property - def mean(self): - return self.base_dist.mean.expand(self.batch_shape + self.event_shape) - - @property - def variance(self): - return self.base_dist.variance.expand(self.batch_shape + self.event_shape) - - def entropy(self): - return sum_rightmost(self.base_dist.entropy(), self.reinterpreted_batch_ndims) - - class MaskedDistribution(TorchDistribution): """ Masks a distribution by a zero-one tensor that is broadcastable to the diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 144f124124..3310ce4b23 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -51,4 +51,4 @@ def _einsum(equation, operands): return _einsum._pyro_unpatched(equation, operands) -__all__ = [] +__all__ = [] \ No newline at end of file diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 2c6dbeaae2..3715e3f340 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -13,7 +13,6 @@ import pyro.distributions as dist import pyro.ops.jit import pyro.poutine as poutine -from pyro.distributions.torch_distribution import ReshapedDistribution from pyro.distributions.util import is_identically_zero, scale_and_mask from pyro.infer.contract import contract_tensor_tree, contract_to_tensor from pyro.infer.elbo import ELBO @@ -130,8 +129,6 @@ def _make_dist(dist_, logits): # Reshape for Bernoulli vs Categorical, OneHotCategorical, etc.. if isinstance(dist_, dist.Bernoulli): logits = logits[..., 1] - logits[..., 0] - elif isinstance(dist_, ReshapedDistribution): - return _make_dist(dist_.base_dist, logits=logits) return type(dist_)(logits=logits) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 38e8bf8c4f..2e74ecca08 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -153,10 +153,10 @@ Fixture(pyro_dist=dist.LowRankMultivariateNormal, scipy_dist=sp.multivariate_normal, examples=[ - {'loc': [2.0, 1.0], 'D_term': [0.5, 0.5], 'W_term': [[1.0], [0.5]], + {'loc': [2.0, 1.0], 'cov_diag': [0.5, 0.5], 'cov_factor': [[1.0], [0.5]], 'test_data': [[2.0, 1.0], [9.0, 3.4]]}, ], - scipy_arg_fn=lambda loc, D_term=None, W_term=None: + scipy_arg_fn=lambda loc, cov_diag=None, cov_factor=None: ((), {"mean": np.array(loc), "cov": np.array([[1.5, 0.5], [0.5, 0.75]])}), prec=0.01, min_samples=500000), diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index d997ed207f..b893afed17 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -6,7 +6,6 @@ import pyro import pyro.distributions as dist -from pyro.distributions.torch_distribution import ReshapedDistribution from pyro.distributions.util import broadcast_shape from tests.common import assert_equal, xfail_if_not_implemented @@ -130,8 +129,7 @@ def test_distribution_validate_args(dist_class, args, validate_args): def check_sample_shapes(small, large): - dist_instance = small.base_dist if isinstance(small, ReshapedDistribution) \ - else small + dist_instance = small if isinstance(dist_instance, (dist.LogNormal, dist.LowRankMultivariateNormal, dist.VonMises)): # Ignore broadcasting bug in LogNormal: # https://github.com/pytorch/pytorch/pull/7269 @@ -147,9 +145,10 @@ def check_sample_shapes(small, large): def test_expand_by(dist, sample_shape, shape_type): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) - large = small.expand_by(shape_type(sample_shape)) - assert large.batch_shape == sample_shape + small.batch_shape - check_sample_shapes(small, large) + with xfail_if_not_implemented(): + large = small.expand_by(shape_type(sample_shape)) + assert large.batch_shape == sample_shape + small.batch_shape + check_sample_shapes(small, large) @pytest.mark.parametrize('sample_shape', [(), (2,), (2, 3)]) @@ -157,9 +156,10 @@ def test_expand_by(dist, sample_shape, shape_type): def test_expand_new_dim(dist, sample_shape, shape_type): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) - large = small.expand(shape_type(sample_shape + small.batch_shape)) - assert large.batch_shape == sample_shape + small.batch_shape - check_sample_shapes(small, large) + with xfail_if_not_implemented(): + large = small.expand(shape_type(sample_shape + small.batch_shape)) + assert large.batch_shape == sample_shape + small.batch_shape + check_sample_shapes(small, large) @pytest.mark.parametrize('shape_type', [torch.Size, tuple, list]) @@ -174,8 +174,8 @@ def test_expand_existing_dim(dist, shape_type): batch_shape = torch.Size(batch_shape) with xfail_if_not_implemented(): large = small.expand(shape_type(batch_shape)) - assert large.batch_shape == batch_shape - check_sample_shapes(small, large) + assert large.batch_shape == batch_shape + check_sample_shapes(small, large) @pytest.mark.parametrize("sample_shapes", [ @@ -203,10 +203,11 @@ def test_subsequent_expands_ok(dist, sample_shapes): def test_expand_error(dist, initial_shape, proposed_shape): for idx in range(dist.get_num_test_data()): small = dist.pyro_dist(**dist.get_dist_params(idx)) - large = small.expand(torch.Size(initial_shape) + small.batch_shape) - proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape - with pytest.raises(ValueError): - large.expand(proposed_batch_shape) + with xfail_if_not_implemented(): + large = small.expand(torch.Size(initial_shape) + small.batch_shape) + proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape + with pytest.raises(RuntimeError): + large.expand(proposed_batch_shape) @pytest.mark.parametrize("extra_event_dims,expand_shape", [ @@ -228,19 +229,10 @@ def test_expand_reshaped_distribution(extra_event_dims, expand_shape): assert large.batch_shape == torch.Size(expand_shape) assert large.event_shape == torch.Size(event_shape) - # Change base_dist only if sample_shape cannot be adjusted. - if extra_event_dims >= 1: - assert large.base_dist == reshaped_dist.base_dist - else: - if expand_shape[-1] == 1: - assert large.base_dist == reshaped_dist.base_dist - else: - assert large.base_dist.batch_shape == torch.Size(expand_shape) - # Throws error when batch shape cannot be broadcasted - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): reshaped_dist.expand(expand_shape + [3]) # Throws error when trying to shrink existing batch shape - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): large.expand(expand_shape[1:]) From 29bb3ed2a2e14a0c8c162fd2dee0c971065d8989 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 10:37:33 -0700 Subject: [PATCH 047/157] Use native torch.tensordot --- pyro/ops/einsum/torch_log.py | 53 ++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/pyro/ops/einsum/torch_log.py b/pyro/ops/einsum/torch_log.py index afb7f5e991..f8730630e2 100644 --- a/pyro/ops/einsum/torch_log.py +++ b/pyro/ops/einsum/torch_log.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import torch -from opt_einsum.parser import convert_to_valid_einsum_chars EINSUM_SYMBOLS_BASE = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' @@ -16,7 +15,9 @@ def einsum(equation, *operands): """ # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. - equation = convert_to_valid_einsum_chars(equation) + symbols = sorted(set(equation) - set(',->')) + rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) + equation = ''.join(rename.get(s, s) for s in equation) inputs, output = equation.split('->') inputs = inputs.split(',') @@ -52,7 +53,7 @@ def tensordot(x, y, axes=2): # convert int argument to (list[int], list[int]) if isinstance(axes, int): - axes = range(xnd - axes, xnd), range(axes) + axes = list(range(xnd - axes, xnd)), list(range(axes)) # convert (int, int) to (list[int], list[int]) if isinstance(axes[0], int): @@ -60,30 +61,22 @@ def tensordot(x, y, axes=2): if isinstance(axes[1], int): axes = axes[0], (axes[1],) - # initialize empty indices - x_ix = [None] * xnd - y_ix = [None] * ynd - out_ix = [] - - # fill in repeated indices - available_ix = iter(EINSUM_SYMBOLS_BASE) - for ax1, ax2 in zip(*axes): - repeat = next(available_ix) - x_ix[ax1] = repeat - y_ix[ax2] = repeat - - # fill in the rest, and maintain output order - for i in range(xnd): - if x_ix[i] is None: - leave = next(available_ix) - x_ix[i] = leave - out_ix.append(leave) - for i in range(ynd): - if y_ix[i] is None: - leave = next(available_ix) - y_ix[i] = leave - out_ix.append(leave) - - # form full string and contract! - einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix))) - return einsum(einsum_str, x, y) + # compute shifts + assert all(dim >= 0 for axis in axes for dim in axes) + x_shift = x + y_shift = y + for dim in axes[0]: + x_shift = x_shift.max(dim, keepdim=True)[0] + for dim in axes[1]: + y_shift = y_shift.max(dim, keepdim=True)[0] + + result = torch.tensordot((x - x_shift).exp(), (y - y_shift).exp(), axes).log() + + # apply shifts to result + x_part = x.dim() - len(axes[0]) + y_part = y.dim() - len(axes[1]) + assert result.dim() == x_part + y_part + result += x_shift.reshape(result.shape[:x_part] + (1,) * y_part) + result += y_shift.reshape(result.shape[x_part:]) + + return result From 8c914b891152e09ff1a0bbf3f10e3907c5a2eabf Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 10:38:00 -0700 Subject: [PATCH 048/157] Remove duplicate implementation of logsumexp --- pyro/contrib/oed/eig.py | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/pyro/contrib/oed/eig.py b/pyro/contrib/oed/eig.py index bc624c4ccc..43f0d80f1c 100644 --- a/pyro/contrib/oed/eig.py +++ b/pyro/contrib/oed/eig.py @@ -126,7 +126,7 @@ def naive_rainforth_eig(model, design, observation_labels, target_labels=None, reexpanded_design = lexpand(design, M_prime, N) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() - conditional_lp = logsumexp(sum(retrace.nodes[l]["log_prob"] for l in observation_labels), 0) \ + conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - np.log(M_prime) else: # This assumes that y are independent conditional on theta @@ -141,7 +141,7 @@ def naive_rainforth_eig(model, design, observation_labels, target_labels=None, reexpanded_design = lexpand(design, M, 1) retrace = poutine.trace(conditional_model).get_trace(reexpanded_design) retrace.compute_log_prob() - marginal_lp = logsumexp(sum(retrace.nodes[l]["log_prob"] for l in observation_labels), 0) \ + marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \ - np.log(M) return (conditional_lp - marginal_lp).sum(0)/N @@ -334,30 +334,6 @@ def loss_fn(design, num_particles): return loss_fn -def logsumexp(inputs, dim=None, keepdim=False): - """Numerically stable logsumexp. - - Args: - inputs: A Variable with any shape. - dim: An integer. - keepdim: A boolean. - - Returns: - Equivalent of `log(sum(exp(inputs), dim=dim, keepdim=keepdim))`. - """ - # For a 1-D array x (any array along a single dimension), - # log sum exp(x) = s + log sum exp(x - s) - # with s = max(x) being a common choice. - if dim is None: - inputs = inputs.view(-1) - dim = 0 - s, _ = torch.max(inputs, dim=dim, keepdim=True) - outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() - if not keepdim: - outputs = outputs.squeeze(dim) - return outputs - - class EwmaLog(torch.autograd.Function): """Logarithm function with exponentially weighted moving average for gradients. From 35a6965ccca7fce17699b400382bfde7216dfa57 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 11:55:48 -0700 Subject: [PATCH 049/157] Ignore jit warnings --- pyro/distributions/torch_patch.py | 13 +++++++++++++ pyro/infer/enum.py | 8 +++++++- pyro/infer/util.py | 9 +++++---- pyro/primitives.py | 7 +++++-- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 144f124124..cd3f64903b 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import warnings + import torch @@ -38,6 +40,17 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) +@_patch('torch.distributions.utils._default_promotion') +def _default_promotion(v): + # Ignore jit warnings about promoting Python numbers to tensors, + # assuming all numbers are constant literals. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="torch.tensor might cause the trace to be incorrect") + return _default_promotion._pyro_unpatched(v) + + +@_patch('torch.einsum') def _einsum(equation, operands): # work around torch.einsum performance issues # see https://github.com/pytorch/pytorch/issues/10661 diff --git a/pyro/infer/enum.py b/pyro/infer/enum.py index 5185eadd54..972083d2ca 100644 --- a/pyro/infer/enum.py +++ b/pyro/infer/enum.py @@ -1,7 +1,9 @@ from __future__ import absolute_import, division, print_function import numbers +import warnings +import torch from six.moves.queue import LifoQueue from pyro import poutine @@ -20,10 +22,14 @@ def iter_discrete_escape(trace, msg): def iter_discrete_extend(trace, site, **ignored): values = site["fn"].enumerate_support(expand=site["infer"].get("expand", False)) + enum_total = values.shape[0] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="Iterating over a tensor") + values = iter(values) for i, value in enumerate(values): extended_site = site.copy() extended_site["infer"] = site["infer"].copy() - extended_site["infer"]["_enum_total"] = len(values) + extended_site["infer"]["_enum_total"] = enum_total extended_site["value"] = value extended_trace = trace.copy() extended_trace.add_node(site["name"], **extended_site) diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 681c5b3f8e..808016735d 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -259,10 +259,11 @@ def compute_expectation(self, costs): for cost in cost_terms: prob = sumproduct(factors, cost.shape) mask = prob > 0 - if torch.is_tensor(mask) and not mask.all(): - cost, prob, mask = broadcast_all(cost, prob, mask) - prob = prob[mask] - cost = cost[mask] + if torch.is_tensor(mask): + if torch._C._get_tracing_state() or not mask.all(): + cost, prob, mask = broadcast_all(cost, prob, mask) + prob = prob[mask] + cost = cost[mask] expected_cost = expected_cost + (prob * cost).sum() LAST_CACHE_SIZE[0] = count_cached_ops(cache) return expected_cost diff --git a/pyro/primitives.py b/pyro/primitives.py index 071ac82e56..e000d767bf 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -302,13 +302,16 @@ def __init__(self, name, size, subsample_size=None, subsample=None, use_cuda=Non self.size, self.subsample_size, self.subsample = _subsample(name, size, subsample_size, subsample, use_cuda) def __iter__(self): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="Iterating over a tensor") + subsample = iter(self.subsample) if not am_i_wrapped(): - for i in self.subsample: + for i in subsample: yield i if isinstance(i, numbers.Number) else i.item() else: indep_context = poutine.indep(name=self.name, size=self.subsample_size) with poutine.scale(scale=self.size / self.subsample_size): - for i in self.subsample: + for i in subsample: indep_context.next_context() with indep_context: # convert to python numeric type as functions like torch.ones(*args) From 692b8833d198fd9781f19f1837d1238226453ddc Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 12:12:41 -0700 Subject: [PATCH 050/157] Ignore a couple TracerWarnings in pyro.ops.jit.trace --- pyro/distributions/torch_patch.py | 12 ------------ pyro/infer/enum.py | 1 - pyro/ops/einsum/torch_log.py | 5 +---- pyro/ops/jit.py | 15 +++++++++++++-- tests/infer/test_jit.py | 12 ++++++++---- 5 files changed, 22 insertions(+), 23 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index cd3f64903b..9f9e9f4b02 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -1,7 +1,5 @@ from __future__ import absolute_import, division, print_function -import warnings - import torch @@ -40,16 +38,6 @@ def _torch_dirichlet_grad(x, concentration, total): return unpatched_fn(x, concentration, total) -@_patch('torch.distributions.utils._default_promotion') -def _default_promotion(v): - # Ignore jit warnings about promoting Python numbers to tensors, - # assuming all numbers are constant literals. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="torch.tensor might cause the trace to be incorrect") - return _default_promotion._pyro_unpatched(v) - - @_patch('torch.einsum') def _einsum(equation, operands): # work around torch.einsum performance issues diff --git a/pyro/infer/enum.py b/pyro/infer/enum.py index 972083d2ca..89577f23b1 100644 --- a/pyro/infer/enum.py +++ b/pyro/infer/enum.py @@ -3,7 +3,6 @@ import numbers import warnings -import torch from six.moves.queue import LifoQueue from pyro import poutine diff --git a/pyro/ops/einsum/torch_log.py b/pyro/ops/einsum/torch_log.py index f8730630e2..85b715ab19 100644 --- a/pyro/ops/einsum/torch_log.py +++ b/pyro/ops/einsum/torch_log.py @@ -48,12 +48,9 @@ def einsum(equation, *operands): # This function is copied and adapted from: # https://github.com/dgasmith/opt_einsum/blob/a6dd686/opt_einsum/backends/torch.py def tensordot(x, y, axes=2): - xnd = x.ndimension() - ynd = y.ndimension() - # convert int argument to (list[int], list[int]) if isinstance(axes, int): - axes = list(range(xnd - axes, xnd)), list(range(axes)) + axes = list(range(x.dim() - axes, x.dim())), list(range(axes)) # convert (int, int) to (list[int], list[int]) if isinstance(axes[0], int): diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index 6cf5adeafb..f2499f2f92 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -1,4 +1,8 @@ +from __future__ import absolute_import, division, print_function + +import warnings import weakref + import torch import pyro @@ -47,8 +51,15 @@ def compiled(*params_and_args): constrained_params[name] = constrained_param return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) - with pyro.validation_enabled(False): - self.compiled[argc] = torch.jit.trace(compiled, params_and_args) + with pyro.validation_enabled(False), warnings.catch_warnings(): + # Ignore jit warnings about promoting Python numbers to tensors, + # assuming all numbers are constant literals. + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="torch.tensor might cause the trace to be incorrect") + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="Converting a tensor to a Python") + + self.compiled[argc] = torch.jit.trace(compiled, params_and_args, check_trace=False) else: unconstrained_params = [pyro.param(name).unconstrained() for name in self._param_names] diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 383dfc2e94..ea9e61a6be 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -33,7 +33,7 @@ def f(x): return y + 1.0 print('Compiling f') - f = torch.jit.trace(f, (y,)) + f = torch.jit.trace(f, (y,), check_trace=False) print('Calling f(y)') assert_equal(f(y), y.new_tensor([2., 2.])) print('Calling f(y)') @@ -49,11 +49,13 @@ def test_multi_output(): def f(x): print('Inside f') - assert x is y + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + assert x is y return y - 1.0, y + 1.0 print('Compiling f') - f = torch.jit.trace(f, (y,)) + f = torch.jit.trace(f, (y,), check_trace=False) print('Calling f(y)') assert_equal(f(y)[1], y.new_tensor([2., 2.])) print('Calling f(y)') @@ -69,7 +71,9 @@ def test_backward(): def f(x): print('Inside f') - assert x is y + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + assert x is y return (y + 1.0).sum() print('Compiling f') From 0c5024381a1ccaf1b6c5238825d424fe6cc76770 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 13:07:05 -0700 Subject: [PATCH 051/157] Fix a tiny test_jit error --- tests/infer/test_jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index ea9e61a6be..76f88ee46a 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -77,7 +77,7 @@ def f(x): return (y + 1.0).sum() print('Compiling f') - f = torch.jit.trace(f, (y,)) + f = torch.jit.trace(f, (y,), check_trace=False) print('Calling f(y)') f(y).backward() print('Calling f(y)') From 850432f102fe09ba168953fe2ae40e9bf8f85b29 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 13:30:20 -0700 Subject: [PATCH 052/157] Add jit test for OneHotCategorical --- tests/infer/test_jit.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 76f88ee46a..765b02ea5a 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -153,6 +153,23 @@ def f(probs): assert log_prob.shape == shape[-1:] + batch_shape +@pytest.mark.parametrize('expand', [False, True]) +@pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) +def test_one_hot_categorical_enumerate(shape, expand): + shape = torch.Size(shape) + probs = torch.ones(shape) + + @pyro.ops.jit.trace + def f(probs): + d = dist.OneHotCategorical(probs) + support = d.enumerate_support(expand=expand) + return d.log_prob(support) + + log_prob = f(probs) + batch_shape = shape[:-1] + assert log_prob.shape == shape[-1:] + batch_shape + + @pytest.mark.parametrize('num_particles', [1, 10]) @pytest.mark.parametrize('Elbo', [ Trace_ELBO, From a03164abbda4ae10cdebd482efcfc9d0e7fc9142 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 11 Sep 2018 14:41:52 -0700 Subject: [PATCH 053/157] fix JIT errors for HMC --- pyro/infer/mcmc/hmc.py | 13 ++++++++++--- tests/infer/mcmc/test_hmc.py | 6 ++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index f96892db6d..503a68d7a8 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import math +import warnings from collections import OrderedDict import torch @@ -153,8 +154,7 @@ def _potential_energy_jit(self, z): if self._compiled_potential_fn: return self._compiled_potential_fn(*vals) - @torch.jit.trace(*vals, optimize=True) - def wrapped(*zi): + def compiled(*zi): z_constrained = list(zi) # transform to constrained space. for i, name in enumerate(names): @@ -171,7 +171,14 @@ def wrapped(*zi): potential_energy += transform.log_abs_det_jacobian(z_constrained[name], zi[i]).sum() return potential_energy - self._compiled_potential_fn = wrapped + with pyro.validation_enabled(False), warnings.catch_warnings(): + # Ignore jit warnings about promoting Python numbers to tensors, + # assuming all numbers are constant literals. + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="torch.tensor might cause the trace to be incorrect") + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="Converting a tensor to a Python") + self._compiled_potential_fn = torch.jit.trace(compiled, vals) return self._compiled_potential_fn(*vals) def _energy(self, z, r): diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 0e93c8e05c..d5fa3a8a08 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -109,8 +109,7 @@ def rmse(t1, t2): mean_tol=0.05, std_tol=0.05, ), marks=[pytest.mark.xfail(reason="flaky"), - pytest.mark.skipif('CI' in os.environ and os.environ['CI'] == 'true', - reason='Slow test - skip on CI')]), + pytest.mark.skip(reason='Slow test')]), pytest.param(*T( GaussianChain(dim=5, chain_len=9, num_obs=1), num_samples=3000, @@ -122,8 +121,7 @@ def rmse(t1, t2): mean_tol=0.08, std_tol=0.08, ), marks=[pytest.mark.xfail(reason="flaky"), - pytest.mark.skipif('CI' in os.environ and os.environ['CI'] == 'true', - reason='Slow test - skip on CI')]) + pytest.mark.skipif(reason='Slow test')]) ] TEST_IDS = [t[0].id_fn() if type(t).__name__ == 'TestExample' From c36a9aaa608aa9169e55d9e37de02bcdb6f764c6 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 11 Sep 2018 16:12:21 -0700 Subject: [PATCH 054/157] change assert in torch_log --- pyro/ops/einsum/torch_log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/ops/einsum/torch_log.py b/pyro/ops/einsum/torch_log.py index 85b715ab19..1dc87f08ec 100644 --- a/pyro/ops/einsum/torch_log.py +++ b/pyro/ops/einsum/torch_log.py @@ -59,7 +59,7 @@ def tensordot(x, y, axes=2): axes = axes[0], (axes[1],) # compute shifts - assert all(dim >= 0 for axis in axes for dim in axes) + assert all(dim >= 0 for axis in axes for dim in axis) x_shift = x y_shift = y for dim in axes[0]: From ecfc9957dbd0a455354854f7df4c3528264381d6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 16:36:15 -0700 Subject: [PATCH 055/157] Work around more jit missing coverage --- pyro/distributions/delta.py | 2 +- pyro/distributions/util.py | 8 +++++++- pyro/primitives.py | 4 ++-- tests/infer/test_jit.py | 24 ++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index 3dab0e009c..e47cf1984c 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -55,7 +55,7 @@ def rsample(self, sample_shape=torch.Size()): def log_prob(self, x): v = self.v.expand(self.shape()) - log_prob = x.new_tensor(x == v).log() + log_prob = (x == v).type_as(v).log() log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index 3228dcfd93..39a09bdfb0 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import numbers +import warnings from contextlib import contextmanager import torch @@ -177,7 +178,12 @@ def scale_and_mask(tensor, scale=1.0, mask=None): return tensor * scale tensor, mask = broadcast_all(tensor, mask) tensor = tensor * scale # triggers a copy, avoiding in-place op errors - tensor.masked_fill_(~mask, 0.) + if torch._C._get_tracing_state(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + tensor[~mask] = 0. + else: + tensor.masked_fill_(~mask, 0.) return tensor diff --git a/pyro/primitives.py b/pyro/primitives.py index e000d767bf..1b5444ce40 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -248,7 +248,7 @@ def __enter__(self): self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) if self._wrapped: try: - self._scale_messenger = poutine.scale(scale=self.size / self.subsample_size) + self._scale_messenger = poutine.scale(scale=float(self.size / self.subsample_size)) self._indep_messenger = poutine.indep(name=self.name, size=self.subsample_size, dim=self.dim) self._scale_messenger.__enter__() self._indep_messenger.__enter__() @@ -310,7 +310,7 @@ def __iter__(self): yield i if isinstance(i, numbers.Number) else i.item() else: indep_context = poutine.indep(name=self.name, size=self.subsample_size) - with poutine.scale(scale=self.size / self.subsample_size): + with poutine.scale(scale=float(self.size / self.subsample_size)): for i in subsample: indep_context.next_context() with indep_context: diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 765b02ea5a..1eba28966d 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -120,6 +120,30 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) +def test_masked_fill(): + + def f(y, mask): + return y.clone().masked_fill_(mask, 0.) + + x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) + y = x / x.unsqueeze(-1) + mask = ~(y == y) + f = torch.jit.trace(f, (y, mask)) + + +def test_masked_fill_workaround(): + + def f(y, mask): + y = y.clone() + y[mask] = 0. + return y + + x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) + y = x / x.unsqueeze(-1) + mask = ~(y == y) + f = torch.jit.trace(f, (y, mask)) + + @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) def test_bernoulli_enumerate(shape, expand): From c47e2944ac318c33864a37e3ca6f5748c29ef2e0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 11 Sep 2018 18:46:54 -0700 Subject: [PATCH 056/157] Strengthen masked_fill test --- tests/infer/test_jit.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 1eba28966d..43efbe1cfa 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -120,6 +120,7 @@ def f(x, y): f(torch.zeros(2, requires_grad=True), torch.zeros(1, requires_grad=True)) +@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/11555") def test_masked_fill(): def f(y, mask): @@ -134,14 +135,19 @@ def f(y, mask): def test_masked_fill_workaround(): def f(y, mask): + return y.clone().masked_fill_(mask, 0.) + + def g(y, mask): y = y.clone() - y[mask] = 0. + y[mask] = 0. # this is much slower than .masked_fill_() return y x = torch.tensor([-float('inf'), -1., 0., 1., float('inf')]) y = x / x.unsqueeze(-1) mask = ~(y == y) - f = torch.jit.trace(f, (y, mask)) + assert_equal(f(y, mask), g(y, mask)) + g = torch.jit.trace(g, (y, mask)) + assert_equal(f(y, mask), g(y, mask)) @pytest.mark.parametrize('expand', [False, True]) From c19830d3302b4981cc48f354da5a009b0b3b8654 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 12 Sep 2018 09:08:43 -0700 Subject: [PATCH 057/157] fix hmc enum test --- tests/infer/mcmc/test_hmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index d5fa3a8a08..b95a59474f 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -300,8 +300,8 @@ def test_gaussian_mixture_model(jit): @poutine.broadcast def gmm(data): + mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) with pyro.iarange("num_clusters", K): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) From 0c233d33296a1be658e6d8c53e0996aeffb121b7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 12 Sep 2018 17:52:58 -0700 Subject: [PATCH 058/157] Fix failing jit tests --- pyro/infer/mcmc/hmc.py | 2 ++ pyro/ops/jit.py | 2 ++ tests/infer/test_jit.py | 11 +++++++++++ 3 files changed, 15 insertions(+) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 503a68d7a8..06c6e6531b 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -178,6 +178,8 @@ def compiled(*zi): message="torch.tensor might cause the trace to be incorrect") warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, message="Converting a tensor to a Python") + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="torch.tensor results are registered as constants in the trace") self._compiled_potential_fn = torch.jit.trace(compiled, vals) return self._compiled_potential_fn(*vals) diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index f2499f2f92..30e43eaf93 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -58,6 +58,8 @@ def compiled(*params_and_args): message="torch.tensor might cause the trace to be incorrect") warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, message="Converting a tensor to a Python") + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, + message="torch.tensor results are registered as constants in the trace") self.compiled[argc] = torch.jit.trace(compiled, params_and_args, check_trace=False) else: diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index a77e9912ed..bcc344aa44 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -151,6 +151,17 @@ def g(y, mask): assert_equal(f(y, mask), g(y, mask)) +@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/11614") +def test_scatter(): + + def make_one_hot(x, i): + return x.new_zeros(x.shape).scatter(-1, i, 1.0) + + x = torch.randn(5, 4, 3) + i = torch.randint(0, 3, torch.Size((5, 4, 1))) + torch.jit.trace(make_one_hot, (x, i)) + + @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) def test_bernoulli_enumerate(shape, expand): From dceaf9af0c608b3c30d2fd399ee1b661e1325b8a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 13 Sep 2018 09:46:49 -0700 Subject: [PATCH 059/157] Add test for .scatter_() workaround --- tests/infer/test_jit.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index bcc344aa44..77b5a09092 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -155,13 +155,30 @@ def g(y, mask): def test_scatter(): def make_one_hot(x, i): - return x.new_zeros(x.shape).scatter(-1, i, 1.0) + return x.new_zeros(x.shape).scatter(-1, i.unsqueeze(-1), 1.0) x = torch.randn(5, 4, 3) - i = torch.randint(0, 3, torch.Size((5, 4, 1))) + i = torch.randint(0, 3, torch.Size((5, 4))) torch.jit.trace(make_one_hot, (x, i)) +def test_scatter_workaround(): + + def make_one_hot_expected(x, i): + return x.new_zeros(x.shape).scatter(-1, i.unsqueeze(-1), 1.0) + + def make_one_hot_actual(x, i): + eye = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) + return eye[i].clone() + + x = torch.randn(5, 4, 3) + i = torch.randint(0, 3, torch.Size((5, 4))) + torch.jit.trace(make_one_hot_actual, (x, i)) + expected = make_one_hot_expected(x, i) + actual = make_one_hot_actual(x, i) + assert_equal(actual, expected) + + @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) def test_bernoulli_enumerate(shape, expand): From e5cd034cc2def99544c88588ec1f18c71593ff2e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 13 Sep 2018 10:18:03 -0700 Subject: [PATCH 060/157] add expand for MaskedDistribution --- pyro/distributions/torch_distribution.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 05d3702800..5509e5fe2d 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -237,6 +237,15 @@ def __init__(self, base_dist, mask): self._mask = mask.byte() super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MaskedDistribution, _instance) + batch_shape = torch.Size(batch_shape) + new.base_dist = self.base_dist.expand(batch_shape) + new._mask = self._mask.expand(batch_shape) + super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + @property def has_rsample(self): return self.base_dist.has_rsample From 6ce99251e92b6ab68eb42d7fac7aef437e548007 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 13 Sep 2018 16:07:12 -0700 Subject: [PATCH 061/157] remove binomial and half cauchy --- docs/source/distributions.rst | 15 ---- pyro/distributions/__init__.py | 4 - pyro/distributions/binomial.py | 135 ------------------------------ pyro/distributions/half_cauchy.py | 57 ------------- pyro/distributions/torch.py | 2 - 5 files changed, 213 deletions(-) delete mode 100644 pyro/distributions/binomial.py delete mode 100644 pyro/distributions/half_cauchy.py diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 8bcf8f2aa7..b38bd7f633 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -56,14 +56,6 @@ AVFMultivariateNormal :undoc-members: :show-inheritance: -Binomial --------- - -.. autoclass:: pyro.distributions.Binomial - :members: - :undoc-members: - :show-inheritance: - Delta ----- .. autoclass:: pyro.distributions.Delta @@ -85,13 +77,6 @@ GaussianScaleMixture :undoc-members: :show-inheritance: -HalfCauchy ----------- -.. autoclass:: pyro.distributions.HalfCauchy - :members: - :undoc-members: - :show-inheritance: - LowRankMultivariateNormal ------------------------- .. autoclass:: pyro.distributions.LowRankMultivariateNormal diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 3f39b98fc6..2d30709ba7 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -2,14 +2,12 @@ import pyro.distributions.torch_patch # noqa F403 from pyro.distributions.avf_mvn import AVFMultivariateNormal -from pyro.distributions.binomial import Binomial from pyro.distributions.delta import Delta from pyro.distributions.diag_normal_mixture_shared_cov import MixtureOfDiagNormalsSharedCovariance from pyro.distributions.diag_normal_mixture import MixtureOfDiagNormals from pyro.distributions.distribution import Distribution from pyro.distributions.empirical import Empirical from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture -from pyro.distributions.half_cauchy import HalfCauchy from pyro.distributions.iaf import InverseAutoregressiveFlow from pyro.distributions.mixture import MaskedMixture from pyro.distributions.omt_mvn import OMTMultivariateNormal @@ -29,12 +27,10 @@ "is_validation_enabled", "validation_enabled", "AVFMultivariateNormal", - "Binomial", "Delta", "Distribution", "Empirical", "GaussianScaleMixture", - "HalfCauchy", "InverseAutoregressiveFlow", "MaskedMixture", "MixtureOfDiagNormalsSharedCovariance", diff --git a/pyro/distributions/binomial.py b/pyro/distributions/binomial.py deleted file mode 100644 index 13d164a3e5..0000000000 --- a/pyro/distributions/binomial.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import absolute_import, division, print_function - -from numbers import Number - -import torch -from torch.distributions import constraints -from torch.distributions.utils import broadcast_all, lazy_property, logits_to_probs, probs_to_logits - -from pyro.distributions.torch_distribution import TorchDistributionMixin - - -class Binomial(torch.distributions.Distribution, TorchDistributionMixin): - r""" - Creates a Binomial distribution parameterized by `total_count` and - either `probs` or `logits` (but not both). `total_count` must be - broadcastable with `probs`/`logits`. - - This is adapted from :class:`torch.distributions.binomial.Binomial`, - with the important difference that `total_count` is not limited to - being a single `int`, but can be a `torch.Tensor`. - - Example:: - - >>> m = Binomial(100, torch.Tensor([0 , .2, .8, 1])) - >>> m.sample() # doctest: +SKIP - 0 - 22 - 71 - 100 - [torch.FloatTensor of size 4]] - - >>> m = Binomial(torch.Tensor([[5.], [10.]]), torch.Tensor([0.5, 0.8])) - >>> m.sample() # doctest: +SKIP - 4 5 - 7 6 - [torch.FloatTensor of size (2,2)] - - :param (Tensor) total_count: number of Bernoulli trials - :param (Tensor) probs: Event probabilities - :param (Tensor) logits: Event log-odds - """ - arg_constraints = {'total_count': constraints.nonnegative_integer, - 'probs': constraints.unit_interval} - has_enumerate_support = True - - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): - if (probs is None) == (logits is None): - raise ValueError("Either `probs` or `logits` must be specified, but not both.") - if probs is not None: - self.total_count, self.probs, = broadcast_all(total_count, probs) - is_scalar = isinstance(self.probs, Number) - else: - self.total_count, self.logits, = broadcast_all(total_count, logits) - is_scalar = isinstance(self.logits, Number) - - self._param = self.probs if probs is not None else self.logits - if is_scalar: - batch_shape = torch.Size() - else: - batch_shape = self._param.shape - super(Binomial, self).__init__(batch_shape, validate_args=validate_args) - - def _new(self, *args, **kwargs): - return self._param.new(*args, **kwargs) - - @constraints.dependent_property - def support(self): - return constraints.integer_interval(0, self.total_count) - - @property - def mean(self): - return self.total_count * self.probs - - @property - def variance(self): - return self.total_count * self.probs * (1 - self.probs) - - @lazy_property - def logits(self): - return probs_to_logits(self.probs, is_binary=True) - - @lazy_property - def probs(self): - return logits_to_probs(self.logits, is_binary=True) - - @property - def param_shape(self): - return self._param.shape - - def sample(self, sample_shape=torch.Size()): - with torch.no_grad(): - max_count = max(int(self.total_count.max()), 1) - shape = self._extended_shape(sample_shape) + (max_count,) - bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)) - if self.total_count.min() != max_count: - arange = torch.arange(max_count, out=self.total_count.new_empty(max_count)) - mask = arange >= self.total_count.unsqueeze(-1) - bernoullis.masked_fill_(mask, 0.) - return bernoullis.sum(dim=-1) - - def log_prob(self, value): - if self._validate_args: - self._validate_sample(value) - log_factorial_n = torch.lgamma(self.total_count + 1) - log_factorial_k = torch.lgamma(value + 1) - log_factorial_nmk = torch.lgamma(self.total_count - value + 1) - max_val = (-self.logits).clamp(min=0.0) - # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) - return (log_factorial_n - log_factorial_k - log_factorial_nmk + - value * self.logits + self.total_count * max_val - - self.total_count * torch.log1p((self.logits + 2 * max_val).exp())) - - def enumerate_support(self, expand=True): - total_count = int(self.total_count.max()) - if not self.total_count.min() == total_count: - raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.") - values = self._new(1 + total_count,) - torch.arange(1 + total_count, out=values) - values = values.view((-1,) + (1,) * len(self._batch_shape)) - if expand: - values = values.expand((-1,) + self._batch_shape) - return values - - def expand(self, batch_shape): - try: - return super(Binomial, self).expand(batch_shape) - except NotImplementedError: - validate_args = self.__dict__.get('_validate_args') - total_count = self.total_count.expand(batch_shape) - if 'probs' in self.__dict__: - probs = self.probs.expand(batch_shape) - return type(self)(total_count, probs=probs, validate_args=validate_args) - else: - logits = self.logits.expand(batch_shape) - return type(self)(total_count, logits=logits, validate_args=validate_args) diff --git a/pyro/distributions/half_cauchy.py b/pyro/distributions/half_cauchy.py deleted file mode 100644 index a4b8f8f0eb..0000000000 --- a/pyro/distributions/half_cauchy.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import math - -from torch.distributions import constraints -from torch.distributions.transforms import AbsTransform, AffineTransform -from torch.distributions.utils import broadcast_all - -from pyro.distributions.torch import Cauchy, TransformedDistribution - - -class HalfCauchy(TransformedDistribution): - r""" - Half-Cauchy distribution. - - This is a continuous distribution with lower-bounded domain (`x > loc`). - See also the :class:`~pyro.distributions.torch.Cauchy` distribution. - - :param torch.Tensor loc: lower bound of the distribution. - :param torch.Tensor scale: half width at half maximum. - """ - arg_constraints = Cauchy.arg_constraints - support = Cauchy.support - - def __init__(self, loc=0, scale=1): - loc, scale = broadcast_all(loc, scale) - base_dist = Cauchy(0, scale) - transforms = [AbsTransform(), AffineTransform(loc, 1)] - super(HalfCauchy, self).__init__(base_dist, transforms) - - @property - def loc(self): - return self.transforms[1].loc - - @property - def scale(self): - return self.base_dist.scale - - @constraints.dependent_property - def support(self): - return constraints.greater_than(self.loc) - - def log_prob(self, value): - log_prob = self.base_dist.log_prob(value - self.loc) + math.log(2) - log_prob[value < self.loc] = -float('inf') - return log_prob - - def entropy(self): - return self.base_dist.entropy() - math.log(2) - - def expand(self, batch_shape): - try: - return super(HalfCauchy, self).expand(batch_shape) - except NotImplementedError: - loc = self.loc.expand(batch_shape) - scale = self.scale.expand(batch_shape) - return type(self)(loc, scale) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index d996cf010d..ef6e903f40 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -13,8 +13,6 @@ class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributi # Programmatically load all distributions from PyTorch. __all__ = [] for _name, _Dist in torch.distributions.__dict__.items(): - if _name == 'Binomial': - continue if not isinstance(_Dist, type): continue if not issubclass(_Dist, torch.distributions.Distribution): From 68c116889e58ba0b72aef519d7ce7c7522721d77 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 14 Sep 2018 09:54:19 -0700 Subject: [PATCH 062/157] reinstate Independent constraint --- pyro/distributions/torch.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index ef6e903f40..c2c8b593f7 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -10,6 +10,20 @@ class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributi support = IndependentConstraint(constraints.real, 1) # TODO move upstream +class Independent(torch.distributions.Independent, TorchDistributionMixin): + @constraints.dependent_property + def support(self): + return IndependentConstraint(self.base_dist.support, self.reinterpreted_batch_ndims) + + @property + def _validate_args(self): + return self.base_dist._validate_args + + @_validate_args.setter + def _validate_args(self, value): + self.base_dist._validate_args = value + + # Programmatically load all distributions from PyTorch. __all__ = [] for _name, _Dist in torch.distributions.__dict__.items(): From 657fc56e366200c908fe71fb66a2562c34fa1def Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 14 Sep 2018 12:20:26 -0700 Subject: [PATCH 063/157] add expand methods to more distributions --- pyro/distributions/delta.py | 12 +++-- pyro/distributions/diag_normal_mixture.py | 17 ++++++- .../diag_normal_mixture_shared_cov.py | 15 ++++++ .../distributions/relaxed_straight_through.py | 8 ---- pyro/distributions/testing/rejection_gamma.py | 47 +++++++++++++++++++ 5 files changed, 85 insertions(+), 14 deletions(-) diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index e47cf1984c..0960725958 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -42,12 +42,14 @@ def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): self.log_density = log_density super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args) - def expand(self, batch_shape): - validate_args = self.__dict__.get('_validate_args') + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Delta, _instance) batch_shape = torch.Size(batch_shape) - v = self.v.expand(batch_shape + self.event_shape) - log_density = self.log_density.expand(batch_shape) - return Delta(v, log_density, self.event_dim, validate_args=validate_args) + new.v = self.v.expand(batch_shape + self.event_shape) + new.log_density = self.log_density.expand(batch_shape) + super(Delta, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new def rsample(self, sample_shape=torch.Size()): shape = sample_shape + self.v.shape diff --git a/pyro/distributions/diag_normal_mixture.py b/pyro/distributions/diag_normal_mixture.py index 0ce30a1e8a..521905a679 100644 --- a/pyro/distributions/diag_normal_mixture.py +++ b/pyro/distributions/diag_normal_mixture.py @@ -67,7 +67,22 @@ def __init__(self, locs, coord_scale, component_logits): self.dim = locs.size(-1) self.categorical = Categorical(logits=component_logits) self.probs = self.categorical.probs - super(MixtureOfDiagNormals, self).__init__(batch_shape=batch_shape, event_shape=(self.dim,)) + super(MixtureOfDiagNormals, self).__init__(batch_shape=torch.Size(batch_shape), + event_shape=torch.Size((self.dim,))) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MixtureOfDiagNormals, _instance) + new.batch_mode = True + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.locs = self.locs.expand(batch_shape + self.locs.shape[-2:]) + new.coord_scale = self.coord_scale.expand(batch_shape + self.coord_scale.shape[-2:]) + new.component_logits = self.component_logits.expand(batch_shape + self.component_logits.shape[-1:]) + new.categorical = self.categorical.expand(batch_shape) + new.probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) + super(MixtureOfDiagNormals, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new def log_prob(self, value): epsilon = (value.unsqueeze(-2) - self.locs) / self.coord_scale # L B K D diff --git a/pyro/distributions/diag_normal_mixture_shared_cov.py b/pyro/distributions/diag_normal_mixture_shared_cov.py index 4361d5790c..fd1b44a8b9 100644 --- a/pyro/distributions/diag_normal_mixture_shared_cov.py +++ b/pyro/distributions/diag_normal_mixture_shared_cov.py @@ -68,6 +68,21 @@ def __init__(self, locs, coord_scale, component_logits): self.probs = self.categorical.probs super(MixtureOfDiagNormalsSharedCovariance, self).__init__(batch_shape=batch_shape, event_shape=(self.dim,)) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MixtureOfDiagNormalsSharedCovariance, _instance) + new.batch_mode = True + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.locs = self.locs.expand(batch_shape + self.locs.shape[-2:]) + coord_scale_shape = -1 if self.batch_mode else -2 + new.coord_scale = self.coord_scale.expand(batch_shape + self.coord_scale.shape[coord_scale_shape:]) + new.component_logits = self.component_logits.expand(batch_shape + self.component_logits.shape[-1:]) + new.categorical = self.categorical.expand(batch_shape) + new.probs = self.probs.expand(batch_shape + self.probs.shape[-1:]) + super(MixtureOfDiagNormalsSharedCovariance, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + def log_prob(self, value): # TODO: use torch.logsumexp once it's in PyTorch release coord_scale = self.coord_scale.unsqueeze(-2) if self.batch_mode else self.coord_scale diff --git a/pyro/distributions/relaxed_straight_through.py b/pyro/distributions/relaxed_straight_through.py index f0f7739b06..5d57cdf1f5 100644 --- a/pyro/distributions/relaxed_straight_through.py +++ b/pyro/distributions/relaxed_straight_through.py @@ -29,10 +29,6 @@ class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical): [2] Categorical Reparameterization with Gumbel-Softmax, Eric Jang, Shixiang Gu, Ben Poole """ - def __init__(self, temperature, probs=None, logits=None, validate_args=None): - super(RelaxedOneHotCategoricalStraightThrough, self).__init__(temperature=temperature, probs=probs, - logits=logits, validate_args=validate_args) - def rsample(self, sample_shape=torch.Size()): soft_sample = super(RelaxedOneHotCategoricalStraightThrough, self).rsample(sample_shape) soft_sample = clamp_probs(soft_sample) @@ -81,10 +77,6 @@ class RelaxedBernoulliStraightThrough(RelaxedBernoulli): [2] Categorical Reparameterization with Gumbel-Softmax, Eric Jang, Shixiang Gu, Ben Poole """ - def __init__(self, temperature, probs=None, logits=None, validate_args=None): - super(RelaxedBernoulliStraightThrough, self).__init__(temperature=temperature, probs=probs, - logits=logits, validate_args=validate_args) - def rsample(self, sample_shape=torch.Size()): soft_sample = super(RelaxedBernoulliStraightThrough, self).rsample(sample_shape) soft_sample = clamp_probs(soft_sample) diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index f815e0e323..dde109a94e 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -27,6 +27,20 @@ def __init__(self, concentration): log_scale = self.propose_log_prob(x) + self.log_prob_accept(x) - self.log_prob(x) super(RejectionStandardGamma, self).__init__(self.propose, self.log_prob_accept, log_scale) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RejectionStandardGamma, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new._standard_gamma = self._standard_gamma.expand(batch_shape) + new._d = self._d.expand(batch_shape) + new._c = self._c.expand(batch_shape) + # Compute log scale using Gamma.log_prob(). + x = new._d.detach() # just an arbitrary x. + log_scale = new.propose_log_prob(x) + new.log_prob_accept(x) - new.log_prob(x) + super(RejectionStandardGamma, new).__init__(new.propose, new.log_prob_accept, log_scale) + new._validate_args = self._validate_args + return new + def propose(self, sample_shape=torch.Size()): # Marsaglia & Tsang's x == Naesseth's epsilon x = self.concentration.new_empty(sample_shape + self.concentration.shape).normal_() @@ -65,6 +79,13 @@ def __init__(self, concentration, rate, validate_args=None): self._standard_gamma = RejectionStandardGamma(concentration) self.rate = rate + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RejectionGamma, _instance) + new = super(RejectionGamma, self).expand(batch_shape, new) + new._standard_gamma = self._standard_gamma.expand(batch_shape) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): return self._standard_gamma.rsample(sample_shape) / self.rate @@ -94,6 +115,16 @@ def __init__(self, concentration, rate, boost=1, validate_args=None): self._rejection_gamma = RejectionGamma(concentration + boost, rate) self._unboost_x_cache = None, None + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShapeAugmentedGamma, _instance) + new = super(ShapeAugmentedGamma, self).expand(batch_shape, new) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new._boost = self._boost + new._rejection_gamma = self._rejection_gamma.expand(batch_shape) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): x = self._rejection_gamma.rsample(sample_shape) boosted_x = x.clone() @@ -124,6 +155,14 @@ def __init__(self, concentration, boost=1, validate_args=None): super(ShapeAugmentedDirichlet, self).__init__(concentration, validate_args=validate_args) self._gamma = ShapeAugmentedGamma(concentration, torch.ones_like(concentration), boost) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShapeAugmentedDirichlet, _instance) + new = super(ShapeAugmentedDirichlet, self).expand(batch_shape, new) + batch_shape = torch.Size(batch_shape) + new._gamma = self._gamma.expand(batch_shape + self._gamma.concentration.shape[-1:]) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): gammas = self._gamma.rsample(sample_shape) return gammas / gammas.sum(-1, True) @@ -142,6 +181,14 @@ def __init__(self, concentration1, concentration0, boost=1, validate_args=None): alpha_beta = torch.stack([concentration1, concentration0], -1) self._gamma = ShapeAugmentedGamma(alpha_beta, torch.ones_like(alpha_beta), boost) + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShapeAugmentedBeta, _instance) + new = super(ShapeAugmentedBeta, self).expand(batch_shape, new) + batch_shape = torch.Size(batch_shape) + new._gamma = self._gamma.expand(batch_shape + self._gamma.concentration.shape[-1:]) + new._validate_args = self._validate_args + return new + def rsample(self, sample_shape=torch.Size()): gammas = self._gamma.rsample(sample_shape) probs = gammas / gammas.sum(-1, True) From d44f90b7925fd7a745b6786bbf8fd891773a69fb Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 14 Sep 2018 12:34:42 -0700 Subject: [PATCH 064/157] Fix CUDA tests in test_eig.py --- tests/contrib/oed/test_eig.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/contrib/oed/test_eig.py b/tests/contrib/oed/test_eig.py index 8d7f119b57..dd4384c785 100644 --- a/tests/contrib/oed/test_eig.py +++ b/tests/contrib/oed/test_eig.py @@ -76,6 +76,7 @@ def bernoulli_ground_truth(model, design, observation_labels, target_labels, eig def h(p): + p = p.cpu().numpy() return -(sc.xlogy(p, p) + sc.xlog1py(1 - p, -p)) @@ -228,11 +229,11 @@ def h(p): def test_eig_lm(model, design, observation_labels, target_labels, estimator, args, eig, allow_error): pyro.set_rng_seed(42) pyro.clear_param_store() - y = estimator(model, design, observation_labels, target_labels, *args).cpu() + y = estimator(model, design, observation_labels, target_labels, *args) if model is bernoulli_model: - y_true = bernoulli_ground_truth(model, design, observation_labels, target_labels, eig=eig).cpu() + y_true = bernoulli_ground_truth(model, design, observation_labels, target_labels, eig=eig) else: - y_true = linear_model_ground_truth(model, design, observation_labels, target_labels, eig=eig).cpu() + y_true = linear_model_ground_truth(model, design, observation_labels, target_labels, eig=eig) print() print(estimator.__name__) print(y) From c2c4b724b913425761f4bccd778236b954511294 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 14 Sep 2018 12:41:22 -0700 Subject: [PATCH 065/157] remove standard gamma patch --- pyro/distributions/torch_patch.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 56727005b9..d144b384c9 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -22,14 +22,6 @@ def decorator(new_fn): return decorator -@_patch('torch._standard_gamma') -def _torch_standard_gamma(concentration): - unpatched_fn = _torch_standard_gamma._pyro_unpatched - if concentration.is_cuda: - return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) - return unpatched_fn(concentration) - - @_patch('torch._dirichlet_grad') def _torch_dirichlet_grad(x, concentration, total): unpatched_fn = _torch_dirichlet_grad._pyro_unpatched From 5532c3feda7a7238b8ecfcdd3b5fde6b70402559 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 19 Sep 2018 13:11:26 -0400 Subject: [PATCH 066/157] Work-around to allow JIT compiler to infer batch size in iarange (#1392) --- examples/air/main.py | 2 +- pyro/distributions/util.py | 13 ++++----- pyro/infer/elbo.py | 6 ++++- pyro/infer/enum.py | 7 +++-- pyro/infer/mcmc/hmc.py | 19 +++++-------- pyro/infer/mcmc/nuts.py | 2 ++ pyro/infer/trace_elbo.py | 3 ++- pyro/infer/traceenum_elbo.py | 3 ++- pyro/infer/tracegraph_elbo.py | 2 +- pyro/ops/jit.py | 22 +++++----------- pyro/ops/sumproduct.py | 6 +++-- pyro/poutine/broadcast_messenger.py | 6 +++-- pyro/poutine/indep_messenger.py | 6 +++-- pyro/primitives.py | 41 +++++++++++++---------------- pyro/util.py | 34 ++++++++++++++++++++++++ tests/infer/mcmc/test_hmc.py | 16 ++++++----- tests/infer/mcmc/test_nuts.py | 17 +++++++----- tests/infer/test_enum.py | 4 +-- tests/infer/test_jit.py | 29 +++++++++++++++----- 19 files changed, 142 insertions(+), 96 deletions(-) diff --git a/examples/air/main.py b/examples/air/main.py index aae26828de..7a76d84e9b 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -287,7 +287,7 @@ def per_param_optim_args(module_name, param_name): help='number of steps between parameter saves') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') - parser.add_argument('--jit', action='store_true', default=False, + parser.add_argument('--jit', action='store_true', default=True, help='use PyTorch jit') parser.add_argument('-t', '--model-steps', type=int, default=3, help='number of time steps') diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index 39a09bdfb0..fa8576bc60 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import numbers -import warnings from contextlib import contextmanager import torch @@ -9,6 +8,7 @@ from torch import logsumexp from torch.distributions.utils import broadcast_all + _VALIDATION_ENABLED = False log_sum_exp = logsumexp # DEPRECATED @@ -170,18 +170,15 @@ def scale_and_mask(tensor, scale=1.0, mask=None): :param mask: an optional masking tensor :type mask: torch.ByteTensor or None """ - if is_identically_zero(tensor): - return tensor - if mask is None: - if is_identically_one(scale): + if not torch._C._get_tracing_state(): + if is_identically_zero(tensor) or (mask is None and is_identically_one(scale)): return tensor + if mask is None: return tensor * scale tensor, mask = broadcast_all(tensor, mask) tensor = tensor * scale # triggers a copy, avoiding in-place op errors if torch._C._get_tracing_state(): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) - tensor[~mask] = 0. + tensor[~mask] = 0. else: tensor.masked_fill_(~mask, 0.) return tensor diff --git a/pyro/infer/elbo.py b/pyro/infer/elbo.py index 5717a66137..cdb15cb8fc 100644 --- a/pyro/infer/elbo.py +++ b/pyro/infer/elbo.py @@ -36,6 +36,8 @@ class ELBO(object): misuse of enumeration, i.e. that :class:`pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there are enumerated sample sites. + :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT + tracer, when . All :class:`torch.jit.TracerWarning` will be ignored. References @@ -50,7 +52,8 @@ def __init__(self, num_particles=1, max_iarange_nesting=float('inf'), vectorize_particles=False, - strict_enumeration_warning=True): + strict_enumeration_warning=True, + ignore_jit_warnings=False): self.num_particles = num_particles self.max_iarange_nesting = max_iarange_nesting self.vectorize_particles = vectorize_particles @@ -61,6 +64,7 @@ def __init__(self, "a finite value for `max_iarange_nesting` arg.") self.max_iarange_nesting += 1 self.strict_enumeration_warning = strict_enumeration_warning + self.ignore_jit_warnings = ignore_jit_warnings def _vectorized_num_particles(self, fn): """ diff --git a/pyro/infer/enum.py b/pyro/infer/enum.py index 89577f23b1..c220450529 100644 --- a/pyro/infer/enum.py +++ b/pyro/infer/enum.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import numbers -import warnings from six.moves.queue import LifoQueue @@ -9,7 +8,7 @@ from pyro.infer.util import is_validation_enabled from pyro.poutine import Trace from pyro.poutine.util import prune_subsample_sites -from pyro.util import check_model_guide_match, check_site_shape +from pyro.util import check_model_guide_match, check_site_shape, ignore_jit_warnings def iter_discrete_escape(trace, msg): @@ -22,8 +21,8 @@ def iter_discrete_escape(trace, msg): def iter_discrete_extend(trace, site, **ignored): values = site["fn"].enumerate_support(expand=site["infer"].get("expand", False)) enum_total = values.shape[0] - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, message="Iterating over a tensor") + with ignore_jit_warnings(["Converting a tensor to a Python index", + ("Iterating over a tensor", RuntimeWarning)]): values = iter(values) for i, value in enumerate(values): extended_site = site.copy() diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 06c6e6531b..994a378ece 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import math -import warnings from collections import OrderedDict import torch @@ -16,7 +15,7 @@ from pyro.ops.dual_averaging import DualAveraging from pyro.ops.integrator import single_step_velocity_verlet, velocity_verlet from pyro.primitives import _Subsample -from pyro.util import torch_isinf, torch_isnan, optional +from pyro.util import torch_isinf, torch_isnan, optional, ignore_jit_warnings class HMC(TraceKernel): @@ -54,6 +53,8 @@ class HMC(TraceKernel): :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. + :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT + tracer when ``jit_compile=True``. Default is False. :param bool experimental_use_einsum: Whether to use an einsum operation to evaluate log pdf for the model trace. No-op unless the trace has discrete sample sites. This flag is experimental and will most likely @@ -88,6 +89,7 @@ def __init__(self, transforms=None, max_iarange_nesting=float("inf"), jit_compile=False, + ignore_jit_warnings=False, experimental_use_einsum=False): # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. @@ -105,6 +107,7 @@ def __init__(self, self.num_steps = max(1, int(self.trajectory_length / self.step_size)) self.adapt_step_size = adapt_step_size self._jit_compile = jit_compile + self._ignore_jit_warnings = ignore_jit_warnings self.use_einsum = experimental_use_einsum self._target_accept_prob = 0.8 # from Stan @@ -171,16 +174,8 @@ def compiled(*zi): potential_energy += transform.log_abs_det_jacobian(z_constrained[name], zi[i]).sum() return potential_energy - with pyro.validation_enabled(False), warnings.catch_warnings(): - # Ignore jit warnings about promoting Python numbers to tensors, - # assuming all numbers are constant literals. - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="torch.tensor might cause the trace to be incorrect") - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="Converting a tensor to a Python") - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="torch.tensor results are registered as constants in the trace") - self._compiled_potential_fn = torch.jit.trace(compiled, vals) + with pyro.validation_enabled(False), optional(ignore_jit_warnings(), self._ignore_jit_warnings): + self._compiled_potential_fn = torch.jit.trace(compiled, vals, check_trace=False) return self._compiled_potential_fn(*vals) def _energy(self, z, r): diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index 7ceca7a843..75a1ded9ee 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -91,6 +91,7 @@ def __init__(self, transforms=None, max_iarange_nesting=float("inf"), jit_compile=False, + ignore_jit_warnings=False, experimental_use_einsum=False): super(NUTS, self).__init__(model, step_size, @@ -98,6 +99,7 @@ def __init__(self, transforms=transforms, max_iarange_nesting=max_iarange_nesting, jit_compile=jit_compile, + ignore_jit_warnings=ignore_jit_warnings, experimental_use_einsum=experimental_use_einsum) self._max_tree_depth = 10 # from Stan diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index b1e4147029..a0c05bd9c0 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -155,12 +155,13 @@ class JitTrace_ELBO(Trace_ELBO): .. warning:: Experimental. Interface subject to change. """ + def loss_and_grads(self, model, guide, *args, **kwargs): if getattr(self, '_loss_and_surrogate_loss', None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.trace + @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings) def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py index 3715e3f340..510117fb4a 100644 --- a/pyro/infer/traceenum_elbo.py +++ b/pyro/infer/traceenum_elbo.py @@ -410,12 +410,13 @@ class JitTraceEnum_ELBO(TraceEnum_ELBO): .. warning:: Experimental. Interface subject to change. """ + def loss_and_grads(self, model, guide, *args, **kwargs): if getattr(self, '_differentiable_loss', None) is None: weakself = weakref.ref(self) - @pyro.ops.jit.trace + @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings) def differentiable_loss(*args): self = weakself() elbo = 0.0 diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 47eb8c42aa..a1f1f5317a 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -275,7 +275,7 @@ def loss_and_grads(self, model, guide, *args, **kwargs): # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) - @pyro.ops.jit.trace + @pyro.ops.jit.trace(ignore_warnings=self.ignore_jit_warnings) def loss_and_surrogate_loss(*args): self = weakself() loss = 0.0 diff --git a/pyro/ops/jit.py b/pyro/ops/jit.py index 30e43eaf93..3a91083cea 100644 --- a/pyro/ops/jit.py +++ b/pyro/ops/jit.py @@ -1,12 +1,12 @@ from __future__ import absolute_import, division, print_function -import warnings import weakref import torch import pyro import pyro.poutine as poutine +from pyro.util import ignore_jit_warnings, optional class CompiledFunction(object): @@ -19,9 +19,10 @@ class CompiledFunction(object): The actual PyTorch compilation artifact is stored in :attr:`compiled`. Call diagnostic methods on this attribute. """ - def __init__(self, fn): + def __init__(self, fn, ignore_warnings=False): self.fn = fn self.compiled = {} # len(args) -> callable + self.ignore_warnings = ignore_warnings self._param_names = None def __call__(self, *args, **kwargs): @@ -51,16 +52,7 @@ def compiled(*params_and_args): constrained_params[name] = constrained_param return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) - with pyro.validation_enabled(False), warnings.catch_warnings(): - # Ignore jit warnings about promoting Python numbers to tensors, - # assuming all numbers are constant literals. - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="torch.tensor might cause the trace to be incorrect") - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="Converting a tensor to a Python") - warnings.filterwarnings("ignore", category=torch.jit.TracerWarning, - message="torch.tensor results are registered as constants in the trace") - + with pyro.validation_enabled(False), optional(ignore_jit_warnings(), self.ignore_warnings): self.compiled[argc] = torch.jit.trace(compiled, params_and_args, check_trace=False) else: unconstrained_params = [pyro.param(name).unconstrained() @@ -79,7 +71,7 @@ def compiled(*params_and_args): return ret -def trace(fn=None): +def trace(fn=None, ignore_warnings=False): """ Lazy replacement for :func:`torch.jit.trace` that works with Pyro functions that call :func:`pyro.param`. @@ -100,5 +92,5 @@ def model_log_prob_fn(x, y): return tr.log_prob_sum() """ if fn is None: - return lambda fn: trace(fn) - return CompiledFunction(fn) + return lambda fn: trace(fn, ignore_warnings=ignore_warnings) + return CompiledFunction(fn, ignore_warnings=ignore_warnings) diff --git a/pyro/ops/sumproduct.py b/pyro/ops/sumproduct.py index 38ad671243..1bb907f629 100644 --- a/pyro/ops/sumproduct.py +++ b/pyro/ops/sumproduct.py @@ -9,6 +9,7 @@ from pyro.distributions.util import broadcast_shape from pyro.ops.einsum import contract +from pyro.util import ignore_jit_warnings def zip_align_right(xs, ys): @@ -63,8 +64,9 @@ def sumproduct(factors, target_shape=(), optimize=True, device=None): for t in factors: (numbers if isinstance(t, Number) else tensors).append(t) if not tensors: - return torch.tensor(float(reduce(operator.mul, numbers, 1.)), - device=device).expand(target_shape) + with ignore_jit_warnings(["torch.tensor results are registered as constants"]): + return torch.tensor(float(reduce(operator.mul, numbers, 1.)), + device=device).expand(target_shape) if numbers: number_part = reduce(operator.mul, numbers, 1.) tensor_part = sumproduct(tensors, target_shape, optimize=optimize) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 03f2748d5e..20ba466f23 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +from pyro.util import ignore_jit_warnings from .messenger import Messenger @@ -11,6 +12,7 @@ class BroadcastMessenger(Messenger): broadcastable with the size of the :class:`~pyro.iarange` contexts installed in the `cond_indep_stack`. """ + @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) def _pyro_sample(self, msg): """ :param msg: current message at a trace site. @@ -21,14 +23,14 @@ def _pyro_sample(self, msg): dist = msg["fn"] actual_batch_shape = getattr(dist, "batch_shape", None) if actual_batch_shape is not None: - target_batch_shape = [None if size == 1 else int(size) # int() is required by jit + target_batch_shape = [None if size == 1 else size for size in actual_batch_shape] for f in msg["cond_indep_stack"]: if f.dim is None or f.size == -1: continue assert f.dim < 0 target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape - if target_batch_shape[f.dim] not in (None, f.size): + if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size: raise ValueError("Shape mismatch inside iarange('{}') at site {} dim {}, {} vs {}".format( f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim])) target_batch_shape[f.dim] = f.size diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 691bbf4b12..306a6d5e64 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -4,6 +4,7 @@ import torch +from pyro.util import ignore_jit_warnings from .messenger import Messenger @@ -13,8 +14,9 @@ def vectorized(self): return self.dim is not None def _key(self): - size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size - return self.name, self.dim, size, self.counter + with ignore_jit_warnings(["Converting a tensor to a Python number"]): + size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size + return self.name, self.dim, size, self.counter def __eq__(self, other): return type(self) == type(other) and self._key() == other._key() diff --git a/pyro/primitives.py b/pyro/primitives.py index 1b5444ce40..45330ef244 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import copy -import numbers import warnings from collections import OrderedDict from contextlib import contextmanager @@ -15,7 +14,7 @@ from pyro.distributions.distribution import Distribution from pyro.params import param_with_module_name from pyro.poutine.runtime import _DIM_ALLOCATOR, _MODULE_NAMESPACE_DIVIDER, _PYRO_PARAM_STORE, am_i_wrapped, apply_stack -from pyro.util import deep_getattr, set_rng_seed # noqa: F401 +from pyro.util import deep_getattr, ignore_jit_warnings, torch_float, jit_compatible_arange # noqa: F401 def get_param_store(): @@ -115,19 +114,16 @@ def sample(self, sample_shape=torch.Size()): if sample_shape: raise NotImplementedError subsample_size = self.subsample_size - if subsample_size is None or subsample_size > self.size: - subsample_size = self.size - if subsample_size == self.size: - result = torch.LongTensor(list(range(self.size))) + if subsample_size is None or subsample_size >= self.size: + result = jit_compatible_arange(self.size) else: - # torch.randperm does not have a CUDA implementation - result = torch.randperm(self.size, device=torch.device('cpu'))[:self.subsample_size] + result = torch.multinomial(torch.ones(self.size), self.subsample_size, replacement=False) return result.cuda() if self.use_cuda else result def log_prob(self, x): # This is zero so that iarange can provide an unbiased estimate of # the non-subsampled log_prob. - result = torch.zeros(1) + result = torch.tensor(0.) return result.cuda() if self.use_cuda else result @@ -143,12 +139,13 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No elif subsample is None: subsample = sample(name, _Subsample(size, subsample_size, use_cuda)) - if subsample_size is None: - subsample_size = subsample.shape[0] - elif subsample is not None and subsample_size != len(subsample): - raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( - subsample_size, len(subsample)) + - " Did you accidentally use different subsample_size in the model and guide?") + with ignore_jit_warnings(): + if subsample_size is None: + subsample_size = subsample.shape[0] if torch._C._get_tracing_state() else len(subsample) + elif subsample is not None and subsample_size != len(subsample): + raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( + subsample_size, len(subsample)) + + " Did you accidentally use different subsample_size in the model and guide?") return size, subsample_size, subsample @@ -248,7 +245,7 @@ def __enter__(self): self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) if self._wrapped: try: - self._scale_messenger = poutine.scale(scale=float(self.size / self.subsample_size)) + self._scale_messenger = poutine.scale(scale=torch_float(self.size) / self.subsample_size) self._indep_messenger = poutine.indep(name=self.name, size=self.subsample_size, dim=self.dim) self._scale_messenger.__enter__() self._indep_messenger.__enter__() @@ -302,21 +299,19 @@ def __init__(self, name, size, subsample_size=None, subsample=None, use_cuda=Non self.size, self.subsample_size, self.subsample = _subsample(name, size, subsample_size, subsample, use_cuda) def __iter__(self): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, message="Iterating over a tensor") + with ignore_jit_warnings(["Converting a tensor to a Python index", + ("Iterating over a tensor", RuntimeWarning)]): subsample = iter(self.subsample) if not am_i_wrapped(): for i in subsample: - yield i if isinstance(i, numbers.Number) else i.item() + yield i else: indep_context = poutine.indep(name=self.name, size=self.subsample_size) - with poutine.scale(scale=float(self.size / self.subsample_size)): + with poutine.scale(scale=torch_float(self.size) / self.subsample_size): for i in subsample: indep_context.next_context() with indep_context: - # convert to python numeric type as functions like torch.ones(*args) - # do not work with dim 0 torch.Tensor instances. - yield i if isinstance(i, numbers.Number) else i.item() + yield i # XXX this should have the same call signature as torch.Tensor constructors diff --git a/pyro/util.py b/pyro/util.py index 6b49e0d2b5..6c8a94ff5a 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -324,6 +324,30 @@ def check_if_enumerated(guide_trace): 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) +@contextmanager +def ignore_jit_warnings(filter=None): + """ + Ignore JIT tracer warnings with messages that match `filter`. If + `filter` is not specified all tracer warnings are ignored. + + :param filter: A list containing either warning message (str), + or tuple consisting of (warning message (str), Warning class). + """ + with warnings.catch_warnings(): + if filter is None: + warnings.filterwarnings("ignore", + category=torch.jit.TracerWarning) + else: + for msg in filter: + category = torch.jit.TracerWarning + if isinstance(msg, tuple): + msg, category = msg + warnings.filterwarnings("ignore", + category=category, + message=msg) + yield + + @contextmanager def optional(context_manager, condition): """ @@ -342,3 +366,13 @@ def deep_getattr(obj, name): Throws an AttributeError if bad attribute """ return functools.reduce(getattr, name.split("."), obj) + + +# work around https://github.com/pytorch/pytorch/issues/11829 +def jit_compatible_arange(end, dtype=None, device=None): + dtype = torch.long if dtype is None else dtype + return torch.cumsum(torch.ones(end, dtype=dtype, device=device), dim=0) - 1 + + +def torch_float(x): + return x.float() if isinstance(x, torch.Tensor) else float(x) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index b95a59474f..1df03a8499 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -181,7 +181,7 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, step_size=0.0855, num_steps=4, jit_compile=jit) + hmc_kernel = HMC(model, step_size=0.0855, num_steps=4, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) beta_posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, beta_posterior.mean).item(), 0.0, prec=0.1) @@ -198,7 +198,7 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, step_size=0.02, num_steps=3, jit_compile=jit) + hmc_kernel = HMC(model, step_size=0.02, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.05) @@ -215,7 +215,7 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit) + hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) @@ -231,7 +231,7 @@ def model(data): true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit) + hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) @@ -250,7 +250,8 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, jit_compile=jit) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(posterior.mean, true_coefs).item(), 0.0, prec=0.1) @@ -269,7 +270,7 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=2, - jit_compile=jit) + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.05) @@ -338,7 +339,8 @@ def model(data): z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2. * z, 1.0).sample() hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=1, - jit_compile=jit, experimental_use_einsum=use_einsum) + jit_compile=jit, ignore_jit_warnings=True, + experimental_use_einsum=use_einsum) mcmc_run = MCMC(hmc_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="y_prob").mean assert_equal(posterior, y_prob, prec=0.05) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 35b0874818..fd3896e8bc 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -12,6 +12,7 @@ from pyro.infer.mcmc.mcmc import MCMC from pyro.infer.mcmc.nuts import NUTS import pyro.poutine as poutine +from pyro.util import ignore_jit_warnings from tests.common import assert_equal from .test_hmc import TEST_CASES, TEST_IDS, T, rmse @@ -272,14 +273,16 @@ def model(data): emission_loc = pyro.sample("emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim))) emission_scale = pyro.sample("emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim))) x = None - for t, y in enumerate(data): - x = pyro.sample("x_{}".format(t), dist.Categorical(initialize if x is None else transition[x])) - pyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) - # check shape - effective_dim = sum(1 for size in x.shape if size > 1) - assert effective_dim == 1 + with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): + for t, y in enumerate(data): + x = pyro.sample("x_{}".format(t), dist.Categorical(initialize if x is None else transition[x])) + pyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) + # check shape + effective_dim = sum(1 for size in x.shape if size > 1) + assert effective_dim == 1 data = torch.ones(num_steps) - nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=0, jit_compile=jit, + nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=0, + jit_compile=jit, ignore_jit_warnings=True, experimental_use_einsum=use_einsum) MCMC(nuts_kernel, num_samples=5, warmup_steps=5).run(data) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 7b1a5ba3b3..637abfe985 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -151,7 +151,7 @@ def gmm_model(data, verbose=False): z = pyro.sample("z_{}".format(i), dist.Bernoulli(p)) z = z.long() if verbose: - logger.debug("M{} z_{} = {}".format(" " * i, i, z.cpu().numpy())) + logger.debug("M{} z_{} = {}".format(" " * int(i), int(i), z.cpu().numpy())) pyro.sample("x_{}".format(i), dist.Normal(mus[z], scale), obs=data[i]) @@ -161,7 +161,7 @@ def gmm_guide(data, verbose=False): z = pyro.sample("z_{}".format(i), dist.Bernoulli(p)) z = z.long() if verbose: - logger.debug("G{} z_{} = {}".format(" " * i, i, z.cpu().numpy())) + logger.debug("G{} z_{} = {}".format(" " * int(i), int(i), z.cpu().numpy())) @pytest.mark.parametrize("data_size", [1, 2, 3]) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 77b5a09092..427d017b30 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -10,11 +10,12 @@ import pyro import pyro.distributions as dist import pyro.ops.jit +import pyro.poutine as poutine from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) from pyro.optim import Adam from pyro.poutine.indep_messenger import CondIndepStackFrame -from tests.common import assert_equal, xfail_param +from tests.common import assert_equal def constant(*args, **kwargs): @@ -162,6 +163,7 @@ def make_one_hot(x, i): torch.jit.trace(make_one_hot, (x, i)) +@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python integer') def test_scatter_workaround(): def make_one_hot_expected(x, i): @@ -181,6 +183,7 @@ def make_one_hot_actual(x, i): @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(), (4,), (5, 4)]) +@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python boolean') def test_bernoulli_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.empty(shape).fill_(0.25) @@ -214,6 +217,7 @@ def f(probs): @pytest.mark.parametrize('expand', [False, True]) @pytest.mark.parametrize('shape', [(3,), (4, 3), (5, 4, 3)]) +@pytest.mark.filterwarnings('ignore:Converting a tensor to a Python integer') def test_one_hot_categorical_enumerate(shape, expand): shape = torch.Size(shape) probs = torch.ones(shape) @@ -285,7 +289,8 @@ def guide(): outer_particles = num_particles // inner_particles elbo = Elbo(max_iarange_nesting=0, strict_enumeration_warning=any([enumerate1, enumerate2]), - num_particles=inner_particles) + num_particles=inner_particles, + ignore_jit_warnings=True) actual_loss = sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles actual_grad = q.unconstrained().grad / outer_particles @@ -329,7 +334,7 @@ def guide(data): constraint=constraints.positive) pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) - elbo = Elbo(num_particles=7, strict_enumeration_warning=False) + elbo = Elbo(num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): @@ -338,15 +343,16 @@ def guide(data): @pytest.mark.parametrize('Elbo', [ Trace_ELBO, - xfail_param(JitTrace_ELBO, reason="https://github.com/uber/pyro/issues/1358"), + JitTrace_ELBO, TraceGraph_ELBO, - xfail_param(JitTraceGraph_ELBO, reason="https://github.com/uber/pyro/issues/1358"), + JitTraceGraph_ELBO, TraceEnum_ELBO, - xfail_param(JitTraceEnum_ELBO, reason="https://github.com/uber/pyro/issues/1358"), + JitTraceEnum_ELBO, ]) def test_svi_irregular_batch_size(Elbo): pyro.clear_param_store() + @poutine.broadcast def model(data): loc = pyro.param("loc", constant(0.0)) scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) @@ -390,7 +396,7 @@ def guide(data): constraint=constraints.positive) pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) - elbo = Elbo(num_particles=7, strict_enumeration_warning=False) + elbo = Elbo(num_particles=7, strict_enumeration_warning=False, ignore_jit_warnings=True) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): @@ -405,3 +411,12 @@ def test_cond_indep_equality(x, y): assert x == y assert not x != y assert hash(x) == hash(y) + + +def test_jit_arange_workaround(): + def fn(x): + y = torch.ones(x.shape[0], dtype=torch.long, device=x.device) + return torch.cumsum(y, 0) - 1 + + compiled = torch.jit.trace(fn, torch.ones(3)) + assert_equal(compiled(torch.ones(10)), torch.arange(10)) From 021707a2471555d269a5da1baa2cb155643175ca Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 19 Sep 2018 17:00:50 -0700 Subject: [PATCH 067/157] Remove deprecated new_tensor invocation --- examples/bayesian_regression.py | 8 ++++---- examples/contrib/oed/gp_bayes_opt.py | 3 +-- pyro/contrib/gp/models/sgpr.py | 2 +- pyro/distributions/delta.py | 2 +- pyro/distributions/iaf.py | 2 +- pyro/distributions/testing/rejection_gamma.py | 2 +- pyro/distributions/torch.py | 2 +- pyro/ops/newton.py | 4 ++-- tests/contrib/gp/test_likelihoods.py | 4 ++-- tests/contrib/gp/test_models.py | 4 ++-- 10 files changed, 16 insertions(+), 17 deletions(-) diff --git a/examples/bayesian_regression.py b/examples/bayesian_regression.py index 2ca58da1df..e258ad42b0 100644 --- a/examples/bayesian_regression.py +++ b/examples/bayesian_regression.py @@ -75,10 +75,10 @@ def model(data): def guide(data): - w_loc = data.new_tensor(torch.randn(1, p)) - w_log_sig = data.new_tensor(-3.0 * torch.ones(1, p) + 0.05 * torch.randn(1, p)) - b_loc = data.new_tensor(torch.randn(1)) - b_log_sig = data.new_tensor(-3.0 * torch.ones(1) + 0.05 * torch.randn(1)) + w_loc = torch.randn(1, p, dtype=data.dtype, device=data.device) + w_log_sig = -3 + 0.05 * torch.randn(1, p, dtype=data.dtype, device=data.device) + b_loc = torch.randn(1, dtype=data.dtype, device=data.device) + b_log_sig = -3 + 0.05 * torch.randn(1, dtype=data.dtype, device=data.device) # register learnable params in the param store mw_param = pyro.param("guide_mean_weight", w_loc) sw_param = softplus(pyro.param("guide_log_scale_weight", w_log_sig)) diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 6eb10953c3..0ec356c5db 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -46,8 +46,7 @@ def find_a_candidate(self, differentiable, x_init): """ # transform x to an unconstrained domain unconstrained_x_init = transform_to(self.constraints).inv(x_init) - unconstrained_x = unconstrained_x_init.new_tensor( - unconstrained_x_init, requires_grad=True) + unconstrained_x = unconstrained_x_init.detach().clone().requires_grad_(True) # TODO: Use LBFGS with line search by pytorch #8824 merged minimizer = optim.LBFGS([unconstrained_x], max_eval=20) diff --git a/pyro/contrib/gp/models/sgpr.py b/pyro/contrib/gp/models/sgpr.py index a402632060..0d75db6730 100644 --- a/pyro/contrib/gp/models/sgpr.py +++ b/pyro/contrib/gp/models/sgpr.py @@ -152,7 +152,7 @@ def model(self): if self.approx == "VFE": trace_term_name = param_with_module_name(self.name, "trace_term") pyro.sample(trace_term_name, dist.Bernoulli(probs=torch.exp(-trace_term / 2.)), - obs=trace_term.new_tensor(1.)) + obs=torch.tensor(1., dtype=trace_term.dtype, device=trace_term.device)) y_name = param_with_module_name(self.name, "y") return pyro.sample(y_name, diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index 3dab0e009c..55b9c26169 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -55,7 +55,7 @@ def rsample(self, sample_shape=torch.Size()): def log_prob(self, x): v = self.v.expand(self.shape()) - log_prob = x.new_tensor(x == v).log() + log_prob = (x == v).type(x.dtype).log().to(x.device) log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index 23ffc19824..9d679d7860 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -79,7 +79,7 @@ def _call(self, x): sample from the base distribution (or the output of a previous flow) """ mean, scale = self.module.arn(x) - scale = self.module.sigmoid(scale + scale.new_tensor(self.module.sigmoid_bias)) + scale = self.module.sigmoid(scale + self.module.sigmoid_bias.type(x.dtype).to(x.device)) y = scale * x + (1 - scale) * mean self._add_intermediate_to_cache(x, y, 'x') diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index f815e0e323..8d069da17d 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -18,7 +18,7 @@ def __init__(self, concentration): if concentration.data.min() < 1: raise NotImplementedError('concentration < 1 is not supported') self.concentration = concentration - self._standard_gamma = Gamma(concentration, concentration.new_tensor([1.]).squeeze().expand_as(concentration)) + self._standard_gamma = Gamma(concentration, concentration.new([1.]).squeeze().expand_as(concentration)) # The following are Marsaglia & Tsang's variable names. self._d = self.concentration - 1.0 / 3.0 self._c = 1.0 / torch.sqrt(9.0 * self._d) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index ea82e5504b..4590161ecc 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -20,7 +20,7 @@ def expand(self, batch_shape): return type(self)(logits=logits, validate_args=validate_args) def enumerate_support(self, expand=True): - values = self._param.new_tensor([0., 1.]) + values = self._param.new([0., 1.]) values = values.reshape((2,) + (1,) * len(self.batch_shape)) if expand: values = values.expand((2,) + self.batch_shape) diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index ead58e9131..2e0fa6b44e 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -150,7 +150,7 @@ def newton_step_2d(loss, x, trust_radius=None): min_eig = mean_eig - (mean_eig ** 2 - detH).clamp(min=0).sqrt() regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) warn_if_nan(regularizer, 'regularizer') - H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * H.new_tensor(torch.eye(2)) + H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(2, dtype=H.dtype, device=H.device) # compute newton update Hinv = rinverse(H, sym=True) @@ -201,7 +201,7 @@ def newton_step_3d(loss, x, trust_radius=None): min_eig, _, _ = eig_3d(H) regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) warn_if_nan(regularizer, 'regularizer') - H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * H.new_tensor(torch.eye(3)) + H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(3, dtype=H.dtype, device=H.device) # compute newton update Hinv = rinverse(H, sym=True) diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index 303ca6f3f4..5aae016995 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -120,7 +120,7 @@ def test_forward(model_class, X, y, kernel, likelihood): gp = model_class(X, y, kernel, likelihood, latent_shape=latent_shape) Xnew_shape = (X.shape[0] * 2,) + X.shape[1:] - Xnew = X.new_tensor(torch.rand(Xnew_shape)) + Xnew = torch.rand(Xnew_shape, dtype=X.dtype, device=X.device) f_loc, f_var = gp(Xnew) ynew = gp.likelihood(f_loc, f_var) @@ -139,7 +139,7 @@ def test_forward_with_empty_latent_shape(model_class, X, y, kernel, likelihood): gp = model_class(X, y, kernel, likelihood, latent_shape=latent_shape) Xnew_shape = (X.shape[0] * 2,) + X.shape[1:] - Xnew = X.new_tensor(torch.rand(Xnew_shape)) + Xnew = torch.rand(Xnew_shape, dtype=X.dtype, device=X.device) f_loc, f_var = gp(Xnew) ynew = gp.likelihood(f_loc, f_var) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index 6d2fddba52..52b81c6183 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -331,9 +331,9 @@ def f(x): return 2 * x + 3 + 5 * torch.sin(7 * x) tensor_holder = torch.tensor([]) - X = tensor_holder.new_tensor(torch.arange(100.)) + X = tensor_holder.new_tensor(range(100)) y = f(X) - Xnew = tensor_holder.new_tensor(torch.arange(100., 150.)) + Xnew = tensor_holder.new_tensor(range(100, 150)) ynew = f(Xnew) kernel = Cosine(input_dim=1) From a36f25ad57dbd21920af0c160fcf0625d6677c8a Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 19 Sep 2018 17:00:50 -0700 Subject: [PATCH 068/157] Remove deprecated new_tensor invocation --- examples/bayesian_regression.py | 8 ++++---- examples/contrib/oed/gp_bayes_opt.py | 3 +-- pyro/contrib/gp/models/sgpr.py | 2 +- pyro/distributions/delta.py | 2 +- pyro/distributions/iaf.py | 2 +- pyro/distributions/testing/rejection_gamma.py | 2 +- pyro/ops/newton.py | 4 ++-- pyro/primitives.py | 3 ++- tests/contrib/gp/test_likelihoods.py | 4 ++-- tests/contrib/gp/test_models.py | 4 ++-- 10 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/bayesian_regression.py b/examples/bayesian_regression.py index 2ca58da1df..e258ad42b0 100644 --- a/examples/bayesian_regression.py +++ b/examples/bayesian_regression.py @@ -75,10 +75,10 @@ def model(data): def guide(data): - w_loc = data.new_tensor(torch.randn(1, p)) - w_log_sig = data.new_tensor(-3.0 * torch.ones(1, p) + 0.05 * torch.randn(1, p)) - b_loc = data.new_tensor(torch.randn(1)) - b_log_sig = data.new_tensor(-3.0 * torch.ones(1) + 0.05 * torch.randn(1)) + w_loc = torch.randn(1, p, dtype=data.dtype, device=data.device) + w_log_sig = -3 + 0.05 * torch.randn(1, p, dtype=data.dtype, device=data.device) + b_loc = torch.randn(1, dtype=data.dtype, device=data.device) + b_log_sig = -3 + 0.05 * torch.randn(1, dtype=data.dtype, device=data.device) # register learnable params in the param store mw_param = pyro.param("guide_mean_weight", w_loc) sw_param = softplus(pyro.param("guide_log_scale_weight", w_log_sig)) diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index 6eb10953c3..0ec356c5db 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -46,8 +46,7 @@ def find_a_candidate(self, differentiable, x_init): """ # transform x to an unconstrained domain unconstrained_x_init = transform_to(self.constraints).inv(x_init) - unconstrained_x = unconstrained_x_init.new_tensor( - unconstrained_x_init, requires_grad=True) + unconstrained_x = unconstrained_x_init.detach().clone().requires_grad_(True) # TODO: Use LBFGS with line search by pytorch #8824 merged minimizer = optim.LBFGS([unconstrained_x], max_eval=20) diff --git a/pyro/contrib/gp/models/sgpr.py b/pyro/contrib/gp/models/sgpr.py index a402632060..0d75db6730 100644 --- a/pyro/contrib/gp/models/sgpr.py +++ b/pyro/contrib/gp/models/sgpr.py @@ -152,7 +152,7 @@ def model(self): if self.approx == "VFE": trace_term_name = param_with_module_name(self.name, "trace_term") pyro.sample(trace_term_name, dist.Bernoulli(probs=torch.exp(-trace_term / 2.)), - obs=trace_term.new_tensor(1.)) + obs=torch.tensor(1., dtype=trace_term.dtype, device=trace_term.device)) y_name = param_with_module_name(self.name, "y") return pyro.sample(y_name, diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index 0960725958..4b541d59ac 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -57,7 +57,7 @@ def rsample(self, sample_shape=torch.Size()): def log_prob(self, x): v = self.v.expand(self.shape()) - log_prob = (x == v).type_as(v).log() + log_prob = (x == v).type(x.dtype).log().to(x.device) log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index 23ffc19824..9d679d7860 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -79,7 +79,7 @@ def _call(self, x): sample from the base distribution (or the output of a previous flow) """ mean, scale = self.module.arn(x) - scale = self.module.sigmoid(scale + scale.new_tensor(self.module.sigmoid_bias)) + scale = self.module.sigmoid(scale + self.module.sigmoid_bias.type(x.dtype).to(x.device)) y = scale * x + (1 - scale) * mean self._add_intermediate_to_cache(x, y, 'x') diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index dde109a94e..137c39334a 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -18,7 +18,7 @@ def __init__(self, concentration): if concentration.data.min() < 1: raise NotImplementedError('concentration < 1 is not supported') self.concentration = concentration - self._standard_gamma = Gamma(concentration, concentration.new_tensor([1.]).squeeze().expand_as(concentration)) + self._standard_gamma = Gamma(concentration, concentration.new([1.]).squeeze().expand_as(concentration)) # The following are Marsaglia & Tsang's variable names. self._d = self.concentration - 1.0 / 3.0 self._c = 1.0 / torch.sqrt(9.0 * self._d) diff --git a/pyro/ops/newton.py b/pyro/ops/newton.py index ead58e9131..2e0fa6b44e 100644 --- a/pyro/ops/newton.py +++ b/pyro/ops/newton.py @@ -150,7 +150,7 @@ def newton_step_2d(loss, x, trust_radius=None): min_eig = mean_eig - (mean_eig ** 2 - detH).clamp(min=0).sqrt() regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) warn_if_nan(regularizer, 'regularizer') - H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * H.new_tensor(torch.eye(2)) + H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(2, dtype=H.dtype, device=H.device) # compute newton update Hinv = rinverse(H, sym=True) @@ -201,7 +201,7 @@ def newton_step_3d(loss, x, trust_radius=None): min_eig, _, _ = eig_3d(H) regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(min=1e-8) warn_if_nan(regularizer, 'regularizer') - H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * H.new_tensor(torch.eye(3)) + H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(3, dtype=H.dtype, device=H.device) # compute newton update Hinv = rinverse(H, sym=True) diff --git a/pyro/primitives.py b/pyro/primitives.py index 1ac791cc30..963a0e2a9f 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -107,7 +107,8 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None): if self.use_cuda ^ (device != "cpu"): raise ValueError("Incompatible arg values use_cuda={}, device={}." .format(use_cuda, device)) - self.device = torch.Tensor().device if not device else device + with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): + self.device = torch.Tensor().device if not device else device def sample(self, sample_shape=torch.Size()): """ diff --git a/tests/contrib/gp/test_likelihoods.py b/tests/contrib/gp/test_likelihoods.py index 303ca6f3f4..5aae016995 100644 --- a/tests/contrib/gp/test_likelihoods.py +++ b/tests/contrib/gp/test_likelihoods.py @@ -120,7 +120,7 @@ def test_forward(model_class, X, y, kernel, likelihood): gp = model_class(X, y, kernel, likelihood, latent_shape=latent_shape) Xnew_shape = (X.shape[0] * 2,) + X.shape[1:] - Xnew = X.new_tensor(torch.rand(Xnew_shape)) + Xnew = torch.rand(Xnew_shape, dtype=X.dtype, device=X.device) f_loc, f_var = gp(Xnew) ynew = gp.likelihood(f_loc, f_var) @@ -139,7 +139,7 @@ def test_forward_with_empty_latent_shape(model_class, X, y, kernel, likelihood): gp = model_class(X, y, kernel, likelihood, latent_shape=latent_shape) Xnew_shape = (X.shape[0] * 2,) + X.shape[1:] - Xnew = X.new_tensor(torch.rand(Xnew_shape)) + Xnew = torch.rand(Xnew_shape, dtype=X.dtype, device=X.device) f_loc, f_var = gp(Xnew) ynew = gp.likelihood(f_loc, f_var) diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index 6d2fddba52..52b81c6183 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -331,9 +331,9 @@ def f(x): return 2 * x + 3 + 5 * torch.sin(7 * x) tensor_holder = torch.tensor([]) - X = tensor_holder.new_tensor(torch.arange(100.)) + X = tensor_holder.new_tensor(range(100)) y = f(X) - Xnew = tensor_holder.new_tensor(torch.arange(100., 150.)) + Xnew = tensor_holder.new_tensor(range(100, 150)) ynew = f(Xnew) kernel = Cosine(input_dim=1) From 693fcd9046c230d9b01fd66cd14330bb3faf9dac Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 19 Sep 2018 18:00:51 -0700 Subject: [PATCH 069/157] remove .new --- pyro/distributions/testing/rejection_gamma.py | 2 +- pyro/distributions/torch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index 8d069da17d..f815e0e323 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -18,7 +18,7 @@ def __init__(self, concentration): if concentration.data.min() < 1: raise NotImplementedError('concentration < 1 is not supported') self.concentration = concentration - self._standard_gamma = Gamma(concentration, concentration.new([1.]).squeeze().expand_as(concentration)) + self._standard_gamma = Gamma(concentration, concentration.new_tensor([1.]).squeeze().expand_as(concentration)) # The following are Marsaglia & Tsang's variable names. self._d = self.concentration - 1.0 / 3.0 self._c = 1.0 / torch.sqrt(9.0 * self._d) diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 4590161ecc..ea82e5504b 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -20,7 +20,7 @@ def expand(self, batch_shape): return type(self)(logits=logits, validate_args=validate_args) def enumerate_support(self, expand=True): - values = self._param.new([0., 1.]) + values = self._param.new_tensor([0., 1.]) values = values.reshape((2,) + (1,) * len(self.batch_shape)) if expand: values = values.expand((2,) + self.batch_shape) From 31290592ff73dc9c440bf0c37f99b463ab8599de Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 08:00:21 -0700 Subject: [PATCH 070/157] address comments --- pyro/distributions/delta.py | 2 +- pyro/distributions/iaf.py | 2 +- tests/contrib/gp/test_models.py | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index 55b9c26169..a4d1c7bea7 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -55,7 +55,7 @@ def rsample(self, sample_shape=torch.Size()): def log_prob(self, x): v = self.v.expand(self.shape()) - log_prob = (x == v).type(x.dtype).log().to(x.device) + log_prob = (x == v).type(x.dtype).log() log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density diff --git a/pyro/distributions/iaf.py b/pyro/distributions/iaf.py index 9d679d7860..e6fe19de77 100644 --- a/pyro/distributions/iaf.py +++ b/pyro/distributions/iaf.py @@ -79,7 +79,7 @@ def _call(self, x): sample from the base distribution (or the output of a previous flow) """ mean, scale = self.module.arn(x) - scale = self.module.sigmoid(scale + self.module.sigmoid_bias.type(x.dtype).to(x.device)) + scale = self.module.sigmoid(scale + self.module.sigmoid_bias.type(x.dtype)) y = scale * x + (1 - scale) * mean self._add_intermediate_to_cache(x, y, 'x') diff --git a/tests/contrib/gp/test_models.py b/tests/contrib/gp/test_models.py index 52b81c6183..f6fd8c7e40 100644 --- a/tests/contrib/gp/test_models.py +++ b/tests/contrib/gp/test_models.py @@ -330,10 +330,9 @@ def _pre_test_mean_function(): def f(x): return 2 * x + 3 + 5 * torch.sin(7 * x) - tensor_holder = torch.tensor([]) - X = tensor_holder.new_tensor(range(100)) + X = torch.arange(100, dtype=torch.Tensor().dtype) y = f(X) - Xnew = tensor_holder.new_tensor(range(100, 150)) + Xnew = torch.arange(100, 150, dtype=torch.Tensor().dtype) ynew = f(Xnew) kernel = Cosine(input_dim=1) From a648968cecf9b759f469cfa0d614bf48a55bfb8f Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 11:16:57 -0700 Subject: [PATCH 071/157] fix test_hessian --- tests/contrib/autoguide/test_hessian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contrib/autoguide/test_hessian.py b/tests/contrib/autoguide/test_hessian.py index 7e3f255490..f26a124db0 100644 --- a/tests/contrib/autoguide/test_hessian.py +++ b/tests/contrib/autoguide/test_hessian.py @@ -9,7 +9,7 @@ def test_mvn(): tmp = torch.randn(3, 10) - cov = torch.tensor(torch.matmul(tmp, tmp.t())) + cov = torch.matmul(tmp, tmp.t()) mvn = dist.MultivariateNormal(cov.new_zeros(3), cov) x = torch.randn(3, requires_grad=True) From c1c9e82961ca9d7eb7719c607338f3b12fb80285 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 13:50:31 -0700 Subject: [PATCH 072/157] fix more tests --- pyro/distributions/omt_mvn.py | 2 +- tests/distributions/test_omt_mvn.py | 6 +++--- tests/infer/test_enum.py | 2 +- tests/infer/test_inference.py | 19 ++++++------------- 4 files changed, 11 insertions(+), 18 deletions(-) diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index 8fed96f85f..8265e52dd5 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -51,7 +51,7 @@ def backward(ctx, grad_output): g = grad_output loc_grad = sum_leftmost(grad_output, -1) - identity = torch.eye(dim, out=torch.tensor(g.new_empty(dim, dim))) + identity = torch.eye(dim, out=g.new_empty(dim, dim)) R_inv = torch.trtrs(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) diff --git a/tests/distributions/test_omt_mvn.py b/tests/distributions/test_omt_mvn.py index f8f41e40f3..839521af45 100644 --- a/tests/distributions/test_omt_mvn.py +++ b/tests/distributions/test_omt_mvn.py @@ -35,7 +35,7 @@ def test_mean_gradient(mvn_dist, k, sample_shape, L21, omega1, L11, L22=0.8, L33 if mvn_dist == 'OMTMultivariateNormal': dist = OMTMultivariateNormal(loc, L) elif mvn_dist == 'AVFMultivariateNormal': - CV = torch.tensor(1.1 * torch.rand(2, k, 3), requires_grad=True) + CV = (1.1 * torch.rand(2, k, 3)).requires_grad_(True) dist = AVFMultivariateNormal(loc, L, CV) z = dist.rsample(sample_shape) @@ -67,7 +67,7 @@ def test_mean_single_gradient(mvn_dist, k, L21, omega1, L11, L22=0.8, L33=0.9, o if mvn_dist == 'OMTMultivariateNormal': dist = OMTMultivariateNormal(loc, L) elif mvn_dist == 'AVFMultivariateNormal': - CV = torch.tensor(0.2 * torch.rand(2, k, 3), requires_grad=True) + CV = (0.2 * torch.rand(2, k, 3)).requires_grad_(True) dist = AVFMultivariateNormal(loc, L, CV) computed_grads = [] @@ -103,6 +103,6 @@ def test_log_prob(mvn_dist): if mvn_dist == OMTMultivariateNormal: mvn_prime = OMTMultivariateNormal(loc, L) elif mvn_dist == AVFMultivariateNormal: - CV = torch.tensor(0.2 * torch.rand(2, 2, 5)) + CV = 0.2 * torch.rand(2, 2, 5) mvn_prime = AVFMultivariateNormal(loc, L, CV) assert_equal(mvn.log_prob(x), mvn_prime.log_prob(x)) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 7b1a5ba3b3..24faac30f1 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -193,7 +193,7 @@ def gmm_batch_model(data): def gmm_batch_guide(data): with pyro.iarange("data", len(data)) as batch: n = len(batch) - probs = pyro.param("probs", torch.tensor(torch.ones(n, 1) * 0.6, requires_grad=True)) + probs = pyro.param("probs", torch.ones(n, 1) * 0.6) probs = torch.cat([probs, 1 - probs], dim=1) z = pyro.sample("z", dist.OneHotCategorical(probs)) assert z.shape[-1] == 2 diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index 91efdb8399..448b524529 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -72,11 +72,8 @@ def model(): return loc_latent def guide(): - loc_q = pyro.param("loc_q", torch.tensor(self.analytic_loc_n.data + 0.134 * torch.ones(2), - requires_grad=True)) - log_sig_q = pyro.param("log_sig_q", torch.tensor( - self.analytic_log_sig_n.data - 0.14 * torch.ones(2), - requires_grad=True)) + loc_q = pyro.param("loc_q", (self.analytic_loc_n.detach().clone() + 0.134)) + log_sig_q = pyro.param("log_sig_q", (self.analytic_log_sig_n.data.detach().clone() - 0.14)) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("loc_latent", Normal(loc_q, sig_q).independent(1)) @@ -107,11 +104,9 @@ def do_test_fixedness(self, fixed_parts): def model(): alpha_p_log = pyro.param( - "alpha_p_log", torch.tensor( - self.alpha_p_log_0.clone())) + "alpha_p_log", self.alpha_p_log_0.clone()) beta_p_log = pyro.param( - "beta_p_log", torch.tensor( - self.beta_p_log_0.clone())) + "beta_p_log", self.beta_p_log_0.clone()) alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log) lambda_latent = pyro.sample("lambda_latent", dist.Gamma(alpha_p, beta_p)) pyro.sample("obs", dist.Poisson(lambda_latent), obs=self.data) @@ -119,11 +114,9 @@ def model(): def guide(): alpha_q_log = pyro.param( - "alpha_q_log", torch.tensor( - self.alpha_q_log_0.clone())) + "alpha_q_log", self.alpha_q_log_0.clone()) beta_q_log = pyro.param( - "beta_q_log", torch.tensor( - self.beta_q_log_0.clone())) + "beta_q_log", self.beta_q_log_0.clone()) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", dist.Gamma(alpha_q, beta_q)) From edf000ec6bb9b00b75e0919abf282c8e126cd32f Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 13:52:57 -0700 Subject: [PATCH 073/157] remove redundant parens --- tests/infer/test_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index 448b524529..2ae0fa78da 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -72,8 +72,8 @@ def model(): return loc_latent def guide(): - loc_q = pyro.param("loc_q", (self.analytic_loc_n.detach().clone() + 0.134)) - log_sig_q = pyro.param("log_sig_q", (self.analytic_log_sig_n.data.detach().clone() - 0.14)) + loc_q = pyro.param("loc_q", self.analytic_loc_n.detach().clone() + 0.134) + log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.data.detach().clone() - 0.14) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("loc_latent", Normal(loc_q, sig_q).independent(1)) From c2d3de861f983996ed0c5b5bddc30909b0ad38c1 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 14:01:00 -0700 Subject: [PATCH 074/157] fix test_elbo_mapdata --- tests/infer/test_elbo_mapdata.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/infer/test_elbo_mapdata.py b/tests/infer/test_elbo_mapdata.py index 47783d66ed..912da1818b 100644 --- a/tests/infer/test_elbo_mapdata.py +++ b/tests/infer/test_elbo_mapdata.py @@ -71,10 +71,8 @@ def model(): return loc_latent def guide(): - loc_q = pyro.param("loc_q", torch.tensor( - analytic_loc_n.data + torch.tensor([-0.18, 0.23]), requires_grad=True)) - log_sig_q = pyro.param("log_sig_q", torch.tensor( - analytic_log_sig_n.data - torch.tensor([-0.18, 0.23]), requires_grad=True)) + loc_q = pyro.param("loc_q", analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23])) + log_sig_q = pyro.param("log_sig_q", analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23])) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", dist.Normal(loc_q, sig_q).independent(1)) if map_type == "irange" or map_type is None: From bf85894190caf6f6eeca00069d92dcbb9b579984 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 14:06:51 -0700 Subject: [PATCH 075/157] fix test_conj_gaussian --- tests/infer/test_inference.py | 4 ++-- .../test_conjugate_gaussian_models.py | 15 +++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index 2ae0fa78da..78cd6998c0 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -72,8 +72,8 @@ def model(): return loc_latent def guide(): - loc_q = pyro.param("loc_q", self.analytic_loc_n.detach().clone() + 0.134) - log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.data.detach().clone() - 0.14) + loc_q = pyro.param("loc_q", self.analytic_loc_n.detach() + 0.134) + log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.data.detach() - 0.14) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("loc_latent", Normal(loc_q, sig_q).independent(1)) diff --git a/tests/integration_tests/test_conjugate_gaussian_models.py b/tests/integration_tests/test_conjugate_gaussian_models.py index c846bd999f..c58d19d6b3 100644 --- a/tests/integration_tests/test_conjugate_gaussian_models.py +++ b/tests/integration_tests/test_conjugate_gaussian_models.py @@ -82,19 +82,14 @@ def model(self, reparameterized, difficulty=0.0): def guide(self, reparameterized, difficulty=0.0): previous_sample = None for k in reversed(range(1, self.N + 1)): - loc_q = pyro.param("loc_q_%d" % k, torch.tensor(self.target_mus[k].data + - difficulty * (0.1 * torch.randn(1) - 0.53), - requires_grad=True)) - log_sig_q = pyro.param("log_sig_q_%d" % k, - torch.tensor(-0.5 * torch.log(self.lambda_posts[k]).data + - difficulty * (0.1 * torch.randn(1) - 0.53), - requires_grad=True)) + loc_q = pyro.param("loc_q_%d" % k, self.target_mus[k].detach() + difficulty * (0.1 * torch.randn(1) - 0.53)) + log_sig_q = pyro.param("log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k]).data + + difficulty * (0.1 * torch.randn(1) - 0.53)) sig_q = torch.exp(log_sig_q) kappa_q = None if k != self.N: - kappa_q = pyro.param("kappa_q_%d" % k, torch.tensor(self.target_kappas[k].data + - difficulty * (0.1 * torch.randn(1) - 0.53), - requires_grad=True)) + kappa_q = pyro.param("kappa_q_%d" % k, self.target_kappas[k].data + + difficulty * (0.1 * torch.randn(1) - 0.53)) mean_function = loc_q if k == self.N else kappa_q * previous_sample + loc_q node_flagged = True if self.which_nodes_reparam[k - 1] == 1.0 else False Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal From c6dd8d7ce070e91b0da2e7201010bb6776b2c0c0 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 14:21:19 -0700 Subject: [PATCH 076/157] fix test_valid_models --- tests/infer/test_valid_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/infer/test_valid_models.py b/tests/infer/test_valid_models.py index be7e584411..c23816225c 100644 --- a/tests/infer/test_valid_models.py +++ b/tests/infer/test_valid_models.py @@ -689,8 +689,8 @@ def test_enum_discrete_parallel_nested_ok(max_iarange_nesting): iarange_shape = torch.Size([1] * max_iarange_nesting) def model(): - p2 = torch.tensor(torch.ones(2) / 2) - p3 = torch.tensor(torch.ones(3) / 3) + p2 = torch.ones(2) / 2 + p3 = torch.ones(3) / 3 x2 = pyro.sample("x2", dist.OneHotCategorical(p2)) x3 = pyro.sample("x3", dist.OneHotCategorical(p3)) assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape From d4ff53c812d4f52122039284f5e8c08157027ddd Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 15:02:04 -0700 Subject: [PATCH 077/157] fix dist tests --- tests/distributions/test_gaussian_mixtures.py | 15 ++++++------- tests/distributions/test_rejector.py | 22 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index 4dba284055..35d3e7ca0f 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -25,23 +25,22 @@ def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): if mix_dist == GaussianScaleMixture: locs = torch.zeros(K, D, requires_grad=True) else: - locs = torch.tensor(torch.rand(K, D), requires_grad=True) + locs = torch.rand(K, D).requires_grad_(True) if mix_dist == GaussianScaleMixture: - component_scale = 1.5 * torch.ones(K) + 0.5 * torch.rand(K) - component_scale = torch.tensor(component_scale, requires_grad=True) + component_scale = 1.5 * torch.ones(K) + 0.5 * torch.rand(K).requires_grad_(True) else: component_scale = torch.ones(K, requires_grad=True) if mix_dist == MixtureOfDiagNormals: coord_scale = torch.ones(K, D) + 0.5 * torch.rand(K, D) - coord_scale = torch.tensor(coord_scale, requires_grad=True) + coord_scale.requires_grad_(True) else: coord_scale = torch.ones(D) + 0.5 * torch.rand(D) - coord_scale = torch.tensor(coord_scale, requires_grad=True) + coord_scale.requires_grad_(True) if not flat_logits: - component_logits = torch.tensor(1.5 * torch.rand(K), requires_grad=True) + component_logits = (1.5 * torch.rand(K)).requires_grad_(True) else: - component_logits = torch.tensor(0.1 * torch.rand(K), requires_grad=True) - omega = torch.tensor(0.2 * torch.ones(D) + 0.1 * torch.rand(D), requires_grad=False) + component_logits = (0.1 * torch.rand(K)).requires_grad_(True) + omega = (0.2 * torch.ones(D) + 0.1 * torch.rand(D)).requires_grad_(False) _pis = torch.exp(component_logits) pis = _pis / _pis.sum() diff --git a/tests/distributions/test_rejector.py b/tests/distributions/test_rejector.py index 9265957b5c..6baf363a9b 100644 --- a/tests/distributions/test_rejector.py +++ b/tests/distributions/test_rejector.py @@ -45,8 +45,8 @@ def compute_elbo_grad(model, guide, variables): @pytest.mark.parametrize('factor', [0.25, 0.5, 1.0]) def test_rejector(rate, factor): num_samples = 100000 - rates = torch.tensor(torch.tensor(rate).expand(num_samples, 1), requires_grad=True) - factors = torch.tensor(torch.tensor(factor).expand(num_samples, 1), requires_grad=True) + rates = torch.tensor(rate).expand(num_samples, 1).requires_grad_(True) + factors = torch.tensor(factor).expand(num_samples, 1).requires_grad_(True) dist1 = Exponential(rates) dist2 = RejectionExponential(rates, factors) # implemented using Rejector @@ -61,8 +61,8 @@ def test_rejector(rate, factor): @pytest.mark.parametrize('factor', [0.25, 0.5, 1.0]) def test_exponential_elbo(rate, factor): num_samples = 100000 - rates = torch.tensor(torch.tensor(rate).expand(num_samples, 1), requires_grad=True) - factors = torch.tensor(torch.tensor(factor).expand(num_samples, 1), requires_grad=True) + rates = torch.tensor(rate).expand(num_samples, 1).requires_grad_(True) + factors = torch.tensor(factor).expand(num_samples, 1).requires_grad_(True) model = Exponential(torch.ones(num_samples, 1)) guide1 = Exponential(rates) @@ -81,7 +81,7 @@ def test_exponential_elbo(rate, factor): @pytest.mark.parametrize('alpha', [1.0, 2.0, 5.0]) def test_standard_gamma_elbo(alpha): num_samples = 100000 - alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True) + alphas = torch.tensor(alpha).expand(num_samples, 1).requires_grad_(True) betas = torch.ones(num_samples, 1) model = Gamma(torch.ones(num_samples, 1), betas) @@ -99,8 +99,8 @@ def test_standard_gamma_elbo(alpha): @pytest.mark.parametrize('beta', [0.2, 0.5, 1.0, 2.0, 5.0]) def test_gamma_elbo(alpha, beta): num_samples = 100000 - alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True) - betas = torch.tensor(torch.tensor(beta).expand(num_samples, 1), requires_grad=True) + alphas = torch.tensor(alpha).expand(num_samples, 1).requires_grad_(True) + betas = torch.tensor(beta).expand(num_samples, 1).requires_grad_(True) model = Gamma(torch.ones(num_samples, 1), torch.ones(num_samples, 1)) guide1 = Gamma(alphas, betas) @@ -121,8 +121,8 @@ def test_gamma_elbo(alpha, beta): @pytest.mark.parametrize('beta', [0.2, 0.5, 1.0, 2.0, 5.0]) def test_shape_augmented_gamma_elbo(alpha, beta): num_samples = 100000 - alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True) - betas = torch.tensor(torch.tensor(beta).expand(num_samples, 1), requires_grad=True) + alphas = torch.tensor(alpha).expand(num_samples, 1).requires_grad_(True) + betas = torch.tensor(beta).expand(num_samples, 1).requires_grad_(True) model = Gamma(torch.ones(num_samples, 1), torch.ones(num_samples, 1)) guide1 = Gamma(alphas, betas) @@ -143,8 +143,8 @@ def test_shape_augmented_gamma_elbo(alpha, beta): @pytest.mark.parametrize('beta', [0.5, 1.0, 4.0]) def test_shape_augmented_beta(alpha, beta): num_samples = 10000 - alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True) - betas = torch.tensor(torch.tensor(beta).expand(num_samples, 1), requires_grad=True) + alphas = torch.tensor(alpha).expand(num_samples, 1).requires_grad_(True) + betas = torch.tensor(beta).expand(num_samples, 1).requires_grad_(True) dist = ShapeAugmentedBeta(alphas, betas) # implemented using Rejector z = dist.rsample() cost = z.sum() From 0064b6ae61a095b11375ff85a5e4e570ea8a708e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 20 Sep 2018 15:20:25 -0700 Subject: [PATCH 078/157] fix test_gaussian_mixtures --- tests/distributions/test_gaussian_mixtures.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributions/test_gaussian_mixtures.py b/tests/distributions/test_gaussian_mixtures.py index 35d3e7ca0f..4101f733ca 100644 --- a/tests/distributions/test_gaussian_mixtures.py +++ b/tests/distributions/test_gaussian_mixtures.py @@ -27,7 +27,8 @@ def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode): else: locs = torch.rand(K, D).requires_grad_(True) if mix_dist == GaussianScaleMixture: - component_scale = 1.5 * torch.ones(K) + 0.5 * torch.rand(K).requires_grad_(True) + component_scale = 1.5 * torch.ones(K) + 0.5 * torch.rand(K) + component_scale.requires_grad_(True) else: component_scale = torch.ones(K, requires_grad=True) if mix_dist == MixtureOfDiagNormals: From 710406d6e70f1d470863752c4c262ddf5417ad35 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 13:17:59 -0700 Subject: [PATCH 079/157] Test fixes for compatibility with PyTorch master --- examples/air/main.py | 2 +- tests/infer/mcmc/test_hmc.py | 2 +- tests/infer/mcmc/test_nuts.py | 2 +- tests/infer/test_inference.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/air/main.py b/examples/air/main.py index aae26828de..d1f6c0b9f7 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -188,9 +188,9 @@ def z_pres_prior_p(opt_step, time_step): print('Loading parameters...') air.load_state_dict(torch.load(args.load)) - vis = visdom.Visdom(env=args.visdom_env) # Viz sample from prior. if args.viz: + vis = visdom.Visdom(env=args.visdom_env) z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0)) vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z)))) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index b3658e7dec..cb87d8005d 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -276,7 +276,7 @@ def test_gaussian_mixture_model(): @poutine.broadcast def gmm(data): with pyro.iarange("num_clusters", K): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) + mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor([1.]))) cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 649371ac8b..42a74df887 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -188,7 +188,7 @@ def test_gaussian_mixture_model(): @poutine.broadcast def gmm(data): with pyro.iarange("num_clusters", K): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor(1.))) + mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor([1.]))) cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index 78cd6998c0..2b5d1dd135 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -324,9 +324,9 @@ def model(): def guide(): alpha_q_log = pyro.param("alpha_q_log", - torch.tensor(self.log_alpha_n.data + 0.17, requires_grad=True)) + self.log_alpha_n.clone() + 0.17) beta_q_log = pyro.param("beta_q_log", - torch.tensor(self.log_beta_n.data - 0.143, requires_grad=True)) + self.log_beta_n.clone() - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", Beta(alpha_q, beta_q)) From eef40ebb423a5c6737bf205fafe58852b5295841 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 13:43:39 -0700 Subject: [PATCH 080/157] address comments; more fixes --- tests/infer/mcmc/test_hmc.py | 4 ++-- tests/infer/mcmc/test_nuts.py | 2 +- tests/infer/test_inference.py | 4 ++-- .../test_conjugate_gaussian_models.py | 13 ++++--------- .../integration_tests/test_tracegraph_elbo.py | 19 +++++++++---------- 5 files changed, 18 insertions(+), 24 deletions(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index cb87d8005d..53dce68afe 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -275,8 +275,8 @@ def test_gaussian_mixture_model(): @poutine.broadcast def gmm(data): + mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) with pyro.iarange("num_clusters", K): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor([1.]))) cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) @@ -288,7 +288,7 @@ def gmm(data): cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() hmc_kernel = HMC(gmm, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=1) - mcmc_run = MCMC(hmc_kernel, num_samples=600, warmup_steps=200).run(data) + mcmc_run = MCMC(hmc_kernel, num_samples=300, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) assert_equal(posterior[1], true_cluster_means, prec=0.2) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 42a74df887..16b364a303 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -187,8 +187,8 @@ def test_gaussian_mixture_model(): @poutine.broadcast def gmm(data): + mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) with pyro.iarange("num_clusters", K): - mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.tensor([1.]))) cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.iarange("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) diff --git a/tests/infer/test_inference.py b/tests/infer/test_inference.py index 2b5d1dd135..72b8c3b077 100644 --- a/tests/infer/test_inference.py +++ b/tests/infer/test_inference.py @@ -324,9 +324,9 @@ def model(): def guide(): alpha_q_log = pyro.param("alpha_q_log", - self.log_alpha_n.clone() + 0.17) + self.log_alpha_n + 0.17) beta_q_log = pyro.param("beta_q_log", - self.log_beta_n.clone() - 0.143) + self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", Beta(alpha_q, beta_q)) diff --git a/tests/integration_tests/test_conjugate_gaussian_models.py b/tests/integration_tests/test_conjugate_gaussian_models.py index c58d19d6b3..cb53ff998f 100644 --- a/tests/integration_tests/test_conjugate_gaussian_models.py +++ b/tests/integration_tests/test_conjugate_gaussian_models.py @@ -408,18 +408,13 @@ def guide(self, reparameterized, model_permutation, difficulty=0.0): deps = self.q_dag.predecessors(node) node_suffix = node[11:] log_sig_node = pyro.param("log_sig_" + node_suffix, - torch.tensor(-0.5 * torch.log(self.target_lambdas[node_suffix]).data + - difficulty * (torch.Tensor([-0.3]) - - 0.3 * (torch.randn(1) ** 2)), - requires_grad=True)) + -0.5 * torch.log(self.target_lambdas[node_suffix]) + + difficulty * (torch.Tensor([-0.3]) - 0.3 * (torch.randn(1) ** 2))) mean_function_node = pyro.param("constant_term_" + node, - torch.tensor(self.loc0.data + - torch.Tensor([difficulty * i / n_nodes]), - requires_grad=True)) + self.loc0 + torch.Tensor([difficulty * i / n_nodes])) for dep in deps: kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[11:], - torch.tensor([0.5 + difficulty * i / n_nodes], - requires_grad=True)) + torch.tensor([0.5 + difficulty * i / n_nodes])) mean_function_node = mean_function_node + kappa_dep * latents_dict[dep] node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal diff --git a/tests/integration_tests/test_tracegraph_elbo.py b/tests/integration_tests/test_tracegraph_elbo.py index 64efa34156..83d9f1ee4a 100644 --- a/tests/integration_tests/test_tracegraph_elbo.py +++ b/tests/integration_tests/test_tracegraph_elbo.py @@ -71,9 +71,9 @@ def model(): return loc_latent def guide(): - loc_q = pyro.param("loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.334, requires_grad=True)) + loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334) log_sig_q = pyro.param("log_sig_q", - torch.tensor(self.analytic_log_sig_n.expand(2) - 0.29, requires_grad=True)) + self.analytic_log_sig_n.expand(2) - 0.29) sig_q = torch.exp(log_sig_q) with pyro.iarange("iarange", 2): loc_latent = pyro.sample("loc_latent", Normal(loc_q, sig_q)) @@ -164,15 +164,14 @@ def model(): # note that the exact posterior is not mean field! def guide(): - loc_q = pyro.param("loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.334, requires_grad=True)) + loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334) log_sig_q = pyro.param("log_sig_q", - torch.tensor(self.analytic_log_sig_n.expand(2) - 0.29, requires_grad=True)) + self.analytic_log_sig_n.expand(2) - 0.29) loc_q_prime = pyro.param("loc_q_prime", - torch.tensor([-0.34, 0.52], requires_grad=True)) - kappa_q = pyro.param("kappa_q", torch.tensor([0.74], - requires_grad=True)) + torch.tensor([-0.34, 0.52])) + kappa_q = pyro.param("kappa_q", torch.tensor([0.74])) log_sig_q_prime = pyro.param("log_sig_q_prime", - torch.tensor(-0.5 * torch.log(1.2 * self.lam0), requires_grad=True)) + -0.5 * torch.log(1.2 * self.lam0)) sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime) with pyro.iarange("iarange", 2): loc_latent = pyro.sample("loc_latent", Normal2(loc_q, sig_q), @@ -440,9 +439,9 @@ def model(): torch.nn.Linear(2, 2)]) def guide(): - loc_q = pyro.param("loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.094, requires_grad=True)) + loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094) log_sig_q = pyro.param("log_sig_q", - torch.tensor(self.analytic_log_sig_n.expand(2) - 0.07, requires_grad=True)) + self.analytic_log_sig_n.expand(2) - 0.07) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline) baseline_value = trivial_baseline(torch.ones(1)).squeeze() From 2325b1490ad291974e5729bd90ebdd017c0f2d98 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 14:01:20 -0700 Subject: [PATCH 081/157] more test fixes --- pyro/distributions/torch_patch.py | 28 +++++++++---------- tests/infer/mcmc/test_nuts.py | 2 +- .../test_conjugate_gaussian_models.py | 2 +- .../integration_tests/test_tracegraph_elbo.py | 14 +++++----- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index b7b3f6938c..36c00b4753 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -21,20 +21,20 @@ def decorator(new_fn): return decorator - -@_patch('torch._standard_gamma') -def _torch_standard_gamma(concentration): - unpatched_fn = _torch_standard_gamma._pyro_unpatched - if concentration.is_cuda: - return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) - return unpatched_fn(concentration) - - -@_patch('torch.distributions.gamma._standard_gamma') -def _standard_gamma(concentration): - if concentration.is_cuda: - return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) - return concentration._standard_gamma() +# +# @_patch('torch._standard_gamma') +# def _torch_standard_gamma(concentration): +# unpatched_fn = _torch_standard_gamma._pyro_unpatched +# if concentration.is_cuda: +# return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) +# return unpatched_fn(concentration) +# +# +# @_patch('torch.distributions.gamma._standard_gamma') +# def _standard_gamma(concentration): +# if concentration.is_cuda: +# return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) +# return concentration._standard_gamma() @_patch('torch._dirichlet_grad') diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 16b364a303..817e85e470 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -200,7 +200,7 @@ def gmm(data): cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() nuts_kernel = NUTS(gmm, adapt_step_size=True, max_iarange_nesting=1) - mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data) + mcmc_run = MCMC(nuts_kernel, num_samples=300, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) assert_equal(posterior[1], true_cluster_means, prec=0.2) diff --git a/tests/integration_tests/test_conjugate_gaussian_models.py b/tests/integration_tests/test_conjugate_gaussian_models.py index cb53ff998f..9dd17a84c2 100644 --- a/tests/integration_tests/test_conjugate_gaussian_models.py +++ b/tests/integration_tests/test_conjugate_gaussian_models.py @@ -409,7 +409,7 @@ def guide(self, reparameterized, model_permutation, difficulty=0.0): node_suffix = node[11:] log_sig_node = pyro.param("log_sig_" + node_suffix, -0.5 * torch.log(self.target_lambdas[node_suffix]) + - difficulty * (torch.Tensor([-0.3]) - 0.3 * (torch.randn(1) ** 2))) + difficulty * (torch.Tensor([-0.3]) - 0.3 * (torch.randn(1) ** 2))) mean_function_node = pyro.param("constant_term_" + node, self.loc0 + torch.Tensor([difficulty * i / n_nodes])) for dep in deps: diff --git a/tests/integration_tests/test_tracegraph_elbo.py b/tests/integration_tests/test_tracegraph_elbo.py index 83d9f1ee4a..a9776cda4a 100644 --- a/tests/integration_tests/test_tracegraph_elbo.py +++ b/tests/integration_tests/test_tracegraph_elbo.py @@ -243,9 +243,9 @@ def model(): def guide(): alpha_q_log = pyro.param("alpha_q_log", - torch.tensor(self.log_alpha_n + 0.17, requires_grad=True)) + self.log_alpha_n + 0.17) beta_q_log = pyro.param("beta_q_log", - torch.tensor(self.log_beta_n - 0.143, requires_grad=True)) + self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) p_latent = pyro.sample("p_latent", Beta(alpha_q, beta_q), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) @@ -301,10 +301,10 @@ def model(): def guide(): alpha_q_log = pyro.param( "alpha_q_log", - torch.tensor(self.log_alpha_n + 0.17, requires_grad=True)) + self.log_alpha_n + 0.17) beta_q_log = pyro.param( "beta_q_log", - torch.tensor(self.log_beta_n - 0.143, requires_grad=True)) + self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) @@ -365,9 +365,9 @@ def model(): obs=self.data[i][j]) def guide(): - loc_q = pyro.param("loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.234, requires_grad=True)) + loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.234) log_sig_q = pyro.param("log_sig_q", - torch.tensor(self.analytic_log_sig_n.expand(2) - 0.27, requires_grad=True)) + self.analytic_log_sig_n.expand(2) - 0.27) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).independent(1), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) @@ -456,7 +456,7 @@ def guide(): pt_superfluous_baselines[3 * k + i]) baseline_value = z_baseline(loc_latent.detach()) mean_i = pyro.param("mean_%d_%d" % (i, k), - torch.tensor(0.5 * torch.ones(4 - i), requires_grad=True)) + 0.5 * torch.ones(4 - i)) z_i_k = pyro.sample("z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(mean_i, 1), infer=dict(baseline=dict(baseline_value=baseline_value))) From a5a457ff82c7422b470a19ea72fd22ad1e669f91 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 14:05:13 -0700 Subject: [PATCH 082/157] uncomment torch_patch --- pyro/distributions/torch_patch.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 36c00b4753..b7b3f6938c 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -21,20 +21,20 @@ def decorator(new_fn): return decorator -# -# @_patch('torch._standard_gamma') -# def _torch_standard_gamma(concentration): -# unpatched_fn = _torch_standard_gamma._pyro_unpatched -# if concentration.is_cuda: -# return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) -# return unpatched_fn(concentration) -# -# -# @_patch('torch.distributions.gamma._standard_gamma') -# def _standard_gamma(concentration): -# if concentration.is_cuda: -# return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) -# return concentration._standard_gamma() + +@_patch('torch._standard_gamma') +def _torch_standard_gamma(concentration): + unpatched_fn = _torch_standard_gamma._pyro_unpatched + if concentration.is_cuda: + return unpatched_fn(concentration.cpu()).cuda(concentration.get_device()) + return unpatched_fn(concentration) + + +@_patch('torch.distributions.gamma._standard_gamma') +def _standard_gamma(concentration): + if concentration.is_cuda: + return concentration.cpu()._standard_gamma().cuda(concentration.get_device()) + return concentration._standard_gamma() @_patch('torch._dirichlet_grad') From 41164bbe84c65661edf33a1ddae70d1959795ad5 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 15:42:16 -0700 Subject: [PATCH 083/157] ignore jit warnings in hmc --- tests/infer/mcmc/test_hmc.py | 3 ++- tests/infer/mcmc/test_nuts.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 553dde4604..ea3be9e08a 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -287,7 +287,8 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, jit_compile=jit) + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index cad2aefe4a..1852b09176 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -234,8 +234,7 @@ def gmm(data): assert_equal(posterior[1], true_cluster_means, prec=0.2) -@pytest.mark.parametrize("jit", [False, mark_jit(True, marks=[ - pytest.mark.xfail(reason="FIXME: log not implemented for 'CPULongType'")])], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) def test_bernoulli_latent_model(jit): @poutine.broadcast def model(data): @@ -250,7 +249,8 @@ def model(data): y = dist.Bernoulli(y_prob).sample(torch.Size((N,))) z = dist.Bernoulli(0.65 * y + 0.1).sample() data = dist.Normal(2. * z, 1.0).sample() - nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=1, jit_compile=jit) + nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=1, + jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=600, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="y_prob").mean assert_equal(posterior, y_prob, prec=0.05) From 21b32c026cabc1ccb20635a2b07f606f5d048045 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 15:44:13 -0700 Subject: [PATCH 084/157] remove default jit compilation in air --- examples/air/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/air/main.py b/examples/air/main.py index dc9ef9dcfe..d10efc276d 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -287,7 +287,7 @@ def per_param_optim_args(module_name, param_name): help='number of steps between parameter saves') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') - parser.add_argument('--jit', action='store_true', default=True, + parser.add_argument('--jit', action='store_true', help='use PyTorch jit') parser.add_argument('-t', '--model-steps', type=int, default=3, help='number of time steps') From 865ddab98e7042f412cd09287bc47667e9f0b050 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 15:48:04 -0700 Subject: [PATCH 085/157] set args.jit default to false --- examples/air/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/air/main.py b/examples/air/main.py index d10efc276d..d1f6c0b9f7 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -287,7 +287,7 @@ def per_param_optim_args(module_name, param_name): help='number of steps between parameter saves') parser.add_argument('--cuda', action='store_true', default=False, help='use cuda') - parser.add_argument('--jit', action='store_true', + parser.add_argument('--jit', action='store_true', default=False, help='use PyTorch jit') parser.add_argument('-t', '--model-steps', type=int, default=3, help='number of time steps') From 6ddeb383dcb7f87fd84eda68b147dd09edf7689e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 16:56:34 -0700 Subject: [PATCH 086/157] ignore jit warnings in hmc tests --- tests/infer/mcmc/test_hmc.py | 2 +- tests/infer/mcmc/test_nuts.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index ea3be9e08a..ef46d25bb5 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -315,7 +315,7 @@ def gmm(data): cluster_assignments = dist.Categorical(true_mix_proportions).sample(torch.Size((N,))) data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample() hmc_kernel = HMC(gmm, trajectory_length=1, adapt_step_size=True, - max_iarange_nesting=1, jit_compile=jit) + max_iarange_nesting=1, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=300, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=["phi", "cluster_means"]).mean.sort()[0] assert_equal(posterior[0], true_mix_proportions, prec=0.05) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 1852b09176..fd43c6a198 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -99,7 +99,7 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, step_size=0.0855, jit_compile=jit) + nuts_kernel = NUTS(model, step_size=0.0855, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, posterior.mean).item(), 0.0, prec=0.1) @@ -116,7 +116,7 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, step_size=0.02, jit_compile=jit) + nuts_kernel = NUTS(model, step_size=0.02, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) @@ -133,7 +133,7 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, step_size=0.01, jit_compile=jit) + nuts_kernel = NUTS(model, step_size=0.01, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) @@ -152,7 +152,7 @@ def model(data): y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) return y - nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='beta') assert_equal(rmse(true_coefs, posterior.mean).item(), 0.0, prec=0.1) @@ -169,7 +169,8 @@ def model(data): true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, + ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites="p_latent") assert_equal(posterior.mean, true_probs, prec=0.03) @@ -185,7 +186,7 @@ def model(data): true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) - nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02) @@ -201,7 +202,7 @@ def model(data): true_alpha = torch.tensor(5.) true_beta = torch.tensor(1.) data = dist.Beta(concentration1=true_alpha, concentration0=true_beta).sample(torch.Size((5000,))) - nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit) + nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=200).run(data) posterior = EmpiricalMarginal(mcmc_run, sites=['alpha', 'beta']) assert_equal(posterior.mean, torch.stack([true_alpha, true_beta]), prec=0.05) From 43c5e4e06e4762d0d02f6e1f29c6afaef3bf324d Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 26 Sep 2018 17:33:12 -0700 Subject: [PATCH 087/157] mark failing hmc tests --- tests/infer/mcmc/test_hmc.py | 4 +++- tests/infer/mcmc/test_nuts.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index ef46d25bb5..f3bf152425 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -221,7 +221,9 @@ def model(data): assert_equal(posterior.mean, true_std, prec=0.05) -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, + mark_jit(True, marks=[pytest.mark.xfail("https://github.com/uber/pyro/issues/1418")]) + ], ids=jit_idfn) def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index fd43c6a198..b9c9ffad85 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -176,7 +176,9 @@ def model(data): assert_equal(posterior.mean, true_probs, prec=0.03) -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, + mark_jit(True, marks=[pytest.mark.xfail("https://github.com/uber/pyro/issues/1418")]) + ], ids=jit_idfn) def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) From 68171b999568e504a7fe21811414f5c3fe1fd643 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 10:39:50 -0700 Subject: [PATCH 088/157] test against nightly build --- .travis.yml | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5e58a95d8b..8cf3c05962 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,10 +9,19 @@ env: install: - pip install -U pip - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then - pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; + wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; else - pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; + wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; fi + - bash miniconda.sh -b -p $HOME/miniconda + - export PATH="$HOME/miniconda/bin:$PATH" + - hash -r + - conda config --set always_yes yes --set changeps1 no + - conda update -q conda + - conda info -a + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION + - source activate test-environment + - conda install -c pytorch-nightly pytorch - pip install .[test] - pip freeze From f4db7126cd623a8d4e34bf10ca85ca4a8e29946d Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 10:44:48 -0700 Subject: [PATCH 089/157] fix channel name --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 8cf3c05962..a31ec3bf85 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,7 +21,7 @@ install: - conda info -a - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION - source activate test-environment - - conda install -c pytorch-nightly pytorch + - conda install -c pytorch pytorch-nightly - pip install .[test] - pip freeze From fb290e8ebcc8e5813980089f79fd469434531352 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 11:24:11 -0700 Subject: [PATCH 090/157] downgrade ipython --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 32360e56cc..cd29d1139a 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,8 @@ 'pytest==3.7', 'pytest-cov', 'scipy>=0.19.0', + # https://github.com/ipython/ipython/issues/11335 + 'ipython==7.00rc1', ], 'profile': ['prettytable', 'pytest-benchmark', 'snakeviz'], 'dev': EXTRAS_REQUIRE + [ From ed4b360d5eda981ee87179daf064b7f13c21d68a Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 11:47:05 -0700 Subject: [PATCH 091/157] fix lapack issue --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index a31ec3bf85..b04ad0a0d8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ install: - conda config --set always_yes yes --set changeps1 no - conda update -q conda - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy - source activate test-environment - conda install -c pytorch pytorch-nightly - pip install .[test] From 214258987dab4a4815937851b59c08be45337bf6 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 12:23:42 -0700 Subject: [PATCH 092/157] include mkl --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index b04ad0a0d8..1843ee4baa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ install: - conda config --set always_yes yes --set changeps1 no - conda update -q conda - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy mkl mkl-include - source activate test-environment - conda install -c pytorch pytorch-nightly - pip install .[test] From 1efd552dc823899e847584755b64cb9ac6da7798 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 12:37:39 -0700 Subject: [PATCH 093/157] addons to .travis --- .travis.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1843ee4baa..418cff1110 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,8 +6,16 @@ env: global: - PYTHONPATH=$PWD:$PYTHONPATH +addons: + apt: + packages: + - libblas-dev + - liblapack-dev + - gfortran + install: - pip install -U pip + - sudo apt-get update - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; else @@ -16,12 +24,11 @@ install: - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - hash -r - - conda config --set always_yes yes --set changeps1 no + - conda config --set always_yes yes --set changeps1 no --add channels pytorch - conda update -q conda - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy mkl mkl-include + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy mkl mkl-include pytorch-nightly - source activate test-environment - - conda install -c pytorch pytorch-nightly - pip install .[test] - pip freeze From b84122f90a2504411a32b127ffa72e3aef2f2843 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 12:44:26 -0700 Subject: [PATCH 094/157] add pytorch channel --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 418cff1110..f788412987 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,7 +24,8 @@ install: - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - hash -r - - conda config --set always_yes yes --set changeps1 no --add channels pytorch + - conda config --set always_yes yes --set changeps1 no + - conda config --add channels pytorch - conda update -q conda - conda info -a - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy mkl mkl-include pytorch-nightly From 85342efbd694af73a7d179d503699d3c9716857f Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 12:51:09 -0700 Subject: [PATCH 095/157] remove pythonpath --- .travis.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index f788412987..1453cc4f55 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,10 +2,6 @@ language: python sudo: true -env: - global: - - PYTHONPATH=$PWD:$PYTHONPATH - addons: apt: packages: From 464fe68bdc50ff6be2c53d9e5e01574a08507157 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 12:57:34 -0700 Subject: [PATCH 096/157] editable install --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1453cc4f55..b329ccaae1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ install: - conda info -a - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy mkl mkl-include pytorch-nightly - source activate test-environment - - pip install .[test] + - pip install -e .[test] - pip freeze branches: From 475f483cb9bd3206e2bd63842583013bd3ead0da Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 13:37:52 -0700 Subject: [PATCH 097/157] add ld_library_path --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index b329ccaae1..04321faa06 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ addons: install: - pip install -U pip - sudo apt-get update + - export LD_LIBRARY_PATH=/usr/lib:$LD_LIBRARY_PATH - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; else From 07e20d4ebcb1778d058e2beffef2bb4c1dc7aedb Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 14:04:24 -0700 Subject: [PATCH 098/157] conda install pip --- .travis.yml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 04321faa06..771933be96 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,13 +2,6 @@ language: python sudo: true -addons: - apt: - packages: - - libblas-dev - - liblapack-dev - - gfortran - install: - pip install -U pip - sudo apt-get update @@ -25,7 +18,7 @@ install: - conda config --add channels pytorch - conda update -q conda - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy mkl mkl-include pytorch-nightly + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch-nightly pip - source activate test-environment - pip install -e .[test] - pip freeze From a624b5775438ab7d988324f74cc3f7e439d1762b Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 14:14:53 -0700 Subject: [PATCH 099/157] debug build --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 771933be96..3055aecf59 100644 --- a/.travis.yml +++ b/.travis.yml @@ -43,6 +43,8 @@ jobs: python: 3.5 env: STAGE=docs script: + - which pip + - which python - pip install -r docs/requirements.txt - make docs - make doctest From 053163f75d5ec2b6c33af123c5b282883e879e4c Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 14:26:25 -0700 Subject: [PATCH 100/157] debug - revert to pytorch release --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 3055aecf59..f9864ccf59 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,7 +18,7 @@ install: - conda config --add channels pytorch - conda update -q conda - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch-nightly pip + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch pip - source activate test-environment - pip install -e .[test] - pip freeze From fb62e69d1198f702c99cf3d5114e789b6fe9dda5 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 14:43:43 -0700 Subject: [PATCH 101/157] add before install --- .travis.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index f9864ccf59..2999564ff6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,9 @@ language: python sudo: true +before_install: + - sudo apt-get install libblas-dev liblapack-dev + install: - pip install -U pip - sudo apt-get update @@ -18,7 +21,7 @@ install: - conda config --add channels pytorch - conda update -q conda - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch pip + - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch-nightly pip - source activate test-environment - pip install -e .[test] - pip freeze From 13be21f6ad1aea3a9cb566400d274bd56f2bdcc1 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 16:09:43 -0700 Subject: [PATCH 102/157] use nightly wheel --- .travis.yml | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2999564ff6..253ba6c9da 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,28 +2,14 @@ language: python sudo: true -before_install: - - sudo apt-get install libblas-dev liblapack-dev +env: + global: + - PYTHONPATH=$PWD:$PYTHONPATH install: - pip install -U pip - - sudo apt-get update - - export LD_LIBRARY_PATH=/usr/lib:$LD_LIBRARY_PATH - - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then - wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; - else - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - fi - - bash miniconda.sh -b -p $HOME/miniconda - - export PATH="$HOME/miniconda/bin:$PATH" - - hash -r - - conda config --set always_yes yes --set changeps1 no - - conda config --add channels pytorch - - conda update -q conda - - conda info -a - - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch-nightly pip - - source activate test-environment - - pip install -e .[test] + - pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install .[test] - pip freeze branches: @@ -46,8 +32,6 @@ jobs: python: 3.5 env: STAGE=docs script: - - which pip - - which python - pip install -r docs/requirements.txt - make docs - make doctest From 31b5d636f95e2b47ccde0ca41a3a7128747bd96e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 16:16:22 -0700 Subject: [PATCH 103/157] Fix incompatible dependency between jupyter-console and ipython --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index efb404960f..50ec0766dc 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,7 @@ 'pytest==3.7', 'pytest-cov', 'scipy>=0.19.0', + 'ipython<=6.5.0', # https://github.com/jupyter/jupyter_console/issues/158 ], 'profile': ['prettytable', 'pytest-benchmark', 'snakeviz'], 'dev': EXTRAS_REQUIRE + [ @@ -105,6 +106,7 @@ 'pypandoc', 'pytest==3.7', 'pytest-xdist', + 'ipython<=6.5.0', # https://github.com/jupyter/jupyter_console/issues/158 'scipy>=0.19.0', 'sphinx', 'sphinx_rtd_theme', From f904d0ca61f5010b4a77c0bdb3184179b71d53f2 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 16:32:36 -0700 Subject: [PATCH 104/157] remove torch==0.4.1 from setup --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d9ede9ed49..5f2c8b51f9 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,8 @@ 'numpy>=1.7', 'opt_einsum>=2.2.0', 'six>=1.10.0', - 'torch>=0.4.1', + # TODO: uncomment on release; using torch-nightly build + # 'torch>=0.4.1', 'tqdm>=4.25', ], extras_require={ From 13d7b51769b6a370bd16ec92504cb724675345aa Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 16:40:27 -0700 Subject: [PATCH 105/157] remove torchvision temporarily --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5f2c8b51f9..d8f51f3261 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,8 @@ 'matplotlib>=1.3', 'observations>=0.1.4', 'pillow', - 'torchvision', + # TODO: uncomment on release; using torch-nightly build + # 'torchvision', 'visdom>=0.1.4', 'pandas', 'wget', From 0e17d3cee0b127d94689c7624e6e62c25b94ffe7 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 17:24:57 -0700 Subject: [PATCH 106/157] install torchvision without deps --- .travis.yml | 1 + setup.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 253ba6c9da..df8943a823 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,6 +9,7 @@ env: install: - pip install -U pip - pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torchvision --no-dependencies - pip install .[test] - pip freeze diff --git a/setup.py b/setup.py index d8f51f3261..5f2c8b51f9 100644 --- a/setup.py +++ b/setup.py @@ -56,8 +56,7 @@ 'matplotlib>=1.3', 'observations>=0.1.4', 'pillow', - # TODO: uncomment on release; using torch-nightly build - # 'torchvision', + 'torchvision', 'visdom>=0.1.4', 'pandas', 'wget', From 92d6d4deb322174d7f727dda65690e0086551bda Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 17:32:57 -0700 Subject: [PATCH 107/157] remove torchvision from setup --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5f2c8b51f9..d8f51f3261 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,8 @@ 'matplotlib>=1.3', 'observations>=0.1.4', 'pillow', - 'torchvision', + # TODO: uncomment on release; using torch-nightly build + # 'torchvision', 'visdom>=0.1.4', 'pandas', 'wget', From 72c76fda4aa0e16fe7c90ba1c17ed8e05f8f435e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 19:33:34 -0700 Subject: [PATCH 108/157] update to contextlib2 --- pyro/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/util.py b/pyro/util.py index 6c8a94ff5a..1c6a7722fc 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -5,7 +5,7 @@ import random import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib2 import contextmanager import graphviz import torch From 0e60ce258d10c4eaca642ef7159d40e6978092ca Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 19:47:01 -0700 Subject: [PATCH 109/157] fix benchmark tests --- tests/perf/test_benchmark.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 7b2f60f744..2a969a98f3 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -66,18 +66,8 @@ def model(): return lambda_latent def guide(): - alpha_q_log = pyro.param( - "alpha_q_log", - torch.tensor( - log_alpha_n.data + - 0.17, - requires_grad=True)) - beta_q_log = pyro.param( - "beta_q_log", - torch.tensor( - log_beta_n.data - - 0.143, - requires_grad=True)) + alpha_q_log = pyro.param("alpha_q_log", log_alpha_n + 0.17) + beta_q_log = pyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) From b328d12a200e036a38b2e237a0c65cda1f95a625 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 22:05:50 -0700 Subject: [PATCH 110/157] add xfail markers for failing tests --- tests/contrib/oed/test_eig.py | 9 +++++---- tests/contrib/oed/test_ewma.py | 1 + tests/infer/test_enum.py | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/contrib/oed/test_eig.py b/tests/contrib/oed/test_eig.py index 922283a914..cd2df30061 100644 --- a/tests/contrib/oed/test_eig.py +++ b/tests/contrib/oed/test_eig.py @@ -21,6 +21,7 @@ ) from pyro.contrib.oed.util import linear_model_ground_truth from pyro.infer import TraceEnum_ELBO +from tests.common import xfail_param logger = logging.getLogger(__name__) @@ -147,7 +148,7 @@ def h(p): False, 0.3 ), - T( + xfail_param(*T( basic_2p_linear_model_sds_10_2pt5, X_circle_5d_1n_2p, "y", @@ -157,7 +158,7 @@ def h(p): optim.Adam({"lr": 0.025}), False, None, 500], True, 0.3 - ), + ), reason="https://github.com/uber/pyro/issues/1418"), T( basic_2p_linear_model_sds_10_2pt5, AB_test_2d_10n_2p, @@ -203,7 +204,7 @@ def h(p): 0.3, marks=pytest.mark.xfail ), - T( + xfail_param(*T( group_2p_linear_model_sds_10_2pt5, X_circle_5d_1n_2p, "y", @@ -213,7 +214,7 @@ def h(p): optim.Adam({"lr": 0.025}), False, None, 500], True, 0.3 - ), + ), reason="https://github.com/uber/pyro/issues/1418"), T( group_2p_linear_model_sds_10_2pt5, X_circle_5d_1n_2p, diff --git a/tests/contrib/oed/test_ewma.py b/tests/contrib/oed/test_ewma.py index ef606e22e7..514c94ddad 100644 --- a/tests/contrib/oed/test_ewma.py +++ b/tests/contrib/oed/test_ewma.py @@ -8,6 +8,7 @@ @pytest.mark.parametrize("alpha", [0.5, 0.9, 0.99]) +@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1418") def test_ewma(alpha, NS=10000, D=1): ewma_log = EwmaLog(alpha=alpha) sigma = torch.tensor(1.0, requires_grad=True) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index e2e704f14a..c195ab2e49 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -2925,6 +2925,7 @@ def guide(data): (dist.MixtureOfDiagNormals, [[2., 1.], [1., 2], [4., 4.]]), (dist.MixtureOfDiagNormalsSharedCovariance, [2., 1.]), ]) +@pytest.mark.xfail("https://github.com/uber/pyro/issues/1418") def test_mixture_of_diag_normals(mixture, scale): # K = 3, D = 2 pyro.param("locs", torch.tensor([[0., 0.], [0., 1.], [0., 10.]])) From 3161deb5ae62757421409a42ae53c94056938c89 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 27 Sep 2018 23:48:06 -0700 Subject: [PATCH 111/157] temporarily xfail ubersum_sizes test --- tests/ops/test_contract.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ops/test_contract.py b/tests/ops/test_contract.py index 7b4609b071..6b2389ff3b 100644 --- a/tests/ops/test_contract.py +++ b/tests/ops/test_contract.py @@ -288,6 +288,7 @@ def test_ubersum(equation, batch_dims): @pytest.mark.parametrize('b', [3, 1]) @pytest.mark.parametrize('c', [3, 1]) @pytest.mark.parametrize('d', [4, 1]) +@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1418") def test_ubersum_sizes(a, b, c, d): X = torch.randn(a, b) Y = torch.randn(b, c) From e298cb4b81ffe86a6f4a6a1bedad5bd1c010f530 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 28 Sep 2018 11:02:31 -0700 Subject: [PATCH 112/157] fix xfail marker --- tests/infer/test_enum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index c195ab2e49..fecd9f7bc0 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -2925,7 +2925,7 @@ def guide(data): (dist.MixtureOfDiagNormals, [[2., 1.], [1., 2], [4., 4.]]), (dist.MixtureOfDiagNormalsSharedCovariance, [2., 1.]), ]) -@pytest.mark.xfail("https://github.com/uber/pyro/issues/1418") +@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1418") def test_mixture_of_diag_normals(mixture, scale): # K = 3, D = 2 pyro.param("locs", torch.tensor([[0., 0.], [0., 1.], [0., 10.]])) From feca15d8d5208c7877963f91c56996645dbfa24a Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Fri, 28 Sep 2018 16:53:10 -0700 Subject: [PATCH 113/157] remove xfail marker from test_enum --- tests/infer/test_enum.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 1accc6ecc0..6fb6c5091c 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -2925,7 +2925,6 @@ def guide(data): (dist.MixtureOfDiagNormals, [[2., 1.], [1., 2], [4., 4.]]), (dist.MixtureOfDiagNormalsSharedCovariance, [2., 1.]), ]) -@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1418") def test_mixture_of_diag_normals(mixture, scale): # K = 3, D = 2 pyro.param("locs", torch.tensor([[0., 0.], [0., 1.], [0., 10.]])) From a7300e81d62a74b274d968289d70ee10720f6d85 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Sun, 30 Sep 2018 20:39:01 -0700 Subject: [PATCH 114/157] add xfail for mixture of diag normals --- tests/infer/test_enum.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index 6fb6c5091c..7cc736ee77 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -2925,6 +2925,7 @@ def guide(data): (dist.MixtureOfDiagNormals, [[2., 1.], [1., 2], [4., 4.]]), (dist.MixtureOfDiagNormalsSharedCovariance, [2., 1.]), ]) +@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1425") def test_mixture_of_diag_normals(mixture, scale): # K = 3, D = 2 pyro.param("locs", torch.tensor([[0., 0.], [0., 1.], [0., 10.]])) From 379ffef56356ed0b68973f045404c6d2dc4fdf59 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 1 Oct 2018 20:33:36 -0700 Subject: [PATCH 115/157] fix mask fill on non contiguous tensor --- pyro/distributions/util.py | 3 ++- tests/infer/test_enum.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index fa8576bc60..adc9509aeb 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -176,7 +176,8 @@ def scale_and_mask(tensor, scale=1.0, mask=None): if mask is None: return tensor * scale tensor, mask = broadcast_all(tensor, mask) - tensor = tensor * scale # triggers a copy, avoiding in-place op errors + # TODO: Remove .contiguous once https://github.com/pytorch/pytorch/issues/12230 is fixed. + tensor = (tensor * scale).contiguous() if torch._C._get_tracing_state(): tensor[~mask] = 0. else: diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index c8a1dcb3fe..feb73d8b57 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -2929,7 +2929,6 @@ def guide(data): (dist.MixtureOfDiagNormals, [[2., 1.], [1., 2], [4., 4.]]), (dist.MixtureOfDiagNormalsSharedCovariance, [2., 1.]), ]) -@pytest.mark.xfail(reason="https://github.com/uber/pyro/issues/1425") def test_mixture_of_diag_normals(mixture, scale): # K = 3, D = 2 pyro.param("locs", torch.tensor([[0., 0.], [0., 1.], [0., 10.]])) From f948b7c32d92582211e17fbf78a65258c825cb3b Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Sun, 21 Oct 2018 17:50:54 -0700 Subject: [PATCH 116/157] fix imports --- pyro/infer/mcmc/hmc.py | 5 ++--- pyro/infer/mcmc/nuts.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 91029ffa7a..28927dd209 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -18,7 +18,6 @@ from pyro.ops.integrator import single_step_velocity_verlet, velocity_verlet from pyro.ops.welford import WelfordCovariance from pyro.poutine.subsample_messenger import _Subsample -from pyro.primitives import _Subsample from pyro.util import optional, torch_isinf, torch_isnan, ignore_jit_warnings @@ -102,8 +101,8 @@ def __init__(self, transforms=None, max_plate_nesting=float("inf"), max_iarange_nesting=None, # DEPRECATED - jit_compile = False, - ignore_jit_warnings = False, + jit_compile=False, + ignore_jit_warnings=False, experimental_use_einsum=False): self.model = model if max_iarange_nesting is not None: diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index 8661eeeba6..2a8a3e9edb 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -107,8 +107,8 @@ def __init__(self, transforms=None, max_plate_nesting=float("inf"), max_iarange_nesting=None, # DEPRECATED - jit_compile = False, - ignore_jit_warnings = False, + jit_compile=False, + ignore_jit_warnings=False, experimental_use_einsum=False): if max_iarange_nesting is not None: warnings.warn("max_iarange_nesting is deprecated; use max_plate_nesting instead", @@ -121,7 +121,7 @@ def __init__(self, adapt_mass_matrix=adapt_mass_matrix, full_mass=full_mass, transforms=transforms, - max_plate_nesting = max_plate_nesting, + max_plate_nesting=max_plate_nesting, max_iarange_nesting=max_iarange_nesting, jit_compile=jit_compile, ignore_jit_warnings=ignore_jit_warnings, From 92305995aa40a92b9e0db08fcd402c3fe9a5993f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Oct 2018 13:56:20 -0400 Subject: [PATCH 117/157] Fix jit arg error in hmm example (#1445) --- examples/hmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/hmm.py b/examples/hmm.py index 33404cb1f9..3104ba1d25 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -253,7 +253,7 @@ def main(args): # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): - loss = svi.step(sequences, lengths, args, batch_size=args.batch_size) + loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) # We evaluate on the entire training dataset, @@ -269,7 +269,7 @@ def main(args): if args.truncate: lengths.clamp_(max=args.truncate) num_observations = float(lengths.sum()) - test_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) + test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, From 75ab233056c7b87aa65c887f0db89efc0dfcfb49 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 22 Oct 2018 10:53:13 -0700 Subject: [PATCH 118/157] Revert change to broadcast messenger --- pyro/poutine/broadcast_messenger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 4e238e5b73..62247aaaf3 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -12,8 +12,9 @@ class BroadcastMessenger(Messenger): broadcastable with the size of the :class:`~pyro.plate` contexts installed in the `cond_indep_stack`. """ + @staticmethod @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) - def _pyro_sample(self, msg): + def _pyro_sample(msg): """ :param msg: current message at a trace site. """ From 16ad65d3522e7fdb4d1b3dec0595b4bd0fb47fb6 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 22 Oct 2018 11:01:22 -0700 Subject: [PATCH 119/157] fix parametrize in test_nuts --- tests/infer/mcmc/test_nuts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index b887aefd68..948fa07a30 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -125,7 +125,6 @@ def test_nuts_conjugate_gaussian(fixture, assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol) -@pytest.mark.parametrize() @pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) @pytest.mark.parametrize("use_multinomial_sampling", [True, False]) def test_logistic_regression(jit, use_multinomial_sampling): From 776afcd8c60e23fa39ed6b5218c907d641dfcb51 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 22 Oct 2018 17:45:37 -0700 Subject: [PATCH 120/157] Fix tests --- pyro/poutine/indep_messenger.py | 12 ++++++------ pyro/poutine/subsample_messenger.py | 4 +++- tests/infer/test_jit.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 127f829633..e3a65b5a8e 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -53,7 +53,7 @@ class IndepMessenger(Messenger): """ def __init__(self, name=None, size=None, dim=None, device=None): - if size == 0: + if not torch._C._get_tracing_state() and size == 0: raise ZeroDivisionError("size cannot be zero") super(IndepMessenger, self).__init__() @@ -96,11 +96,11 @@ def __iter__(self): self._vectorized = False self.dim = None - - for i in self.indices: - self.next_context() - with self: - yield i if isinstance(i, numbers.Number) else i.item() + with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): + for i in self.indices: + self.next_context() + with self: + yield i if isinstance(i, numbers.Number) else i.item() def _reset(self): if self._vectorized: diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 2437ac3798..761fe50bd3 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -3,6 +3,7 @@ import torch from pyro.distributions.distribution import Distribution +from pyro.util import ignore_jit_warnings from .indep_messenger import CondIndepStackFrame, IndepMessenger from .runtime import apply_stack @@ -31,7 +32,8 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None): if self.use_cuda ^ (device != "cpu"): raise ValueError("Incompatible arg values use_cuda={}, device={}." .format(use_cuda, device)) - self.device = torch.Tensor().device if not device else device + with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): + self.device = torch.Tensor().device if not device else device def sample(self, sample_shape=torch.Size()): """ diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 64d421c9d0..8c426d6dec 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -267,7 +267,7 @@ def guide(data): @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("irange_dim", [1, 2]) -@pytest.mark.parametrize('Elbo', [TraceEnum_ELBO, JitTraceEnum_ELBO]) +@pytest.mark.parametrize('Elbo', [JitTraceEnum_ELBO]) def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 From fdb9d20974eb239e40f36cd78b748bdd78707b30 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 22 Oct 2018 20:24:02 -0700 Subject: [PATCH 121/157] stash --- pyro/poutine/subsample_messenger.py | 13 ++++++------- tests/infer/test_jit.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 761fe50bd3..87bf2bb726 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -3,7 +3,7 @@ import torch from pyro.distributions.distribution import Distribution -from pyro.util import ignore_jit_warnings +from pyro.util import ignore_jit_warnings, jit_compatible_arange from .indep_messenger import CondIndepStackFrame, IndepMessenger from .runtime import apply_stack @@ -35,6 +35,7 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None): with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): self.device = torch.Tensor().device if not device else device + @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) def sample(self, sample_shape=torch.Size()): """ :returns: a random subsample of `range(size)` @@ -43,10 +44,8 @@ def sample(self, sample_shape=torch.Size()): if sample_shape: raise NotImplementedError subsample_size = self.subsample_size - if subsample_size is None or subsample_size > self.size: - subsample_size = self.size - if subsample_size >= self.size: - result = torch.arange(self.size, dtype=torch.long).to(self.device) + if subsample_size is None or subsample_size >= self.size: + result = jit_compatible_arange(self.size, device=self.device) else: result = torch.multinomial(torch.ones(self.size), self.subsample_size, replacement=False).to(self.device) @@ -107,10 +106,10 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No subsample = msg["value"] if subsample_size is None: - subsample_size = len(subsample) + subsample_size = subsample.size(0) elif subsample is not None and subsample_size != len(subsample): raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( - subsample_size, len(subsample)) + + subsample_size, subsample.size(0)) + " Did you accidentally use different subsample_size in the model and guide?") return size, subsample_size, subsample diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 8c426d6dec..951ab8d545 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -360,7 +360,7 @@ def test_svi_irregular_batch_size(Elbo): def model(data): loc = pyro.param("loc", constant(0.0)) scale = pyro.param("scale", constant(1.0), constraint=constraints.positive) - with pyro.iarange("data", data.shape[0]): + with pyro.plate("data", data.shape[0]): pyro.sample("x", dist.Normal(loc, scale).expand([data.shape[0]]), obs=data) From f51076cb598af0bb2912e8f0cc53e1db017afaee Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 22 Oct 2018 22:24:41 -0700 Subject: [PATCH 122/157] Fix JIT tests --- pyro/poutine/broadcast_messenger.py | 2 +- pyro/poutine/subsample_messenger.py | 13 +++++++------ tests/infer/mcmc/test_hmc.py | 3 ++- tests/infer/mcmc/test_nuts.py | 6 ++++-- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 62247aaaf3..9a2d973341 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -31,7 +31,7 @@ def _pyro_sample(msg): continue assert f.dim < 0 target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape - if target_batch_shape[f.dim] not in (None, f.size): + if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size: raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format( f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim])) target_batch_shape[f.dim] = f.size diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 87bf2bb726..3d572b11d4 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -105,12 +105,13 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No apply_stack(msg) subsample = msg["value"] - if subsample_size is None: - subsample_size = subsample.size(0) - elif subsample is not None and subsample_size != len(subsample): - raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( - subsample_size, subsample.size(0)) + - " Did you accidentally use different subsample_size in the model and guide?") + with ignore_jit_warnings(): + if subsample_size is None: + subsample_size = subsample.size(0) + elif subsample is not None and subsample_size != len(subsample): + raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( + subsample_size, subsample.size(0)) + + " Did you accidentally use different subsample_size in the model and guide?") return size, subsample_size, subsample diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 54f5992eb3..38f05ba616 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -237,7 +237,8 @@ def model(data): @pytest.mark.parametrize("jit", [False, - mark_jit(True, marks=[pytest.mark.xfail("https://github.com/uber/pyro/issues/1418")]) + mark_jit(True, marks=[pytest.mark.xfail( + reason="https://github.com/uber/pyro/issues/1418")]) ], ids=jit_idfn) def test_dirichlet_categorical(jit): def model(data): diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index 948fa07a30..20617229ff 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -174,7 +174,8 @@ def model(data): assert_equal(posterior.mean, true_probs, prec=0.02) -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, mark_jit(True, marks=[pytest.mark.skip("Doesn't finish")])], + ids=jit_idfn) @pytest.mark.parametrize("use_multinomial_sampling", [True, False]) def test_gamma_normal(jit, use_multinomial_sampling): def model(data): @@ -233,7 +234,8 @@ def model(data): @pytest.mark.parametrize("jit", [False, - mark_jit(True, marks=[pytest.mark.xfail("https://github.com/uber/pyro/issues/1418")]) + mark_jit(True, marks=[pytest.mark.xfail( + reason="https://github.com/uber/pyro/issues/1418")]) ], ids=jit_idfn) def test_dirichlet_categorical(jit): def model(data): From 7200fc3ce88df57a61cd3df3e39c0ba1fc989ce7 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 23 Oct 2018 10:24:04 -0700 Subject: [PATCH 123/157] fix test_mapdata --- pyro/poutine/subsample_messenger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 3d572b11d4..1931d9c77e 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -107,10 +107,11 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No with ignore_jit_warnings(): if subsample_size is None: - subsample_size = subsample.size(0) + subsample_size = subsample.size(0) if isinstance(subsample, torch.Tensor) \ + else len(subsample) elif subsample is not None and subsample_size != len(subsample): raise ValueError("subsample_size does not match len(subsample), {} vs {}.".format( - subsample_size, subsample.size(0)) + + subsample_size, len(subsample)) + " Did you accidentally use different subsample_size in the model and guide?") return size, subsample_size, subsample From a599587d9fabd8b8a62905e00b236cfa40dec09a Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 12 Nov 2018 16:59:56 -0800 Subject: [PATCH 124/157] Change torch.potrf usage to torch.cholesky (#1529) --- pyro/contrib/autoguide/__init__.py | 2 +- pyro/contrib/gp/models/gpr.py | 6 +++--- pyro/contrib/gp/models/sgpr.py | 6 +++--- pyro/contrib/gp/models/vgp.py | 2 +- pyro/contrib/gp/models/vsgp.py | 2 +- pyro/contrib/gp/util.py | 2 +- tests/contrib/gp/test_conditional.py | 4 ++-- tests/perf/test_benchmark.py | 2 +- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyro/contrib/autoguide/__init__.py b/pyro/contrib/autoguide/__init__.py index 7e16645be3..116c94c2e7 100644 --- a/pyro/contrib/autoguide/__init__.py +++ b/pyro/contrib/autoguide/__init__.py @@ -639,7 +639,7 @@ def laplace_approximation(self, *args, **kwargs): loc = pyro.param("{}_loc".format(self.prefix)) H = hessian(loss, loc.unconstrained()) cov = H.inverse() - scale_tril = cov.potrf(upper=False) + scale_tril = cov.cholesky() # calculate scale_tril from self.guide() scale_tril_name = "{}_scale_tril".format(self.prefix) diff --git a/pyro/contrib/gp/models/gpr.py b/pyro/contrib/gp/models/gpr.py index 8b7ea0c758..263cea55f7 100644 --- a/pyro/contrib/gp/models/gpr.py +++ b/pyro/contrib/gp/models/gpr.py @@ -81,7 +81,7 @@ def model(self): N = self.X.shape[0] Kff = self.kernel(self.X) Kff.view(-1)[::N + 1] += noise # add noise to diagonal - Lff = Kff.potrf(upper=False) + Lff = Kff.cholesky() zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) @@ -129,7 +129,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): N = self.X.shape[0] Kff = self.kernel(self.X).contiguous() Kff.view(-1)[::N + 1] += noise # add noise to the diagonal - Lff = Kff.potrf(upper=False) + Lff = Kff.cholesky() y_residual = self.y - self.mean_function(self.X) loc, cov = conditional(Xnew, self.X, self.kernel, y_residual, None, Lff, @@ -185,7 +185,7 @@ def sample_next(xnew, outside_vars): X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars["Kff"] # Compute Cholesky decomposition of kernel matrix - Lff = Kff.potrf(upper=False) + Lff = Kff.cholesky() y_residual = y - self.mean_function(X) # Compute conditional mean and variance diff --git a/pyro/contrib/gp/models/sgpr.py b/pyro/contrib/gp/models/sgpr.py index 837bb470da..27d4eb93c3 100644 --- a/pyro/contrib/gp/models/sgpr.py +++ b/pyro/contrib/gp/models/sgpr.py @@ -130,7 +130,7 @@ def model(self): M = Xu.shape[0] Kuu = self.kernel(Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal - Luu = Kuu.potrf(upper=False) + Luu = Kuu.cholesky() Kuf = self.kernel(Xu, self.X) W = Kuf.trtrs(Luu, upper=False)[0].t() @@ -210,7 +210,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): Kuu = self.kernel(Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal - Luu = Kuu.potrf(upper=False) + Luu = Kuu.cholesky() Kus = self.kernel(Xu, Xnew) Kuf = self.kernel(Xu, self.X) @@ -225,7 +225,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True): W_Dinv = W / D K = W_Dinv.matmul(W.t()).contiguous() K.view(-1)[::M + 1] += 1 # add identity matrix to K - L = K.potrf(upper=False) + L = K.cholesky() # get y_residual and convert it into 2D tensor for packing y_residual = self.y - self.mean_function(self.X) diff --git a/pyro/contrib/gp/models/vgp.py b/pyro/contrib/gp/models/vgp.py index 4a8f8b0a23..74d20c263f 100644 --- a/pyro/contrib/gp/models/vgp.py +++ b/pyro/contrib/gp/models/vgp.py @@ -91,7 +91,7 @@ def model(self): N = self.X.shape[0] Kff = self.kernel(self.X).contiguous() Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal - Lff = Kff.potrf(upper=False) + Lff = Kff.cholesky() zero_loc = self.X.new_zeros(f_loc.shape) f_name = param_with_module_name(self.name, "f") diff --git a/pyro/contrib/gp/models/vsgp.py b/pyro/contrib/gp/models/vsgp.py index 2d6586bad6..8e95f2d09d 100644 --- a/pyro/contrib/gp/models/vsgp.py +++ b/pyro/contrib/gp/models/vsgp.py @@ -116,7 +116,7 @@ def model(self): M = Xu.shape[0] Kuu = self.kernel(Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal - Luu = Kuu.potrf(upper=False) + Luu = Kuu.cholesky() zero_loc = Xu.new_zeros(u_loc.shape) u_name = param_with_module_name(self.name, "u") diff --git a/pyro/contrib/gp/util.py b/pyro/contrib/gp/util.py index 487e6d824f..24a942e847 100644 --- a/pyro/contrib/gp/util.py +++ b/pyro/contrib/gp/util.py @@ -212,7 +212,7 @@ def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=Fa if Lff is None: Kff = kernel(X).contiguous() Kff.view(-1)[::N + 1] += jitter # add jitter to diagonal - Lff = Kff.potrf(upper=False) + Lff = Kff.cholesky() Kfs = kernel(X, Xnew) # convert f_loc_shape from latent_shape x N to N x latent_shape diff --git a/tests/contrib/gp/test_conditional.py b/tests/contrib/gp/test_conditional.py index b8481ce0fd..6ab03abced 100644 --- a/tests/contrib/gp/test_conditional.py +++ b/tests/contrib/gp/test_conditional.py @@ -17,7 +17,7 @@ X = torch.tensor([[1., 5.], [2., 1.], [3., 2.]]) kernel = Matern52(input_dim=2) Kff = kernel(X) + torch.eye(3) * 1e-6 -Lff = Kff.potrf(upper=False) +Lff = Kff.cholesky() pyro.set_rng_seed(123) f_loc = torch.rand(3) f_scale_tril = torch.rand(3, 3).tril(-1) + torch.rand(3).exp().diag() @@ -75,7 +75,7 @@ def test_conditional_whiten(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov): loc0, cov0 = conditional(Xnew, X, kernel, f_loc, f_scale_tril, full_cov=True, whiten=False) Kff = kernel(X) + torch.eye(3) * 1e-6 - Lff = Kff.potrf(upper=False) + Lff = Kff.cholesky() whiten_f_loc = Lff.inverse().matmul(f_loc) whiten_f_scale_tril = Lff.inverse().matmul(f_scale_tril) loc1, cov1 = conditional(Xnew, X, kernel, whiten_f_loc, whiten_f_scale_tril, diff --git a/tests/perf/test_benchmark.py b/tests/perf/test_benchmark.py index 5477d7324e..7daa789c62 100644 --- a/tests/perf/test_benchmark.py +++ b/tests/perf/test_benchmark.py @@ -104,7 +104,7 @@ def svgp_multiclass(num_steps, whiten): pyro.set_rng_seed(0) X = torch.rand(100, 1) K = (-0.5 * (X - X.t()).pow(2) / 0.01).exp() + torch.eye(100) * 1e-6 - f = K.potrf(upper=False).matmul(torch.randn(100, 3)) + f = K.cholesky().matmul(torch.randn(100, 3)) y = f.argmax(dim=-1) kernel = gp.kernels.Matern32(1).add( From 6be5d06429e7da981fff22771f2f9e9e359b40af Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 12 Nov 2018 21:53:11 -0800 Subject: [PATCH 125/157] update test_beta_bernoulli to use pyro.plate --- tests/infer/mcmc/test_hmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 38f05ba616..b0be8df0a2 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -281,13 +281,13 @@ def model(data): alpha = torch.tensor([1.1, 1.1]) beta = torch.tensor([1.1, 1.1]) p_latent = pyro.sample('p_latent', dist.Beta(alpha, beta)) - with pyro.iarange("data", data.shape[0], dim=-2): + with pyro.plate("data", data.shape[0], dim=-2): pyro.sample('obs', dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,)))) - hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_iarange_nesting=2, + hmc_kernel = HMC(model, trajectory_length=1, adapt_step_size=True, max_plate_nesting=2, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') From c228cb603a7b0a4164639eb18aeb5c7e3e36b762 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 13 Nov 2018 16:24:34 -0800 Subject: [PATCH 126/157] log example output while running --- tests/test_examples.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 2daf37eaee..56c0bbc666 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -3,7 +3,7 @@ import logging import os import sys -from subprocess import check_call +from subprocess import check_output import pytest @@ -126,7 +126,7 @@ def test_cpu(example): example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_call([sys.executable, filename] + args) + check_output([sys.executable, filename] + args) @requires_cuda @@ -136,7 +136,7 @@ def test_cuda(example): example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_call([sys.executable, filename] + args) + check_output([sys.executable, filename] + args) @pytest.mark.parametrize('example', JIT_EXAMPLES) @@ -145,4 +145,4 @@ def test_jit(example): example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_call([sys.executable, filename] + args) + check_output([sys.executable, filename] + args) From 8943517447545d477e20175e4ba16111e69a551e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 15 Nov 2018 10:57:15 -0800 Subject: [PATCH 127/157] fix error in broadcast_all --- tests/ops/test_packed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/test_packed.py b/tests/ops/test_packed.py index 4b6a452aca..b4a7a45d98 100644 --- a/tests/ops/test_packed.py +++ b/tests/ops/test_packed.py @@ -68,7 +68,7 @@ def test_broadcast_all(shapes): packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs] packed_outputs = packed.broadcast_all(*packed_inputs) actual = tuple(packed.unpack(x, symbol_to_dim) for x in packed_outputs) - expected = broadcast_all(*inputs) + expected = broadcast_all(*inputs) if inputs else [] assert len(actual) == len(expected) for a, e in zip(actual, expected): assert_equal(a, e) From 25a40da232c5fbae40cde162e8de2efeed936dba Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 15 Nov 2018 14:27:33 -0800 Subject: [PATCH 128/157] Fix test_memory to ignore UserWarning with gc.getobjects() --- tests/test_memory.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_memory.py b/tests/test_memory.py index 95199c9ec3..55a3180b1a 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import gc +import warnings import networkx as nx import pytest @@ -17,7 +18,9 @@ def count_objects_of_type(type_): - return sum(1 for obj in gc.get_objects() if isinstance(obj, type_)) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + return sum(1 for obj in gc.get_objects() if isinstance(obj, type_)) def test_trace(): From a8a49b19b123c8521ec57eb39a52d8015da7c3f8 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Thu, 15 Nov 2018 14:53:46 -0800 Subject: [PATCH 129/157] add xfail marker to failing jittraceenum test --- tests/infer/test_jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 86b69593bf..ee301c9705 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -351,7 +351,7 @@ def guide(data): TraceGraph_ELBO, JitTraceGraph_ELBO, TraceEnum_ELBO, - JitTraceEnum_ELBO, + xfail_param(JitTraceEnum_ELBO, reason="https://github.com/uber/pyro/issues/1418"), ]) def test_svi_irregular_batch_size(Elbo): pyro.clear_param_store() From e55dff861e6bfd71b9f7a4c97cfd989aa91b964e Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 26 Nov 2018 14:17:18 -0800 Subject: [PATCH 130/157] fix test_svi_enum --- tests/infer/test_jit.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index 791bfe14b2..19bdd7dd1e 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -267,15 +267,7 @@ def guide(data): @pytest.mark.parametrize("enumerate2", ["sequential", "parallel"]) @pytest.mark.parametrize("enumerate1", ["sequential", "parallel"]) @pytest.mark.parametrize("plate_dim", [1, 2]) -@pytest.mark.parametrize('Elbo', [ - Trace_ELBO, - JitTrace_ELBO, - TraceGraph_ELBO, - JitTraceGraph_ELBO, - TraceEnum_ELBO, - JitTraceEnum_ELBO, -]) -def test_svi_enum(Elbo, plate_dim, enumerate1, enumerate2): +def test_svi_enum(plate_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 q = pyro.param("q", constant(0.75), constraint=constraints.unit_interval) @@ -298,10 +290,10 @@ def guide(): inner_particles = 2 outer_particles = num_particles // inner_particles - elbo = Elbo(max_plate_nesting=0, - strict_enumeration_warning=any([enumerate1, enumerate2]), - num_particles=inner_particles, - ignore_jit_warnings=True) + elbo = TraceEnum_ELBO(max_plate_nesting=0, + strict_enumeration_warning=any([enumerate1, enumerate2]), + num_particles=inner_particles, + ignore_jit_warnings=True) actual_loss = sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles actual_grad = q.unconstrained().grad / outer_particles From 8b6d40fa5da720260984158e7c744814acb8db05 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 26 Nov 2018 16:18:44 -0800 Subject: [PATCH 131/157] add jit markers for hmc --- .travis.yml | 2 +- tests/infer/mcmc/test_hmc.py | 38 ++++++++++++++++++----------------- tests/infer/mcmc/test_nuts.py | 7 +++---- tests/test_examples.py | 8 ++++---- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/.travis.yml b/.travis.yml index e9602652ca..d6b39b3b6a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ env: install: - pip install -U pip - - pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torch_nightly==1.0.0.dev20181126 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install torchvision --no-dependencies - pip install .[test] - pip freeze diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 4d8ef3776c..f12a8cef0c 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -20,8 +20,6 @@ def mark_jit(*args, **kwargs): jit_markers = kwargs.pop("marks", []) jit_markers += [ - pytest.mark.skipif(torch.__version__ <= "0.4.1", - reason="https://github.com/pytorch/pytorch/issues/10041#issuecomment-409057228"), pytest.mark.skipif('CI' in os.environ, reason='slow test') ] @@ -81,7 +79,7 @@ def rmse(t1, t2): 'num_steps': 4}, expected_means=[0.25, 0.50, 0.75], expected_precs=[1.33, 1, 1.33], - mean_tol=0.06, + mean_tol=0.08, std_tol=0.08, ), T( @@ -92,29 +90,29 @@ def rmse(t1, t2): 'num_steps': 5}, expected_means=[0.20, 0.40, 0.60, 0.80], expected_precs=[1.25, 0.83, 0.83, 1.25], - mean_tol=0.06, - std_tol=0.06, + mean_tol=0.08, + std_tol=0.08, ), T( GaussianChain(dim=5, chain_len=2, num_obs=100), num_samples=2000, - warmup_steps=500, - hmc_params={'num_steps': 25}, + warmup_steps=1000, + hmc_params={'num_steps': 15, 'step_size': 0.7}, expected_means=[0.5, 1.0], expected_precs=[2.0, 100], - mean_tol=0.06, - std_tol=0.06, + mean_tol=0.08, + std_tol=0.08, ), T( GaussianChain(dim=5, chain_len=9, num_obs=1), num_samples=3000, warmup_steps=500, - hmc_params={'step_size': 0.1, + hmc_params={'step_size': 0.2, 'num_steps': 15}, expected_means=[0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90], expected_precs=[1.11, 0.63, 0.48, 0.42, 0.4, 0.42, 0.48, 0.63, 1.11], - mean_tol=0.1, - std_tol=0.1, + mean_tol=0.11, + std_tol=0.11, ) ] @@ -126,8 +124,7 @@ def rmse(t1, t2): 'fixture, num_samples, warmup_steps, hmc_params, expected_means, expected_precs, mean_tol, std_tol', TEST_CASES, ids=TEST_IDS) -@pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ, - reason='Slow test - skip on CI/CUDA') +@pytest.mark.skip(reason='Slow test (https://github.com/pytorch/pytorch/issues/12190)') @pytest.mark.disable_validation() def test_hmc_conjugate_gaussian(fixture, num_samples, @@ -211,7 +208,12 @@ def model(data): assert_equal(posterior.mean, true_probs, prec=0.05) -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, + mark_jit( + True, + marks=[pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1487")]), + ], + ids=jit_idfn) def test_gamma_normal(jit): def model(data): rate = torch.tensor([1.0, 1.0]) @@ -222,15 +224,15 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model, step_size=0.01, num_steps=3, jit_compile=jit, ignore_jit_warnings=True) + hmc_kernel = HMC(model,trajectory_length=1, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) @pytest.mark.parametrize("jit", [False, - mark_jit(True, marks=[pytest.mark.xfail( - reason="https://github.com/uber/pyro/issues/1418")]) + mark_jit(True, marks=[pytest.mark.skip( + reason="https://github.com/uber/pyro/issues/1487")]) ], ids=jit_idfn) def test_dirichlet_categorical(jit): def model(data): diff --git a/tests/infer/mcmc/test_nuts.py b/tests/infer/mcmc/test_nuts.py index f926a2c701..ffee6ce757 100644 --- a/tests/infer/mcmc/test_nuts.py +++ b/tests/infer/mcmc/test_nuts.py @@ -39,8 +39,8 @@ warmup_steps=200, expected_means=[0.25, 0.50, 0.75], expected_precs=[1.33, 1, 1.33], - mean_tol=0.08, - std_tol=0.08, + mean_tol=0.09, + std_tol=0.09, ), T( GaussianChain(dim=10, chain_len=4, num_obs=1), @@ -96,8 +96,7 @@ def jit_idfn(param): 'fixture, num_samples, warmup_steps, expected_means, expected_precs, mean_tol, std_tol', TEST_CASES, ids=TEST_IDS) -@pytest.mark.skipif('CI' in os.environ or 'CUDA_TEST' in os.environ, - reason='Slow test - skip on CI/CUDA') +@pytest.mark.skip(reason='Slow test (https://github.com/pytorch/pytorch/issues/12190)') @pytest.mark.disable_validation() def test_nuts_conjugate_gaussian(fixture, num_samples, diff --git a/tests/test_examples.py b/tests/test_examples.py index 1a43a57712..c8da422991 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -3,7 +3,7 @@ import logging import os import sys -from subprocess import check_output +from subprocess import check_call import pytest @@ -127,7 +127,7 @@ def test_cpu(example): example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_output([sys.executable, filename] + args) + check_call([sys.executable, filename] + args) @requires_cuda @@ -137,7 +137,7 @@ def test_cuda(example): example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_output([sys.executable, filename] + args) + check_call([sys.executable, filename] + args) @pytest.mark.parametrize('example', JIT_EXAMPLES) @@ -146,4 +146,4 @@ def test_jit(example): example = example.split() filename, args = example[0], example[1:] filename = os.path.join(EXAMPLES_DIR, filename) - check_output([sys.executable, filename] + args) + check_call([sys.executable, filename] + args) From e5fdb18962c144e9189163b1289e0058997c4201 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 26 Nov 2018 16:41:37 -0800 Subject: [PATCH 132/157] skip slow jit tests --- tests/infer/mcmc/test_hmc.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index f12a8cef0c..1bbe043e52 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -208,11 +208,8 @@ def model(data): assert_equal(posterior.mean, true_probs, prec=0.05) -@pytest.mark.parametrize("jit", [False, - mark_jit( - True, - marks=[pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1487")]), - ], +@pytest.mark.parametrize("jit", [False, mark_jit(True, marks=[ + pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1487")])], ids=jit_idfn) def test_gamma_normal(jit): def model(data): @@ -230,10 +227,9 @@ def model(data): assert_equal(posterior.mean, true_std, prec=0.05) -@pytest.mark.parametrize("jit", [False, - mark_jit(True, marks=[pytest.mark.skip( - reason="https://github.com/uber/pyro/issues/1487")]) - ], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, mark_jit(True, marks=[ + pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1487")])], + ids=jit_idfn) def test_dirichlet_categorical(jit): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) @@ -288,7 +284,9 @@ def model(data): assert_equal(posterior.mean, true_probs, prec=0.05) -@pytest.mark.parametrize("jit", [False, mark_jit(True)], ids=jit_idfn) +@pytest.mark.parametrize("jit", [False, mark_jit(True, marks=[ + pytest.mark.skip(reason="https://github.com/uber/pyro/issues/1487")])], + ids=jit_idfn) def test_gamma_normal_with_dual_averaging(jit): def model(data): rate = torch.tensor([1.0, 1.0]) From f154fa3c6a658da183dd940e850653c6c6d5fc51 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 26 Nov 2018 16:57:17 -0800 Subject: [PATCH 133/157] add gp to jit test --- tests/test_examples.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_examples.py b/tests/test_examples.py index ecec106de0..d2e30a9ad2 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -99,6 +99,7 @@ def xfail_jit(*args): xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --jit'), xfail_jit('vae/vae.py --num-epochs=1 --jit'), xfail_jit('vae/vae_comparison.py --num-epochs=1 --jit'), + xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), ] From 68b88b22788aa125cd509e0ec8941ba58f23721d Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 26 Nov 2018 17:06:28 -0800 Subject: [PATCH 134/157] remove low rank mvn docs --- docs/source/distributions.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index d9e8d25c9e..ae72939ad9 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -77,13 +77,6 @@ GaussianScaleMixture :undoc-members: :show-inheritance: -LowRankMultivariateNormal -------------------------- -.. autoclass:: pyro.distributions.LowRankMultivariateNormal - :members: - :undoc-members: - :show-inheritance: - MaskedMixture ------------- .. autoclass:: pyro.distributions.MaskedMixture From 364e2a10a8e4d5f95da730e575d6e51f9337870f Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Mon, 26 Nov 2018 17:16:05 -0800 Subject: [PATCH 135/157] fix lint --- tests/infer/mcmc/test_hmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/infer/mcmc/test_hmc.py b/tests/infer/mcmc/test_hmc.py index 1bbe043e52..6af12b6386 100644 --- a/tests/infer/mcmc/test_hmc.py +++ b/tests/infer/mcmc/test_hmc.py @@ -221,7 +221,7 @@ def model(data): true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,)))) - hmc_kernel = HMC(model,trajectory_length=1, jit_compile=jit, ignore_jit_warnings=True) + hmc_kernel = HMC(model, trajectory_length=1, jit_compile=jit, ignore_jit_warnings=True) mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05) From 820493e6058c0cf79438ec9c8e3d0a4366d7ec67 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 10:39:07 -0800 Subject: [PATCH 136/157] address comment --- tests/test_examples.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_examples.py b/tests/test_examples.py index d2e30a9ad2..e42791011d 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -92,6 +92,7 @@ def xfail_jit(*args): xfail_jit('hmm.py --num-steps=1 --truncate=65 --model=2 --jit'), xfail_jit('hmm.py --num-steps=1 --truncate=65 --model=3 --jit'), xfail_jit('hmm.py --num-steps=1 --truncate=65 --model=4 --jit'), + xfail_jit('hmm.py --num-steps=1 --truncate=65 --model=5 --jit'), xfail_jit('lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit'), xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit'), From 16b51494850062dd550b2cbcc0cb270170dd3b23 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 10:40:50 -0800 Subject: [PATCH 137/157] update travis build --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d6b39b3b6a..1ac61f72fb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ env: install: - pip install -U pip - - pip install torch_nightly==1.0.0.dev20181126 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install torchvision --no-dependencies - pip install .[test] - pip freeze From f110c01cbee378cbc1bc3822235323bf5c39a365 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 11:22:05 -0800 Subject: [PATCH 138/157] Fix build timeout due to dataset download issues (#1571) --- pyro/contrib/examples/util.py | 14 ++++++++++++++ pyro/distributions/torch_patch.py | 14 +++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pyro/contrib/examples/util.py b/pyro/contrib/examples/util.py index c769c7bdd5..12a6b94942 100644 --- a/pyro/contrib/examples/util.py +++ b/pyro/contrib/examples/util.py @@ -2,9 +2,23 @@ import sys +import torchvision import torchvision.datasets as datasets from torch.utils.data import DataLoader from torchvision import transforms +from torchvision.datasets import MNIST + +from pyro.distributions.torch_patch import patch_dependency + + +@patch_dependency('torchvision.datasets.MNIST', torchvision) +class _MNIST(getattr(MNIST, '_pyro_unpatched', MNIST)): + urls = [ + "https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz", + "https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz", + "https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz", + "https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz", + ] def get_data_loader(dataset_name, diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index ebc6d71a84..a4935f2d52 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -3,10 +3,10 @@ import torch -def _patch(target): +def patch_dependency(target, root_module=torch): parts = target.split('.') - assert parts[0] == 'torch' - module = torch + assert parts[0] == root_module.__name__ + module = root_module for part in parts[1:-1]: module = getattr(module, part) name = parts[-1] @@ -22,7 +22,7 @@ def decorator(new_fn): return decorator -@_patch('torch._dirichlet_grad') +@patch_dependency('torch._dirichlet_grad') def _torch_dirichlet_grad(x, concentration, total): unpatched_fn = _torch_dirichlet_grad._pyro_unpatched if x.is_cuda: @@ -31,7 +31,7 @@ def _torch_dirichlet_grad(x, concentration, total): # This can be removed when super(...).__init__() is added upstream -@_patch('torch.distributions.transforms.Transform.__init__') +@patch_dependency('torch.distributions.transforms.Transform.__init__') def _Transform__init__(self, cache_size=0): self._cache_size = cache_size self._inv = None @@ -44,7 +44,7 @@ def _Transform__init__(self, cache_size=0): super(torch.distributions.transforms.Transform, self).__init__() -@_patch('torch.linspace') +@patch_dependency('torch.linspace') def _torch_linspace(*args, **kwargs): unpatched_fn = _torch_linspace._pyro_unpatched template = torch.Tensor() @@ -57,7 +57,7 @@ def _torch_linspace(*args, **kwargs): return ret -@_patch('torch.einsum') +@patch_dependency('torch.einsum') def _einsum(equation, operands): # work around torch.einsum performance issues # see https://github.com/pytorch/pytorch/issues/10661 From f838a34f81a33620074fb51f21c0452904827338 Mon Sep 17 00:00:00 2001 From: JP Date: Tue, 27 Nov 2018 12:00:52 -0800 Subject: [PATCH 139/157] rtd install pytorch 1.0 (#1572) * rtd install pytorch 1.0 * Update readme --- README.md | 25 ++++++------------------- docs/source/conf.py | 6 ++---- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 1ca866ba7f..8be869ee5c 100644 --- a/README.md +++ b/README.md @@ -62,27 +62,14 @@ Make sure that the models come from the same release version of the [Pyro source For recent features you can install Pyro from source. -To install a compatible CPU version of PyTorch on OSX / Linux, you -could use the PyTorch install helper script. +To install a compatible version of PyTorch, use the PyTorch nightly +[build](https://pytorch.org/). We recommend pinning to the specific +nightly build below that has been well tested. -``` -bash scripts/install_pytorch.sh -``` - -Alternatively, build PyTorch following instructions in the PyTorch -[README](https://github.com/pytorch/pytorch/blob/master/README.md). -```sh -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch -git checkout 200fb22 # <---- a well-tested commit -``` -On Linux: -```sh -python setup.py install -``` -On OSX: ```sh -MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install +build_ver=1.0.0.dev20181127 # <---- a well-tested PyTorch build +pip install torch_nightly==${build_ver} -f \ + https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html ``` Finally install Pyro using pip or from source as shown below. diff --git a/docs/source/conf.py b/docs/source/conf.py index 883c040407..fe336ce2cb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -205,7 +205,5 @@ def setup(app): # @jpchen's hack to get rtd builder to install latest pytorch if 'READTHEDOCS' in os.environ: - os.system('curl -o install.sh https://raw.githubusercontent.com/uber/pyro/dev/scripts/install_pytorch.sh') - os.system('curl https://raw.githubusercontent.com/uber/pyro/dev/README.md > README.md') - os.system('bash install.sh') - os.system('rm -f install.sh') + os.system("grep 'build_ver=.*' README.md | cut -f1 -d' ' | cut -f2 -d= | xargs -I {} pip install " + "torch_nightly=={} -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html") From 26b0797e5c84a60e52cdadc46cf1f7c02d524a9a Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 12:07:31 -0800 Subject: [PATCH 140/157] change data directory for gp --- examples/contrib/gp/sv-dkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 1553d755ee..21dd0cf19c 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -130,7 +130,7 @@ def cnn_fn(x): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pyro GP MNIST Example') - parser.add_argument('--data-dir', type=str, default='../data', metavar='PATH', + parser.add_argument('--data-dir', type=str, default='./data', metavar='PATH', help='default directory to cache MNIST data') parser.add_argument('--num-inducing', type=int, default=70, metavar='N', help='number of inducing input (default: 70)') From b303af0584f80741b21630592f79af3a2783e1be Mon Sep 17 00:00:00 2001 From: jpchen Date: Tue, 27 Nov 2018 12:43:49 -0800 Subject: [PATCH 141/157] revert rtd command change --- docs/source/conf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index fe336ce2cb..37c4e38572 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -205,5 +205,4 @@ def setup(app): # @jpchen's hack to get rtd builder to install latest pytorch if 'READTHEDOCS' in os.environ: - os.system("grep 'build_ver=.*' README.md | cut -f1 -d' ' | cut -f2 -d= | xargs -I {} pip install " - "torch_nightly=={} -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html") + os.system('pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html') From d6653effd3be193c2f6bc1a10a78d85f6d1e2a44 Mon Sep 17 00:00:00 2001 From: jpchen Date: Tue, 27 Nov 2018 13:23:21 -0800 Subject: [PATCH 142/157] use pytorch 0.4 for rtd --- docs/source/conf.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 37c4e38572..d3bc48e159 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -205,4 +205,8 @@ def setup(app): # @jpchen's hack to get rtd builder to install latest pytorch if 'READTHEDOCS' in os.environ: - os.system('pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html') + os.system('pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl') + # for pytorch 1.0 (currently fails with OOM + # https://readthedocs.org/projects/pyro-ppl/builds/8159615/ +# os.system('pip install torch_nightly==1.0.0.dev20181127 -f ' +# 'https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html') From 1da48c14fd3b54f7244d96488b311c1924e3c6cf Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 14:10:24 -0800 Subject: [PATCH 143/157] Cache dataset directory on CI build --- .gitignore | 1 + .travis.yml | 4 ++++ examples/air/main.py | 4 +++- examples/contrib/gp/sv-dkl.py | 9 +++++---- examples/dmm/polyphonic_data_loader.py | 5 ++++- examples/sparse_gamma_def.py | 9 ++++++--- examples/vae/utils/mnist_cached.py | 8 ++++++-- pyro/contrib/examples/util.py | 8 ++++++++ 8 files changed, 37 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index d11ffaa4bd..5664e405e4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ run_outputs* .DS_Store .benchmarks data +.data results examples/*/processed examples/*/results diff --git a/.travis.yml b/.travis.yml index 5903b85ff8..ba81d24292 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,10 @@ env: global: - PYTHONPATH=$PWD:$PYTHONPATH +cache: + directories: + - $HOME/.data + install: - pip install -U pip - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then diff --git a/examples/air/main.py b/examples/air/main.py index d1f6c0b9f7..f1e0c4bb39 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -22,6 +22,8 @@ import pyro.optim as optim import pyro.poutine as poutine from air import AIR, latents_to_tensor + +from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO from viz import draw_many, tensor_to_objs @@ -110,7 +112,7 @@ def exp_decay(initial, final, begin, duration, t): def load_data(): - inpath = './data' + inpath = get_data_directory(__file__) (X_np, Y), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42) X_np = X_np.astype(np.float32) X_np /= 255.0 diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 1553d755ee..9b4f295c2b 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -27,7 +27,7 @@ import pyro.contrib.gp as gp import pyro.infer as infer import pyro.optim as optim -from pyro.contrib.examples.util import get_data_loader +from pyro.contrib.examples.util import get_data_loader, get_data_directory class CNN(nn.Module): @@ -77,14 +77,15 @@ def test(args, test_loader, gpmodel): def main(args): + data_dir = args.data_dir if args.data_dir is not None else get_data_directory(__file__) train_loader = get_data_loader(dataset_name='MNIST', - data_dir=args.data_dir, + data_dir=data_dir, batch_size=args.batch_size, dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], is_training_set=True, shuffle=True) test_loader = get_data_loader(dataset_name='MNIST', - data_dir=args.data_dir, + data_dir=data_dir, batch_size=args.batch_size, dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], is_training_set=False, @@ -130,7 +131,7 @@ def cnn_fn(x): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pyro GP MNIST Example') - parser.add_argument('--data-dir', type=str, default='../data', metavar='PATH', + parser.add_argument('--data-dir', type=str, default="", metavar='PATH', help='default directory to cache MNIST data') parser.add_argument('--num-inducing', type=int, default=70, metavar='N', help='number of inducing input (default: 70)') diff --git a/examples/dmm/polyphonic_data_loader.py b/examples/dmm/polyphonic_data_loader.py index de7c05aff9..e57ff4dd03 100644 --- a/examples/dmm/polyphonic_data_loader.py +++ b/examples/dmm/polyphonic_data_loader.py @@ -23,6 +23,9 @@ # this function processes the raw data; in particular it unsparsifies it +from pyro.contrib.examples.util import get_data_directory + + def process_data(base_path, filename, T_max=160, min_note=21, note_range=88): output = os.path.join(base_path, filename) if os.path.exists(output): @@ -55,7 +58,7 @@ def process_data(base_path, filename, T_max=160, min_note=21, note_range=88): # this logic will be initiated upon import -base_path = './data' +base_path = get_data_directory(__file__) process_data(base_path, "jsb_processed.pkl") jsb_file_loc = "./data/jsb_processed.pkl" diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index 06187a67c3..adaffb55a4 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -18,6 +18,8 @@ import pyro import pyro.optim as optim import wget + +from pyro.contrib.examples.util import get_data_directory from pyro.distributions import Gamma, Poisson from pyro.infer import SVI, Trace_ELBO @@ -122,9 +124,10 @@ def clip_params(self): def main(args): # load data print('loading training data...') - if not os.path.exists('faces_training.csv'): - wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', 'faces_training.csv') - data = torch.tensor(np.loadtxt('faces_training.csv', delimiter=',')).float() + dataset_path = os.path.join(get_data_directory(__file__), 'faces_training.csv') + if not os.path.exists(dataset_path): + wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', dataset_path) + data = torch.tensor(np.loadtxt(dataset_path, delimiter=',')).float() sparse_gamma_def = SparseGammaDEF() opt = optim.AdagradRMSProp({"eta": 4.5, "t": 0.1}) diff --git a/examples/vae/utils/mnist_cached.py b/examples/vae/utils/mnist_cached.py index a36bba761b..44ae8804af 100644 --- a/examples/vae/utils/mnist_cached.py +++ b/examples/vae/utils/mnist_cached.py @@ -7,6 +7,9 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST +from pyro.contrib.examples.util import get_data_directory + + # This file contains utilities for caching, transforming and splitting MNIST data # efficiently. By default, a PyTorch DataLoader will apply the transform every epoch # we avoid this by caching the data early on in MNISTCached class @@ -191,20 +194,21 @@ def __getitem__(self, index): return img, target -def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root='./data', download=True, **kwargs): +def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, download=True, **kwargs): """ helper function for setting up pytorch data loaders for a semi-supervised dataset :param dataset: the data to use :param use_cuda: use GPU(s) for training :param batch_size: size of a batch of data to output when iterating over the data loaders :param sup_num: number of supervised data examples - :param root: where on the filesystem should the dataset be :param download: download the dataset (if it doesn't exist already) :param kwargs: other params for the pytorch data loader :return: three data loaders: (supervised data for training, un-supervised data for training, supervised data for testing) """ # instantiate the dataset as training/testing sets + if root is None: + root = get_data_directory(__file__) if 'num_workers' not in kwargs: kwargs = {'num_workers': 0, 'pin_memory': False} diff --git a/pyro/contrib/examples/util.py b/pyro/contrib/examples/util.py index c769c7bdd5..e5ffe21342 100644 --- a/pyro/contrib/examples/util.py +++ b/pyro/contrib/examples/util.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +import os import sys import torchvision.datasets as datasets @@ -34,3 +35,10 @@ def print_and_log(logger, msg): if logger is not None: logger.write("{}\n".format(msg)) logger.flush() + + +def get_data_directory(filepath=None): + if 'CI' in os.environ: + return os.path.expanduser('~/.data') + return os.path.abspath(os.path.join(os.path.dirname(filepath), + '.data')) From ea76e3f7b55a473949b4f55f78b9cd3c8b19d35b Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 14:13:03 -0800 Subject: [PATCH 144/157] fix default --- examples/contrib/gp/sv-dkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 9b4f295c2b..381424258a 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -131,7 +131,7 @@ def cnn_fn(x): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pyro GP MNIST Example') - parser.add_argument('--data-dir', type=str, default="", metavar='PATH', + parser.add_argument('--data-dir', type=str, default=None, metavar='PATH', help='default directory to cache MNIST data') parser.add_argument('--num-inducing', type=int, default=70, metavar='N', help='number of inducing input (default: 70)') From 68c26dc3566ab07fdb5390c9a01dd0b1762b3ec6 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 14:34:20 -0800 Subject: [PATCH 145/157] fix example --- examples/sparse_gamma_def.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index adaffb55a4..20602c78b0 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -124,8 +124,10 @@ def clip_params(self): def main(args): # load data print('loading training data...') - dataset_path = os.path.join(get_data_directory(__file__), 'faces_training.csv') + dataset_directory = get_data_directory(__file__) + dataset_path = os.path.join(dataset_directory, 'faces_training.csv') if not os.path.exists(dataset_path): + os.makedirs(dataset_directory, exist_ok=True) wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', dataset_path) data = torch.tensor(np.loadtxt(dataset_path, delimiter=',')).float() From 7819ceae030c95e22cbcc0c4fd067c2d6e3c58e0 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 14:39:05 -0800 Subject: [PATCH 146/157] fix dmm path --- examples/dmm/polyphonic_data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dmm/polyphonic_data_loader.py b/examples/dmm/polyphonic_data_loader.py index e57ff4dd03..775e60296d 100644 --- a/examples/dmm/polyphonic_data_loader.py +++ b/examples/dmm/polyphonic_data_loader.py @@ -60,7 +60,7 @@ def process_data(base_path, filename, T_max=160, min_note=21, note_range=88): # this logic will be initiated upon import base_path = get_data_directory(__file__) process_data(base_path, "jsb_processed.pkl") -jsb_file_loc = "./data/jsb_processed.pkl" +jsb_file_loc = os.path.join(base_path, "jsb_processed.pkl") # ingest training/validation/test data from disk From 9433f6734b61598023e26add13b2d3c850719297 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 15:29:25 -0800 Subject: [PATCH 147/157] address lack of exists_ok in makedirs in python 2 --- examples/sparse_gamma_def.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index 20602c78b0..e69d164a5b 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -10,6 +10,7 @@ from __future__ import absolute_import, division, print_function import argparse +import errno import os import numpy as np @@ -127,7 +128,12 @@ def main(args): dataset_directory = get_data_directory(__file__) dataset_path = os.path.join(dataset_directory, 'faces_training.csv') if not os.path.exists(dataset_path): - os.makedirs(dataset_directory, exist_ok=True) + try: + os.makedirs(dataset_directory) + except OSError as e: + if e.errno != errno.EXIST: + raise + pass wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', dataset_path) data = torch.tensor(np.loadtxt(dataset_path, delimiter=',')).float() From bcffd2e7956115610330053ebb62e62a20571125 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 16:20:52 -0800 Subject: [PATCH 148/157] add debug info --- pyro/contrib/examples/util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pyro/contrib/examples/util.py b/pyro/contrib/examples/util.py index a0c60766f2..6e15cb0026 100644 --- a/pyro/contrib/examples/util.py +++ b/pyro/contrib/examples/util.py @@ -32,11 +32,14 @@ def get_data_loader(dataset_name, dataset_transforms = [] trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms) dataset = getattr(datasets, dataset_name) - return DataLoader( - dataset(root=data_dir, + print("downloading data") + dset = dataset(root=data_dir, train=is_training_set, transform=trans, - download=True), + download=True) + print("download complete.") + return DataLoader( + dset, batch_size=batch_size, shuffle=shuffle ) From 7f947b8612de8e6bda95bce47d64e4e05c04e0e1 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 16:52:13 -0800 Subject: [PATCH 149/157] fix lint --- pyro/contrib/examples/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyro/contrib/examples/util.py b/pyro/contrib/examples/util.py index 6e15cb0026..c61e675701 100644 --- a/pyro/contrib/examples/util.py +++ b/pyro/contrib/examples/util.py @@ -34,9 +34,9 @@ def get_data_loader(dataset_name, dataset = getattr(datasets, dataset_name) print("downloading data") dset = dataset(root=data_dir, - train=is_training_set, - transform=trans, - download=True) + train=is_training_set, + transform=trans, + download=True) print("download complete.") return DataLoader( dset, From ed4e11365d7a1900eee9f9a82f08f4075de092e8 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 17:10:39 -0800 Subject: [PATCH 150/157] fix errno --- examples/sparse_gamma_def.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index e69d164a5b..cff770fcd7 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -131,7 +131,7 @@ def main(args): try: os.makedirs(dataset_directory) except OSError as e: - if e.errno != errno.EXIST: + if e.errno != errno.EEXIST: raise pass wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', dataset_path) From a59321b11111d1c6c9e9eaa21b087d32a6af98a2 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 22:38:52 -0800 Subject: [PATCH 151/157] skip gp example to see which others fail --- tests/test_examples.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index e42791011d..43f33a8496 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -7,7 +7,7 @@ import pytest -from tests.common import EXAMPLES_DIR, requires_cuda, xfail_param +from tests.common import EXAMPLES_DIR, requires_cuda, xfail_param, skipif_param logger = logging.getLogger(__name__) pytestmark = pytest.mark.stage('test_examples') @@ -21,7 +21,8 @@ 'contrib/autoname/scoping_mixture.py --num-epochs=1', 'contrib/autoname/mixture.py --num-epochs=1', 'contrib/autoname/tree_data.py --num-epochs=1', - 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4', + skipif_param('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4', + 'CI' in os.environ, reason='https://github.com/uber/pyro/issues/1540'), 'contrib/oed/ab_test.py --num-vi-steps=1000', 'contrib/oed/item_response.py -N=1000 -M=1000', 'contrib/oed/sequential_oed_sigmoid_lm.py --num-experiments=2 --num-runs=2 --no-plot', From d0271ca810836fb5bc85b5f779a77b0a99c1a105 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 27 Nov 2018 22:57:09 -0800 Subject: [PATCH 152/157] fix pytest param --- tests/test_examples.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 43f33a8496..201f8fb739 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -22,7 +22,8 @@ 'contrib/autoname/mixture.py --num-epochs=1', 'contrib/autoname/tree_data.py --num-epochs=1', skipif_param('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4', - 'CI' in os.environ, reason='https://github.com/uber/pyro/issues/1540'), + condition='CI' in os.environ, + reason='https://github.com/uber/pyro/issues/1540'), 'contrib/oed/ab_test.py --num-vi-steps=1000', 'contrib/oed/item_response.py -N=1000 -M=1000', 'contrib/oed/sequential_oed_sigmoid_lm.py --num-experiments=2 --num-runs=2 --no-plot', From 5dfd731f62bfa72887469f352cd39c860147f42b Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 28 Nov 2018 01:13:41 -0800 Subject: [PATCH 153/157] revert changes from #1573 --- .gitignore | 1 - .travis.yml | 4 ---- examples/air/main.py | 3 +-- examples/contrib/gp/sv-dkl.py | 9 ++++----- examples/dmm/polyphonic_data_loader.py | 7 ++----- examples/sparse_gamma_def.py | 17 +++-------------- examples/vae/utils/mnist_cached.py | 6 +----- pyro/contrib/examples/util.py | 8 -------- 8 files changed, 11 insertions(+), 44 deletions(-) diff --git a/.gitignore b/.gitignore index 5664e405e4..d11ffaa4bd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ run_outputs* .DS_Store .benchmarks data -.data results examples/*/processed examples/*/results diff --git a/.travis.yml b/.travis.yml index beca6fc09f..1ac61f72fb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,10 +6,6 @@ env: global: - PYTHONPATH=$PWD:$PYTHONPATH -cache: - directories: - - $HOME/.data - install: - pip install -U pip - pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html diff --git a/examples/air/main.py b/examples/air/main.py index f1e0c4bb39..b1984e9d8e 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -23,7 +23,6 @@ import pyro.poutine as poutine from air import AIR, latents_to_tensor -from pyro.contrib.examples.util import get_data_directory from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO from viz import draw_many, tensor_to_objs @@ -112,7 +111,7 @@ def exp_decay(initial, final, begin, duration, t): def load_data(): - inpath = get_data_directory(__file__) + inpath = './data' (X_np, Y), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42) X_np = X_np.astype(np.float32) X_np /= 255.0 diff --git a/examples/contrib/gp/sv-dkl.py b/examples/contrib/gp/sv-dkl.py index 381424258a..1553d755ee 100644 --- a/examples/contrib/gp/sv-dkl.py +++ b/examples/contrib/gp/sv-dkl.py @@ -27,7 +27,7 @@ import pyro.contrib.gp as gp import pyro.infer as infer import pyro.optim as optim -from pyro.contrib.examples.util import get_data_loader, get_data_directory +from pyro.contrib.examples.util import get_data_loader class CNN(nn.Module): @@ -77,15 +77,14 @@ def test(args, test_loader, gpmodel): def main(args): - data_dir = args.data_dir if args.data_dir is not None else get_data_directory(__file__) train_loader = get_data_loader(dataset_name='MNIST', - data_dir=data_dir, + data_dir=args.data_dir, batch_size=args.batch_size, dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], is_training_set=True, shuffle=True) test_loader = get_data_loader(dataset_name='MNIST', - data_dir=data_dir, + data_dir=args.data_dir, batch_size=args.batch_size, dataset_transforms=[transforms.Normalize((0.1307,), (0.3081,))], is_training_set=False, @@ -131,7 +130,7 @@ def cnn_fn(x): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pyro GP MNIST Example') - parser.add_argument('--data-dir', type=str, default=None, metavar='PATH', + parser.add_argument('--data-dir', type=str, default='../data', metavar='PATH', help='default directory to cache MNIST data') parser.add_argument('--num-inducing', type=int, default=70, metavar='N', help='number of inducing input (default: 70)') diff --git a/examples/dmm/polyphonic_data_loader.py b/examples/dmm/polyphonic_data_loader.py index 775e60296d..de7c05aff9 100644 --- a/examples/dmm/polyphonic_data_loader.py +++ b/examples/dmm/polyphonic_data_loader.py @@ -23,9 +23,6 @@ # this function processes the raw data; in particular it unsparsifies it -from pyro.contrib.examples.util import get_data_directory - - def process_data(base_path, filename, T_max=160, min_note=21, note_range=88): output = os.path.join(base_path, filename) if os.path.exists(output): @@ -58,9 +55,9 @@ def process_data(base_path, filename, T_max=160, min_note=21, note_range=88): # this logic will be initiated upon import -base_path = get_data_directory(__file__) +base_path = './data' process_data(base_path, "jsb_processed.pkl") -jsb_file_loc = os.path.join(base_path, "jsb_processed.pkl") +jsb_file_loc = "./data/jsb_processed.pkl" # ingest training/validation/test data from disk diff --git a/examples/sparse_gamma_def.py b/examples/sparse_gamma_def.py index cff770fcd7..06187a67c3 100644 --- a/examples/sparse_gamma_def.py +++ b/examples/sparse_gamma_def.py @@ -10,7 +10,6 @@ from __future__ import absolute_import, division, print_function import argparse -import errno import os import numpy as np @@ -19,8 +18,6 @@ import pyro import pyro.optim as optim import wget - -from pyro.contrib.examples.util import get_data_directory from pyro.distributions import Gamma, Poisson from pyro.infer import SVI, Trace_ELBO @@ -125,17 +122,9 @@ def clip_params(self): def main(args): # load data print('loading training data...') - dataset_directory = get_data_directory(__file__) - dataset_path = os.path.join(dataset_directory, 'faces_training.csv') - if not os.path.exists(dataset_path): - try: - os.makedirs(dataset_directory) - except OSError as e: - if e.errno != errno.EEXIST: - raise - pass - wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', dataset_path) - data = torch.tensor(np.loadtxt(dataset_path, delimiter=',')).float() + if not os.path.exists('faces_training.csv'): + wget.download('https://d2fefpcigoriu7.cloudfront.net/datasets/faces_training.csv', 'faces_training.csv') + data = torch.tensor(np.loadtxt('faces_training.csv', delimiter=',')).float() sparse_gamma_def = SparseGammaDEF() opt = optim.AdagradRMSProp({"eta": 4.5, "t": 0.1}) diff --git a/examples/vae/utils/mnist_cached.py b/examples/vae/utils/mnist_cached.py index 44ae8804af..167b75af4b 100644 --- a/examples/vae/utils/mnist_cached.py +++ b/examples/vae/utils/mnist_cached.py @@ -7,8 +7,6 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST -from pyro.contrib.examples.util import get_data_directory - # This file contains utilities for caching, transforming and splitting MNIST data # efficiently. By default, a PyTorch DataLoader will apply the transform every epoch @@ -194,7 +192,7 @@ def __getitem__(self, index): return img, target -def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, download=True, **kwargs): +def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root='./data', download=True, **kwargs): """ helper function for setting up pytorch data loaders for a semi-supervised dataset :param dataset: the data to use @@ -207,8 +205,6 @@ def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, d supervised data for testing) """ # instantiate the dataset as training/testing sets - if root is None: - root = get_data_directory(__file__) if 'num_workers' not in kwargs: kwargs = {'num_workers': 0, 'pin_memory': False} diff --git a/pyro/contrib/examples/util.py b/pyro/contrib/examples/util.py index c61e675701..b417448625 100644 --- a/pyro/contrib/examples/util.py +++ b/pyro/contrib/examples/util.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import os import sys import torchvision @@ -52,10 +51,3 @@ def print_and_log(logger, msg): if logger is not None: logger.write("{}\n".format(msg)) logger.flush() - - -def get_data_directory(filepath=None): - if 'CI' in os.environ: - return os.path.expanduser('~/.data') - return os.path.abspath(os.path.join(os.path.dirname(filepath), - '.data')) From aeadda0b9d425be7e57ca8a7820b4a5f74a93da9 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 28 Nov 2018 01:15:46 -0800 Subject: [PATCH 154/157] update pytorch build --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1ac61f72fb..87da8214e3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ env: install: - pip install -U pip - - pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torch_nightly==1.0.0.dev20181128 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install torchvision --no-dependencies - pip install .[test] - pip freeze From 499185a5d02f28aff1f68c4762610a71f59635dd Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 28 Nov 2018 01:29:26 -0800 Subject: [PATCH 155/157] revert to nov 27 build --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 87da8214e3..1ac61f72fb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ env: install: - pip install -U pip - - pip install torch_nightly==1.0.0.dev20181128 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install torchvision --no-dependencies - pip install .[test] - pip freeze From 7b2a0144e4c8227ffbb0fd6894d3d3fed20f2f12 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 28 Nov 2018 10:58:37 -0800 Subject: [PATCH 156/157] Update travis with 11/28 build --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index beca6fc09f..50fb3e8645 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,7 @@ cache: install: - pip install -U pip - - pip install torch_nightly==1.0.0.dev20181127 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - pip install torch_nightly==1.0.0.dev20181128 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - pip install torchvision --no-dependencies - pip install .[test] - pip freeze From 59c73042820491c309da73418add19663b8fc5e7 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 28 Nov 2018 11:30:29 -0800 Subject: [PATCH 157/157] add typing to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 03bed0b7b7..c9f7515214 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ 'six>=1.10.0', # TODO: uncomment on release; using torch-nightly build # 'torch>=0.4.1', + 'typing>=3.6.4', # required by torch wheel 'tqdm>=4.27', ], extras_require={